Alovestocode's picture
Fix: Add default parameters to generate_response function
fb0ad83 verified
raw
history blame
9.52 kB
import os
import time
import gc
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import spaces # Import spaces early to enable ZeroGPU support
# Configuration
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "600"))
DEFAULT_TEMPERATURE = float(os.environ.get("DEFAULT_TEMPERATURE", "0.2"))
DEFAULT_TOP_P = float(os.environ.get("DEFAULT_TOP_P", "0.9"))
HF_TOKEN = os.environ.get("HF_TOKEN")
MODEL_ID = "Alovestocode/router-gemma3-merged"
# Global model cache
_MODEL = None
_TOKENIZER = None
ACTIVE_STRATEGY = None
# Detect ZeroGPU environment
IS_ZEROGPU = os.environ.get("SPACE_RUNTIME_STATELESS", "0") == "1"
if os.environ.get("SPACES_ZERO_GPU") is not None:
IS_ZEROGPU = True
def load_model():
"""Load the model on CPU. GPU movement happens inside @spaces.GPU decorated function."""
global _MODEL, _TOKENIZER, ACTIVE_STRATEGY
if _MODEL is None:
print(f"Loading model {MODEL_ID}...")
_TOKENIZER = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, token=HF_TOKEN)
if IS_ZEROGPU:
# ZeroGPU: load on CPU with device_map=None
try:
kwargs = {
"device_map": None, # Stay on CPU for ZeroGPU
"quantization_config": BitsAndBytesConfig(load_in_8bit=True),
"trust_remote_code": True,
"low_cpu_mem_usage": True,
"token": HF_TOKEN,
}
_MODEL = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
ACTIVE_STRATEGY = "8bit"
except Exception:
# Fallback to bf16 on CPU
kwargs = {
"device_map": None,
"torch_dtype": torch.bfloat16,
"trust_remote_code": True,
"low_cpu_mem_usage": True,
"token": HF_TOKEN,
}
_MODEL = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
ACTIVE_STRATEGY = "bf16"
else:
# Local environment: use GPU if available
if torch.cuda.is_available():
try:
kwargs = {
"device_map": "auto",
"quantization_config": BitsAndBytesConfig(load_in_8bit=True),
"trust_remote_code": True,
"low_cpu_mem_usage": True,
"token": HF_TOKEN,
}
_MODEL = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
ACTIVE_STRATEGY = "8bit"
except Exception:
kwargs = {
"device_map": "auto",
"torch_dtype": torch.bfloat16,
"trust_remote_code": True,
"low_cpu_mem_usage": True,
"token": HF_TOKEN,
}
_MODEL = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
ACTIVE_STRATEGY = "bf16"
else:
kwargs = {
"device_map": "cpu",
"torch_dtype": torch.float32,
"trust_remote_code": True,
"low_cpu_mem_usage": True,
"token": HF_TOKEN,
}
_MODEL = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
ACTIVE_STRATEGY = "cpu"
_MODEL = _MODEL.eval()
print(f"Loaded {MODEL_ID} with strategy='{ACTIVE_STRATEGY}' (ZeroGPU={IS_ZEROGPU})")
return _MODEL, _TOKENIZER
def get_duration(prompt, max_new_tokens, temperature, top_p):
"""Estimate generation duration for ZeroGPU scheduling."""
# Base time + token generation time
base_duration = 20
token_duration = max_new_tokens * 0.005 # ~200 tokens/second
return base_duration + token_duration
@spaces.GPU(duration=get_duration)
def generate_response(prompt, max_new_tokens=MAX_NEW_TOKENS, temperature=DEFAULT_TEMPERATURE, top_p=DEFAULT_TOP_P):
"""
Generate response using the router model.
In ZeroGPU mode: model is loaded on CPU, moved to GPU here, then back to CPU after.
"""
if not prompt.strip():
return "ERROR: Prompt must not be empty."
global _MODEL
model, tokenizer = load_model()
# In ZeroGPU: move model to GPU inside this @spaces.GPU function
current_device = torch.device("cpu")
if IS_ZEROGPU and torch.cuda.is_available():
current_device = torch.device("cuda")
model = model.to(current_device)
elif torch.cuda.is_available() and not IS_ZEROGPU:
current_device = torch.device("cuda")
inputs = tokenizer(prompt, return_tensors="pt").to(current_device)
eos = tokenizer.eos_token_id
try:
with torch.inference_mode():
output_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
eos_token_id=eos,
pad_token_id=eos,
)
text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
result = text[len(prompt):].strip() or text.strip()
finally:
# In ZeroGPU: move model back to CPU to free GPU memory
if IS_ZEROGPU and torch.cuda.is_available():
_MODEL = model.to(torch.device("cpu"))
torch.cuda.empty_cache()
return result
# Gradio UI
with gr.Blocks(
title="Router Model API - ZeroGPU",
theme=gr.themes.Soft(
primary_hue="indigo",
secondary_hue="purple",
neutral_hue="slate",
radius_size="lg",
),
css="""
.main-header {
text-align: center;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border-radius: 10px;
margin-bottom: 20px;
}
.info-box {
background: #f0f0f0;
padding: 15px;
border-radius: 8px;
margin-bottom: 20px;
}
"""
) as demo:
# Header
gr.Markdown("""
<div class="main-header">
<h1>πŸš€ Router Model API - ZeroGPU</h1>
<p>Intelligent routing agent for coordinating specialized AI agents</p>
</div>
""")
with gr.Row():
# Left Panel - Input
with gr.Column(scale=1):
gr.Markdown("### πŸ“ Input")
prompt_input = gr.Textbox(
label="Router Prompt",
lines=10,
placeholder="Enter your router prompt here...\n\nExample:\nYou are the Router Agent coordinating Math, Code, and General-Search specialists.\nUser query: Solve the integral of x^2 from 0 to 1",
)
with gr.Accordion("βš™οΈ Generation Parameters", open=True):
max_tokens_input = gr.Slider(
minimum=64,
maximum=2048,
value=MAX_NEW_TOKENS,
step=16,
label="Max New Tokens",
info="Maximum number of tokens to generate"
)
temp_input = gr.Slider(
minimum=0.0,
maximum=2.0,
value=DEFAULT_TEMPERATURE,
step=0.05,
label="Temperature",
info="Controls randomness: lower = more deterministic"
)
top_p_input = gr.Slider(
minimum=0.0,
maximum=1.0,
value=DEFAULT_TOP_P,
step=0.05,
label="Top-p (Nucleus Sampling)",
info="Probability mass to consider for sampling"
)
generate_btn = gr.Button("πŸš€ Generate", variant="primary")
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
# Right Panel - Output
with gr.Column(scale=1):
gr.Markdown("### πŸ“€ Output")
output = gr.Textbox(
label="Generated Response",
lines=25,
placeholder="Generated response will appear here...",
show_copy_button=True,
)
with gr.Accordion("πŸ“š Model Information", open=False):
gr.Markdown(f"""
**Model:** `{MODEL_ID}`
**Strategy:** `{ACTIVE_STRATEGY or 'pending'}`
**ZeroGPU:** `{IS_ZEROGPU}`
**Max Tokens:** `{MAX_NEW_TOKENS}`
**Default Temperature:** `{DEFAULT_TEMPERATURE}`
**Default Top-p:** `{DEFAULT_TOP_P}`
""")
# Event handlers
generate_btn.click(
fn=generate_response,
inputs=[prompt_input, max_tokens_input, temp_input, top_p_input],
outputs=output,
)
clear_btn.click(
fn=lambda: ("", ""),
inputs=None,
outputs=[prompt_input, output],
)
# Launch the app
if __name__ == "__main__":
print("Warm start skipped for ZeroGPU. Model will load on first request.")
demo.queue(max_size=8).launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))