File size: 5,878 Bytes
2671aea
9b7c6ff
 
2671aea
 
 
 
9b7c6ff
2671aea
 
 
 
 
 
76e224e
 
 
 
 
 
 
 
 
9b7c6ff
2671aea
76e224e
 
 
 
 
 
 
2671aea
 
76e224e
 
 
 
 
 
 
 
2671aea
9b7c6ff
 
 
 
 
 
 
76e224e
 
 
9b7c6ff
 
 
 
f7e2e5e
9b7c6ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3394ee5
 
 
 
9b7c6ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3394ee5
 
 
 
9b7c6ff
 
 
 
 
2671aea
76e224e
 
 
2671aea
 
 
9b7c6ff
 
5951bbe
9b7c6ff
 
 
 
2671aea
76e224e
9b7c6ff
 
 
2671aea
5951bbe
 
 
 
9b7c6ff
76e224e
9b7c6ff
76e224e
9b7c6ff
 
 
 
 
76e224e
9b7c6ff
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
import json
import requests
from groq import Groq
from src.utils.logger import get_logger
from config.settings import (
    GROQ_API_KEY,
    HF_API_KEY,
    LLM_TEMPERATURE,
    LLM_MAX_TOKENS,
)

logger = get_logger(__name__)

# ---------------------------------------------------------------------------
# Model registry — single source of truth for every model ID in the system
# ---------------------------------------------------------------------------
# HF models are called via the HuggingFace Router endpoint.
# Groq models are called via the Groq SDK.
HF_MODELS = {"zai-org/GLM-5.1", "Qwen/Qwen3.5-9B", "Qwen/Qwen2.5-Coder-7B-Instruct"}
GROQ_MODELS = {"llama-3.3-70b-versatile"}


class MultiModelClient:
    """
    Multi-model LLM client with strict linear fallback.

    Fallback order (never changes regardless of query content):
        1. zai-org/GLM-5.1          (HF — primary)
        2. Qwen/Qwen3.5-9B         (HF — first fallback)
        3. llama-3.3-70b-versatile  (Groq — second fallback)
        4. Qwen/Qwen2.5-Coder-7B-Instruct (HF — final fallback)
    """

    # Strict, ordered fallback chain — do NOT re-order at runtime
    MODEL_CHAIN = [
        "zai-org/GLM-5.1",
        "Qwen/Qwen3.5-9B",
        "llama-3.3-70b-versatile",
        "Qwen/Qwen2.5-Coder-7B-Instruct",
    ]

    def __init__(self):
        if GROQ_API_KEY:
            self.groq_client = Groq(api_key=GROQ_API_KEY)
        else:
            self.groq_client = None

        self.hf_api_key = HF_API_KEY

    # ------------------------------------------------------------------
    # Transport helpers
    # ------------------------------------------------------------------
    def _call_hf(self, model_id, messages, temperature, max_tokens, stream=False):
        if not self.hf_api_key:
            raise ValueError("HF_API_KEY not configured")

        url = "https://router.huggingface.co/v1/chat/completions"
        headers = {
            "Authorization": f"Bearer {self.hf_api_key}",
            "Content-Type": "application/json"
        }
        payload = {
            "model": model_id,
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "stream": stream
        }

        response = requests.post(url, headers=headers, json=payload, stream=stream)
        if response.status_code == 429:
            raise Exception("Rate limit (HTTP 429)")
        if not response.ok:
            raise Exception(f"HF Error: {response.text}")

        if stream:
            def generator():
                for line in response.iter_lines():
                    if line:
                        line = line.decode('utf-8')
                        if line.startswith("data: "):
                            data_str = line[6:]
                            if data_str.strip() == "[DONE]":
                                break
                            try:
                                data = json.loads(data_str)
                                choices = data.get("choices", [])
                                if not choices:
                                    continue
                                token = choices[0].get("delta", {}).get("content", "")
                                if token:
                                    yield token
                            except:
                                pass
            return generator()
        else:
            return response.json()["choices"][0]["message"]["content"]

    def _call_groq(self, model_id, messages, temperature, max_tokens, stream=False):
        if not self.groq_client:
            raise ValueError("GROQ_API_KEY not configured")

        response = self.groq_client.chat.completions.create(
            model=model_id,
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            stream=stream
        )
        if stream:
            def generator():
                for chunk in response:
                    choices = chunk.choices
                    if not choices:
                        continue
                    token = choices[0].delta.content
                    if token:
                        yield token
            return generator()
        else:
            return response.choices[0].message.content

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------
    def generate(
        self,
        system_prompt: str,
        user_prompt: str,
        original_query: str = "",
        history: list = None,
        temperature: float = LLM_TEMPERATURE,
        max_tokens: int = LLM_MAX_TOKENS,
        stream: bool = False
    ):
        """
        Generate response trying models in strict fallback order.
        Returns a tuple of (result, model_used).
        If stream=True, result is a generator.
        Otherwise, result is a string.
        """
        messages = [{"role": "system", "content": system_prompt}]
        if history:
            messages.extend(history)
        messages.append({"role": "user", "content": user_prompt})

        for model in self.MODEL_CHAIN:
            try:
                is_hf = model in HF_MODELS
                logger.info(f"Attempting model: {model}")
                if is_hf:
                    out = self._call_hf(model, messages, temperature, max_tokens, stream)
                else:
                    out = self._call_groq(model, messages, temperature, max_tokens, stream)

                logger.info(f"Model {model} selected successfully.")
                return out, model
            except Exception as e:
                logger.warning(f"Model {model} failed: {e}")
                continue

        raise Exception("All models failed.")