Spaces:
Running
Running
File size: 5,878 Bytes
2671aea 9b7c6ff 2671aea 9b7c6ff 2671aea 76e224e 9b7c6ff 2671aea 76e224e 2671aea 76e224e 2671aea 9b7c6ff 76e224e 9b7c6ff f7e2e5e 9b7c6ff 3394ee5 9b7c6ff 3394ee5 9b7c6ff 2671aea 76e224e 2671aea 9b7c6ff 5951bbe 9b7c6ff 2671aea 76e224e 9b7c6ff 2671aea 5951bbe 9b7c6ff 76e224e 9b7c6ff 76e224e 9b7c6ff 76e224e 9b7c6ff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | import os
import json
import requests
from groq import Groq
from src.utils.logger import get_logger
from config.settings import (
GROQ_API_KEY,
HF_API_KEY,
LLM_TEMPERATURE,
LLM_MAX_TOKENS,
)
logger = get_logger(__name__)
# ---------------------------------------------------------------------------
# Model registry — single source of truth for every model ID in the system
# ---------------------------------------------------------------------------
# HF models are called via the HuggingFace Router endpoint.
# Groq models are called via the Groq SDK.
HF_MODELS = {"zai-org/GLM-5.1", "Qwen/Qwen3.5-9B", "Qwen/Qwen2.5-Coder-7B-Instruct"}
GROQ_MODELS = {"llama-3.3-70b-versatile"}
class MultiModelClient:
"""
Multi-model LLM client with strict linear fallback.
Fallback order (never changes regardless of query content):
1. zai-org/GLM-5.1 (HF — primary)
2. Qwen/Qwen3.5-9B (HF — first fallback)
3. llama-3.3-70b-versatile (Groq — second fallback)
4. Qwen/Qwen2.5-Coder-7B-Instruct (HF — final fallback)
"""
# Strict, ordered fallback chain — do NOT re-order at runtime
MODEL_CHAIN = [
"zai-org/GLM-5.1",
"Qwen/Qwen3.5-9B",
"llama-3.3-70b-versatile",
"Qwen/Qwen2.5-Coder-7B-Instruct",
]
def __init__(self):
if GROQ_API_KEY:
self.groq_client = Groq(api_key=GROQ_API_KEY)
else:
self.groq_client = None
self.hf_api_key = HF_API_KEY
# ------------------------------------------------------------------
# Transport helpers
# ------------------------------------------------------------------
def _call_hf(self, model_id, messages, temperature, max_tokens, stream=False):
if not self.hf_api_key:
raise ValueError("HF_API_KEY not configured")
url = "https://router.huggingface.co/v1/chat/completions"
headers = {
"Authorization": f"Bearer {self.hf_api_key}",
"Content-Type": "application/json"
}
payload = {
"model": model_id,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": stream
}
response = requests.post(url, headers=headers, json=payload, stream=stream)
if response.status_code == 429:
raise Exception("Rate limit (HTTP 429)")
if not response.ok:
raise Exception(f"HF Error: {response.text}")
if stream:
def generator():
for line in response.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if not choices:
continue
token = choices[0].get("delta", {}).get("content", "")
if token:
yield token
except:
pass
return generator()
else:
return response.json()["choices"][0]["message"]["content"]
def _call_groq(self, model_id, messages, temperature, max_tokens, stream=False):
if not self.groq_client:
raise ValueError("GROQ_API_KEY not configured")
response = self.groq_client.chat.completions.create(
model=model_id,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=stream
)
if stream:
def generator():
for chunk in response:
choices = chunk.choices
if not choices:
continue
token = choices[0].delta.content
if token:
yield token
return generator()
else:
return response.choices[0].message.content
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def generate(
self,
system_prompt: str,
user_prompt: str,
original_query: str = "",
history: list = None,
temperature: float = LLM_TEMPERATURE,
max_tokens: int = LLM_MAX_TOKENS,
stream: bool = False
):
"""
Generate response trying models in strict fallback order.
Returns a tuple of (result, model_used).
If stream=True, result is a generator.
Otherwise, result is a string.
"""
messages = [{"role": "system", "content": system_prompt}]
if history:
messages.extend(history)
messages.append({"role": "user", "content": user_prompt})
for model in self.MODEL_CHAIN:
try:
is_hf = model in HF_MODELS
logger.info(f"Attempting model: {model}")
if is_hf:
out = self._call_hf(model, messages, temperature, max_tokens, stream)
else:
out = self._call_groq(model, messages, temperature, max_tokens, stream)
logger.info(f"Model {model} selected successfully.")
return out, model
except Exception as e:
logger.warning(f"Model {model} failed: {e}")
continue
raise Exception("All models failed.") |