File size: 9,524 Bytes
35d68ae
34ce01e
 
35d68ae
34ce01e
 
 
35d68ae
34ce01e
35d68ae
 
 
a83f1cc
35d68ae
34ce01e
35d68ae
34ce01e
7ed3898
34ce01e
 
7ed3898
34ce01e
 
 
 
7ed3898
34ce01e
 
 
 
35d68ae
34ce01e
 
 
 
 
4e24d26
34ce01e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ed3898
34ce01e
 
 
 
 
 
7ed3898
34ce01e
fb0ad83
34ce01e
 
 
 
35d68ae
34ce01e
 
 
 
 
 
 
 
 
 
 
 
 
 
35d68ae
34ce01e
1910748
34ce01e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d758381
34ce01e
7ed3898
 
34ce01e
 
 
 
 
 
7ed3898
34ce01e
 
 
 
 
 
 
 
 
 
 
 
 
 
7ed3898
34ce01e
 
 
7ed3898
 
 
 
 
91a200a
 
34ce01e
7ed3898
 
c12ff5b
 
34ce01e
7ed3898
d758381
7ed3898
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34ce01e
 
91a200a
34ce01e
7ed3898
 
c12ff5b
 
34ce01e
7ed3898
c12ff5b
 
7ed3898
34ce01e
 
 
 
 
 
 
 
7ed3898
91a200a
7ed3898
91a200a
34ce01e
91a200a
34ce01e
91a200a
34ce01e
7ed3898
34ce01e
 
 
7ed3898
d16d09d
34ce01e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import os
import time
import gc
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import spaces  # Import spaces early to enable ZeroGPU support

# Configuration
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "600"))
DEFAULT_TEMPERATURE = float(os.environ.get("DEFAULT_TEMPERATURE", "0.2"))
DEFAULT_TOP_P = float(os.environ.get("DEFAULT_TOP_P", "0.9"))
HF_TOKEN = os.environ.get("HF_TOKEN")

MODEL_ID = "Alovestocode/router-gemma3-merged"

# Global model cache
_MODEL = None
_TOKENIZER = None
ACTIVE_STRATEGY = None

# Detect ZeroGPU environment
IS_ZEROGPU = os.environ.get("SPACE_RUNTIME_STATELESS", "0") == "1"
if os.environ.get("SPACES_ZERO_GPU") is not None:
    IS_ZEROGPU = True

def load_model():
    """Load the model on CPU. GPU movement happens inside @spaces.GPU decorated function."""
    global _MODEL, _TOKENIZER, ACTIVE_STRATEGY
    
    if _MODEL is None:
        print(f"Loading model {MODEL_ID}...")
        _TOKENIZER = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, token=HF_TOKEN)
        
        if IS_ZEROGPU:
            # ZeroGPU: load on CPU with device_map=None
            try:
                kwargs = {
                    "device_map": None,  # Stay on CPU for ZeroGPU
                    "quantization_config": BitsAndBytesConfig(load_in_8bit=True),
                    "trust_remote_code": True,
                    "low_cpu_mem_usage": True,
                    "token": HF_TOKEN,
                }
                _MODEL = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
                ACTIVE_STRATEGY = "8bit"
            except Exception:
                # Fallback to bf16 on CPU
                kwargs = {
                    "device_map": None,
                    "torch_dtype": torch.bfloat16,
                    "trust_remote_code": True,
                    "low_cpu_mem_usage": True,
                    "token": HF_TOKEN,
                }
                _MODEL = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
                ACTIVE_STRATEGY = "bf16"
        else:
            # Local environment: use GPU if available
            if torch.cuda.is_available():
                try:
                    kwargs = {
                        "device_map": "auto",
                        "quantization_config": BitsAndBytesConfig(load_in_8bit=True),
                        "trust_remote_code": True,
                        "low_cpu_mem_usage": True,
                        "token": HF_TOKEN,
                    }
                    _MODEL = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
                    ACTIVE_STRATEGY = "8bit"
                except Exception:
                    kwargs = {
                        "device_map": "auto",
                        "torch_dtype": torch.bfloat16,
                        "trust_remote_code": True,
                        "low_cpu_mem_usage": True,
                        "token": HF_TOKEN,
                    }
                    _MODEL = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
                    ACTIVE_STRATEGY = "bf16"
            else:
                kwargs = {
                    "device_map": "cpu",
                    "torch_dtype": torch.float32,
                    "trust_remote_code": True,
                    "low_cpu_mem_usage": True,
                    "token": HF_TOKEN,
                }
                _MODEL = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
                ACTIVE_STRATEGY = "cpu"
        
        _MODEL = _MODEL.eval()
        print(f"Loaded {MODEL_ID} with strategy='{ACTIVE_STRATEGY}' (ZeroGPU={IS_ZEROGPU})")
    
    return _MODEL, _TOKENIZER

def get_duration(prompt, max_new_tokens, temperature, top_p):
    """Estimate generation duration for ZeroGPU scheduling."""
    # Base time + token generation time
    base_duration = 20
    token_duration = max_new_tokens * 0.005  # ~200 tokens/second
    return base_duration + token_duration

