Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import requests | |
| from groq import Groq | |
| from src.utils.logger import get_logger | |
| from config.settings import ( | |
| GROQ_API_KEY, | |
| HF_API_KEY, | |
| LLM_TEMPERATURE, | |
| LLM_MAX_TOKENS, | |
| ) | |
| logger = get_logger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Model registry — single source of truth for every model ID in the system | |
| # --------------------------------------------------------------------------- | |
| # HF models are called via the HuggingFace Router endpoint. | |
| # Groq models are called via the Groq SDK. | |
| HF_MODELS = {"zai-org/GLM-5.1", "Qwen/Qwen3.5-9B", "Qwen/Qwen2.5-Coder-7B-Instruct"} | |
| GROQ_MODELS = {"llama-3.3-70b-versatile"} | |
| class MultiModelClient: | |
| """ | |
| Multi-model LLM client with strict linear fallback. | |
| Fallback order (never changes regardless of query content): | |
| 1. zai-org/GLM-5.1 (HF — primary) | |
| 2. Qwen/Qwen3.5-9B (HF — first fallback) | |
| 3. llama-3.3-70b-versatile (Groq — second fallback) | |
| 4. Qwen/Qwen2.5-Coder-7B-Instruct (HF — final fallback) | |
| """ | |
| # Strict, ordered fallback chain — do NOT re-order at runtime | |
| MODEL_CHAIN = [ | |
| "zai-org/GLM-5.1", | |
| "Qwen/Qwen3.5-9B", | |
| "llama-3.3-70b-versatile", | |
| "Qwen/Qwen2.5-Coder-7B-Instruct", | |
| ] | |
| def __init__(self): | |
| if GROQ_API_KEY: | |
| self.groq_client = Groq(api_key=GROQ_API_KEY) | |
| else: | |
| self.groq_client = None | |
| self.hf_api_key = HF_API_KEY | |
| # ------------------------------------------------------------------ | |
| # Transport helpers | |
| # ------------------------------------------------------------------ | |
| def _call_hf(self, model_id, messages, temperature, max_tokens, stream=False): | |
| if not self.hf_api_key: | |
| raise ValueError("HF_API_KEY not configured") | |
| url = "https://router.huggingface.co/v1/chat/completions" | |
| headers = { | |
| "Authorization": f"Bearer {self.hf_api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "model": model_id, | |
| "messages": messages, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "stream": stream | |
| } | |
| response = requests.post(url, headers=headers, json=payload, stream=stream) | |
| if response.status_code == 429: | |
| raise Exception("Rate limit (HTTP 429)") | |
| if not response.ok: | |
| raise Exception(f"HF Error: {response.text}") | |
| if stream: | |
| def generator(): | |
| for line in response.iter_lines(): | |
| if line: | |
| line = line.decode('utf-8') | |
| if line.startswith("data: "): | |
| data_str = line[6:] | |
| if data_str.strip() == "[DONE]": | |
| break | |
| try: | |
| data = json.loads(data_str) | |
| choices = data.get("choices", []) | |
| if not choices: | |
| continue | |
| token = choices[0].get("delta", {}).get("content", "") | |
| if token: | |
| yield token | |
| except: | |
| pass | |
| return generator() | |
| else: | |
| return response.json()["choices"][0]["message"]["content"] | |
| def _call_groq(self, model_id, messages, temperature, max_tokens, stream=False): | |
| if not self.groq_client: | |
| raise ValueError("GROQ_API_KEY not configured") | |
| response = self.groq_client.chat.completions.create( | |
| model=model_id, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| stream=stream | |
| ) | |
| if stream: | |
| def generator(): | |
| for chunk in response: | |
| choices = chunk.choices | |
| if not choices: | |
| continue | |
| token = choices[0].delta.content | |
| if token: | |
| yield token | |
| return generator() | |
| else: | |
| return response.choices[0].message.content | |
| # ------------------------------------------------------------------ | |
| # Public API | |
| # ------------------------------------------------------------------ | |
| def generate( | |
| self, | |
| system_prompt: str, | |
| user_prompt: str, | |
| original_query: str = "", | |
| history: list = None, | |
| temperature: float = LLM_TEMPERATURE, | |
| max_tokens: int = LLM_MAX_TOKENS, | |
| stream: bool = False | |
| ): | |
| """ | |
| Generate response trying models in strict fallback order. | |
| Returns a tuple of (result, model_used). | |
| If stream=True, result is a generator. | |
| Otherwise, result is a string. | |
| """ | |
| messages = [{"role": "system", "content": system_prompt}] | |
| if history: | |
| messages.extend(history) | |
| messages.append({"role": "user", "content": user_prompt}) | |
| for model in self.MODEL_CHAIN: | |
| try: | |
| is_hf = model in HF_MODELS | |
| logger.info(f"Attempting model: {model}") | |
| if is_hf: | |
| out = self._call_hf(model, messages, temperature, max_tokens, stream) | |
| else: | |
| out = self._call_groq(model, messages, temperature, max_tokens, stream) | |
| logger.info(f"Model {model} selected successfully.") | |
| return out, model | |
| except Exception as e: | |
| logger.warning(f"Model {model} failed: {e}") | |
| continue | |
| raise Exception("All models failed.") |