|
|
import os |
|
|
import sys |
|
|
from typing import List, Optional |
|
|
import asyncio |
|
|
from dataclasses import dataclass |
|
|
from datetime import datetime, timedelta |
|
|
from dotenv import load_dotenv |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), '.env') |
|
|
load_dotenv(dotenv_path=env_path) |
|
|
|
|
|
|
|
|
logger.info(f"Current working directory: {os.getcwd()}") |
|
|
logger.info(f"Loading .env from: {env_path}") |
|
|
logger.info(f"GEMINI_API_KEY: {'*' * 8 + os.getenv('GEMINI_API_KEY', '')[-4:] if os.getenv('GEMINI_API_KEY') else 'Not set'}") |
|
|
logger.info(f"GEMINI_API_KEYS: {'*' * 8 + os.getenv('GEMINI_API_KEYS', '')[-4:] if os.getenv('GEMINI_API_KEYS') else 'Not set'}") |
|
|
|
|
|
@dataclass |
|
|
class APIKey: |
|
|
key: str |
|
|
last_used: Optional[datetime] = None |
|
|
is_available: bool = True |
|
|
rate_limit_reset: Optional[datetime] = None |
|
|
|
|
|
class APIKeyManager: |
|
|
_instance = None |
|
|
_lock = asyncio.Lock() |
|
|
|
|
|
def __new__(cls): |
|
|
if cls._instance is None: |
|
|
cls._instance = super(APIKeyManager, cls).__new__(cls) |
|
|
cls._instance._initialize() |
|
|
return cls._instance |
|
|
|
|
|
def _initialize(self): |
|
|
self.keys: List[APIKey] = [] |
|
|
self._current_index = 0 |
|
|
self._load_api_keys() |
|
|
|
|
|
def _load_api_keys(self): |
|
|
|
|
|
single_key = os.getenv('GEMINI_API_KEY', '').strip() |
|
|
if single_key: |
|
|
single_key = single_key.strip('"\'') |
|
|
self.keys = [APIKey(key=single_key)] |
|
|
logger.info(f"Loaded 1 API key from GEMINI_API_KEY") |
|
|
return |
|
|
|
|
|
|
|
|
api_keys_str = os.getenv('GEMINI_API_KEYS', '').strip() |
|
|
if api_keys_str: |
|
|
keys = [key.strip().strip('"\'') for key in api_keys_str.split(',') if key.strip()] |
|
|
self.keys = [APIKey(key=key) for key in keys] |
|
|
logger.info(f"Loaded {len(keys)} API keys from GEMINI_API_KEYS") |
|
|
return |
|
|
|
|
|
logger.warning("No API keys found in environment variables") |
|
|
|
|
|
def get_available_key(self) -> Optional[str]: |
|
|
"""Get an available API key, considering rate limits.""" |
|
|
now = datetime.utcnow() |
|
|
|
|
|
for key_obj in self.keys: |
|
|
if not key_obj.is_available: |
|
|
if key_obj.rate_limit_reset and now >= key_obj.rate_limit_reset: |
|
|
key_obj.is_available = True |
|
|
key_obj.rate_limit_reset = None |
|
|
else: |
|
|
continue |
|
|
|
|
|
key_obj.last_used = now |
|
|
return key_obj.key |
|
|
|
|
|
return None |
|
|
|
|
|
def mark_key_unavailable(self, key: str, retry_after_seconds: int = 60): |
|
|
"""Mark a key as unavailable due to rate limiting.""" |
|
|
for key_obj in self.keys: |
|
|
if key_obj.key == key: |
|
|
key_obj.is_available = False |
|
|
key_obj.rate_limit_reset = datetime.utcnow() + timedelta(seconds=retry_after_seconds) |
|
|
logger.warning(f"Rate limit hit for API key. Will retry after {retry_after_seconds} seconds") |
|
|
return |
|
|
logger.warning(f"Tried to mark unknown API key as unavailable") |
|
|
|
|
|
api_key_manager = APIKeyManager() |