researchpilot-api / src /rag /llm_client.py
Subhadip007's picture
fix: switch HF endpoint to universal router URL (auto-selects provider)
f7e2e5e
import os
import json
import requests
from groq import Groq
from src.utils.logger import get_logger
from config.settings import (
GROQ_API_KEY,
HF_API_KEY,
LLM_TEMPERATURE,
LLM_MAX_TOKENS,
)
logger = get_logger(__name__)
# ---------------------------------------------------------------------------
# Model registry — single source of truth for every model ID in the system
# ---------------------------------------------------------------------------
# HF models are called via the HuggingFace Router endpoint.
# Groq models are called via the Groq SDK.
HF_MODELS = {"zai-org/GLM-5.1", "Qwen/Qwen3.5-9B", "Qwen/Qwen2.5-Coder-7B-Instruct"}
GROQ_MODELS = {"llama-3.3-70b-versatile"}
class MultiModelClient:
"""
Multi-model LLM client with strict linear fallback.
Fallback order (never changes regardless of query content):
1. zai-org/GLM-5.1 (HF — primary)
2. Qwen/Qwen3.5-9B (HF — first fallback)
3. llama-3.3-70b-versatile (Groq — second fallback)
4. Qwen/Qwen2.5-Coder-7B-Instruct (HF — final fallback)
"""
# Strict, ordered fallback chain — do NOT re-order at runtime
MODEL_CHAIN = [
"zai-org/GLM-5.1",
"Qwen/Qwen3.5-9B",
"llama-3.3-70b-versatile",
"Qwen/Qwen2.5-Coder-7B-Instruct",
]
def __init__(self):
if GROQ_API_KEY:
self.groq_client = Groq(api_key=GROQ_API_KEY)
else:
self.groq_client = None
self.hf_api_key = HF_API_KEY
# ------------------------------------------------------------------
# Transport helpers
# ------------------------------------------------------------------
def _call_hf(self, model_id, messages, temperature, max_tokens, stream=False):
if not self.hf_api_key:
raise ValueError("HF_API_KEY not configured")
url = "https://router.huggingface.co/v1/chat/completions"
headers = {
"Authorization": f"Bearer {self.hf_api_key}",
"Content-Type": "application/json"
}
payload = {
"model": model_id,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": stream
}
response = requests.post(url, headers=headers, json=payload, stream=stream)
if response.status_code == 429:
raise Exception("Rate limit (HTTP 429)")
if not response.ok:
raise Exception(f"HF Error: {response.text}")
if stream:
def generator():
for line in response.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if not choices:
continue
token = choices[0].get("delta", {}).get("content", "")
if token:
yield token
except:
pass
return generator()
else:
return response.json()["choices"][0]["message"]["content"]
def _call_groq(self, model_id, messages, temperature, max_tokens, stream=False):
if not self.groq_client:
raise ValueError("GROQ_API_KEY not configured")
response = self.groq_client.chat.completions.create(
model=model_id,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=stream
)
if stream:
def generator():
for chunk in response:
choices = chunk.choices
if not choices:
continue
token = choices[0].delta.content
if token:
yield token
return generator()
else:
return response.choices[0].message.content
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def generate(
self,
system_prompt: str,
user_prompt: str,
original_query: str = "",
history: list = None,
temperature: float = LLM_TEMPERATURE,
max_tokens: int = LLM_MAX_TOKENS,
stream: bool = False
):
"""
Generate response trying models in strict fallback order.
Returns a tuple of (result, model_used).
If stream=True, result is a generator.
Otherwise, result is a string.
"""
messages = [{"role": "system", "content": system_prompt}]
if history:
messages.extend(history)
messages.append({"role": "user", "content": user_prompt})
for model in self.MODEL_CHAIN:
try:
is_hf = model in HF_MODELS
logger.info(f"Attempting model: {model}")
if is_hf:
out = self._call_hf(model, messages, temperature, max_tokens, stream)
else:
out = self._call_groq(model, messages, temperature, max_tokens, stream)
logger.info(f"Model {model} selected successfully.")
return out, model
except Exception as e:
logger.warning(f"Model {model} failed: {e}")
continue
raise Exception("All models failed.")