AEGIS-SECURE-API / api_keys.py
Akshat Bhatt
added code
e2e0c18
import os
import sys
from typing import List, Optional
import asyncio
from dataclasses import dataclass
from datetime import datetime, timedelta
from dotenv import load_dotenv
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load environment variables from .env file
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), '.env')
load_dotenv(dotenv_path=env_path)
# Debug: Print environment variables
logger.info(f"Current working directory: {os.getcwd()}")
logger.info(f"Loading .env from: {env_path}")
logger.info(f"GEMINI_API_KEY: {'*' * 8 + os.getenv('GEMINI_API_KEY', '')[-4:] if os.getenv('GEMINI_API_KEY') else 'Not set'}")
logger.info(f"GEMINI_API_KEYS: {'*' * 8 + os.getenv('GEMINI_API_KEYS', '')[-4:] if os.getenv('GEMINI_API_KEYS') else 'Not set'}")
@dataclass
class APIKey:
key: str
last_used: Optional[datetime] = None
is_available: bool = True
rate_limit_reset: Optional[datetime] = None
class APIKeyManager:
_instance = None
_lock = asyncio.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super(APIKeyManager, cls).__new__(cls)
cls._instance._initialize()
return cls._instance
def _initialize(self):
self.keys: List[APIKey] = []
self._current_index = 0
self._load_api_keys()
def _load_api_keys(self):
# Try to load from GEMINI_API_KEY first
single_key = os.getenv('GEMINI_API_KEY', '').strip()
if single_key:
single_key = single_key.strip('"\'')
self.keys = [APIKey(key=single_key)]
logger.info(f"Loaded 1 API key from GEMINI_API_KEY")
return
# Fall back to GEMINI_API_KEYS if GEMINI_API_KEY is not set
api_keys_str = os.getenv('GEMINI_API_KEYS', '').strip()
if api_keys_str:
keys = [key.strip().strip('"\'') for key in api_keys_str.split(',') if key.strip()]
self.keys = [APIKey(key=key) for key in keys]
logger.info(f"Loaded {len(keys)} API keys from GEMINI_API_KEYS")
return
logger.warning("No API keys found in environment variables")
def get_available_key(self) -> Optional[str]:
"""Get an available API key, considering rate limits."""
now = datetime.utcnow()
for key_obj in self.keys:
if not key_obj.is_available:
if key_obj.rate_limit_reset and now >= key_obj.rate_limit_reset:
key_obj.is_available = True
key_obj.rate_limit_reset = None
else:
continue
key_obj.last_used = now
return key_obj.key
return None
def mark_key_unavailable(self, key: str, retry_after_seconds: int = 60):
"""Mark a key as unavailable due to rate limiting."""
for key_obj in self.keys:
if key_obj.key == key:
key_obj.is_available = False
key_obj.rate_limit_reset = datetime.utcnow() + timedelta(seconds=retry_after_seconds)
logger.warning(f"Rate limit hit for API key. Will retry after {retry_after_seconds} seconds")
return
logger.warning(f"Tried to mark unknown API key as unavailable")
api_key_manager = APIKeyManager()