Source code for academic_metrics.utils.api_key_validator

from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional

if TYPE_CHECKING:
    from langchain.schema.runnable import Runnable


[docs] @dataclass class ValidationResult: openai: bool = False anthropic: bool = False google: bool = False
[docs] class APIKeyValidator: """Validator for LLM API keys across different services. Example: >>> validator = APIKeyValidator(api_key="sk-...") >>> if validator.is_valid(): >>> print("Key is valid!") >>> validator.print_results() # See which services work """
[docs] def __init__(self): """Initialize the APIKeyValidator class.""" # Dict to track api keys which have been validated already # Key = api_key, Value = bool (True if valid, False if not) self._validated_already: Dict[str, bool] = {}
[docs] def _validate(self, api_key: str, model: Optional[str] = None) -> None: """Run validation tests for each service.""" from langchain.prompts import ( ChatPromptTemplate, HumanMessagePromptTemplate, PromptTemplate, SystemMessagePromptTemplate, ) results: ValidationResult = ValidationResult() system_prompt_template: PromptTemplate = PromptTemplate(template="test") human_prompt_template: PromptTemplate = PromptTemplate(template="test") prompt: ChatPromptTemplate = ChatPromptTemplate.from_messages( [ SystemMessagePromptTemplate.from_template( system_prompt_template.template ), HumanMessagePromptTemplate.from_template( human_prompt_template.template ), ] ) # Test OpenAI try: from langchain_openai import ChatOpenAI llm: ChatOpenAI = ChatOpenAI(api_key=api_key, model=model or "gpt-4o-mini") chain: Runnable = prompt | llm chain.invoke({}) results.openai = True except Exception: pass # Test Anthropic try: from langchain_anthropic import ChatAnthropic llm: ChatAnthropic = ChatAnthropic( api_key=api_key, model=model or "claude-3.5-haiku" ) chain: Runnable = prompt | llm chain.invoke({}) results.anthropic = True except Exception: pass # Test Google try: from langchain_google_genai import ChatGoogleGenerativeAI llm: ChatGoogleGenerativeAI = ChatGoogleGenerativeAI( api_key=api_key, model=model or "gemini-1.5-pro" ) chain: Runnable = prompt | llm chain.invoke({}) results.google = True except Exception: pass self._validated_already[api_key] = results
[docs] def _check_attr(self) -> None: """Check if the API key is valid for any service.""" if not hasattr(self, "_current_key"): raise RuntimeError( "Must call is_valid() before checking validity. " "Example usage: " ">>> validator = APIKeyValidator() " ">>> if validator.is_valid(api_key='...'): " ">>> print('Key is valid!')" ">>> else: " ">>> print('Key is invalid!')" )
[docs] def is_valid(self, api_key: str, model: Optional[str] = None) -> bool: """Check if the API key is valid for any service. Validates if not already done.""" if api_key not in self._validated_already: self._validate(api_key=api_key, model=model) results = self._validated_already[api_key] return any([results.openai, results.anthropic, results.google])
[docs] def get_results_for_api_key(self, api_key: str) -> Dict[str, bool]: """Get detailed validation results. Validates if not already done.""" if api_key not in self._validated_already: self._validate(api_key=api_key) results = self._validated_already[api_key] return { "openai": results.openai, "anthropic": results.anthropic, "google": results.google, }
[docs] def get_full_results(self) -> Dict[str, Dict[str, bool]]: """Get detailed validation results for all keys.""" return self._validated_already
[docs] def print_results_for_api_key(self, api_key: str) -> None: """Print formatted validation results for a given API key.""" from academic_metrics.utils.unicode_chars_dict import unicode_chars_dict results: Dict[str, bool] = self.get_results_for_api_key(api_key) print(f"API Key: {api_key}") for service, valid in results.items(): status = ( f"{unicode_chars_dict.get('boxed_checkmark', '')} Valid" if valid else f"{unicode_chars_dict.get('boxed_x', '')} Invalid" ) print(f"{service.title()}: {status}") print("-" * 25)
[docs] def print_full_results(self) -> None: """Print formatted validation results for all keys.""" print("\nAPI Key Validation Results:") print("-" * 25) api_keys = list(self._validated_already.keys()) for api_key in api_keys: self.print_results_for_api_key(api_key)