@spaces.GPU(duration=get_duration)
def generate_response(prompt, max_new_tokens=MAX_NEW_TOKENS, temperature=DEFAULT_TEMPERATURE, top_p=DEFAULT_TOP_P):
    """
    Generate response using the router model.
    In ZeroGPU mode: model is loaded on CPU, moved to GPU here, then back to CPU after.
    """
    if not prompt.strip():
        return "ERROR: Prompt must not be empty."
    
    global _MODEL
    model, tokenizer = load_model()
    
    # In ZeroGPU: move model to GPU inside this @spaces.GPU function
    current_device = torch.device("cpu")
    if IS_ZEROGPU and torch.cuda.is_available():
        current_device = torch.device("cuda")
        model = model.to(current_device)
    elif torch.cuda.is_available() and not IS_ZEROGPU:
        current_device = torch.device("cuda")
    
    inputs = tokenizer(prompt, return_tensors="pt").to(current_device)
    eos = tokenizer.eos_token_id
    
    try:
        with torch.inference_mode():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                do_sample=True,
                eos_token_id=eos,
                pad_token_id=eos,
            )
        
        text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        result = text[len(prompt):].strip() or text.strip()
    finally:
        # In ZeroGPU: move model back to CPU to free GPU memory
        if IS_ZEROGPU and torch.cuda.is_available():
            _MODEL = model.to(torch.device("cpu"))
            torch.cuda.empty_cache()
    
    return result

# Gradio UI
with gr.Blocks(
    title="Router Model API - ZeroGPU",
    theme=gr.themes.Soft(
        primary_hue="indigo",
        secondary_hue="purple",
        neutral_hue="slate",
        radius_size="lg",
    ),
    css="""
        .main-header {
            text-align: center;
            padding: 20px;
            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
            color: white;
            border-radius: 10px;
            margin-bottom: 20px;
        }
        .info-box {
            background: #f0f0f0;
            padding: 15px;
            border-radius: 8px;
            margin-bottom: 20px;
        }
    """
) as demo:
    # Header
    gr.Markdown("""
    <div class="main-header">
        <h1>πŸš€ Router Model API - ZeroGPU</h1>
        <p>Intelligent routing agent for coordinating specialized AI agents</p>
    </div>
    """)
    
    with gr.Row():
        # Left Panel - Input
        with gr.Column(scale=1):
            gr.Markdown("### πŸ“ Input")
            prompt_input = gr.Textbox(
                label="Router Prompt",
                lines=10,
                placeholder="Enter your router prompt here...\n\nExample:\nYou are the Router Agent coordinating Math, Code, and General-Search specialists.\nUser query: Solve the integral of x^2 from 0 to 1",
            )
            
            with gr.Accordion("βš™οΈ Generation Parameters", open=True):
                max_tokens_input = gr.Slider(
                    minimum=64,
                    maximum=2048,
                    value=MAX_NEW_TOKENS,
                    step=16,
                    label="Max New Tokens",
                    info="Maximum number of tokens to generate"
                )
                temp_input = gr.Slider(
                    minimum=0.0,
                    maximum=2.0,
                    value=DEFAULT_TEMPERATURE,
                    step=0.05,
                    label="Temperature",
                    info="Controls randomness: lower = more deterministic"
                )
                top_p_input = gr.Slider(
                    minimum=0.0,
                    maximum=1.0,
                    value=DEFAULT_TOP_P,
                    step=0.05,
                    label="Top-p (Nucleus Sampling)",
                    info="Probability mass to consider for sampling"
                )
            
            generate_btn = gr.Button("πŸš€ Generate", variant="primary")
            clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
        
        # Right Panel - Output
        with gr.Column(scale=1):
            gr.Markdown("### πŸ“€ Output")
            output = gr.Textbox(
                label="Generated Response",
                lines=25,
                placeholder="Generated response will appear here...",
                show_copy_button=True,
            )
            
            with gr.Accordion("πŸ“š Model Information", open=False):
                gr.Markdown(f"""
                **Model:** `{MODEL_ID}`  
                **Strategy:** `{ACTIVE_STRATEGY or 'pending'}`  
                **ZeroGPU:** `{IS_ZEROGPU}`  
                **Max Tokens:** `{MAX_NEW_TOKENS}`  
                **Default Temperature:** `{DEFAULT_TEMPERATURE}`  
                **Default Top-p:** `{DEFAULT_TOP_P}`
                """)
    
    # Event handlers
    generate_btn.click(
        fn=generate_response,
        inputs=[prompt_input, max_tokens_input, temp_input, top_p_input],
        outputs=output,
    )
    
    clear_btn.click(
        fn=lambda: ("", ""),
        inputs=None,
        outputs=[prompt_input, output],
    )

# Launch the app
if __name__ == "__main__":
    print("Warm start skipped for ZeroGPU. Model will load on first request.")
    demo.queue(max_size=8).launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))