import os import time import gc import torch import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig import spaces # Import spaces early to enable ZeroGPU support # Configuration MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "600")) DEFAULT_TEMPERATURE = float(os.environ.get("DEFAULT_TEMPERATURE", "0.2")) DEFAULT_TOP_P = float(os.environ.get("DEFAULT_TOP_P", "0.9")) HF_TOKEN = os.environ.get("HF_TOKEN") MODEL_ID = "Alovestocode/router-gemma3-merged" # Global model cache _MODEL = None _TOKENIZER = None ACTIVE_STRATEGY = None # Detect ZeroGPU environment IS_ZEROGPU = os.environ.get("SPACE_RUNTIME_STATELESS", "0") == "1" if os.environ.get("SPACES_ZERO_GPU") is not None: IS_ZEROGPU = True def load_model(): """Load the model on CPU. GPU movement happens inside @spaces.GPU decorated function.""" global _MODEL, _TOKENIZER, ACTIVE_STRATEGY if _MODEL is None: print(f"Loading model {MODEL_ID}...") _TOKENIZER = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, token=HF_TOKEN) if IS_ZEROGPU: # ZeroGPU: load on CPU with device_map=None try: kwargs = { "device_map": None, # Stay on CPU for ZeroGPU "quantization_config": BitsAndBytesConfig(load_in_8bit=True), "trust_remote_code": True, "low_cpu_mem_usage": True, "token": HF_TOKEN, } _MODEL = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs) ACTIVE_STRATEGY = "8bit" except Exception: # Fallback to bf16 on CPU kwargs = { "device_map": None, "torch_dtype": torch.bfloat16, "trust_remote_code": True, "low_cpu_mem_usage": True, "token": HF_TOKEN, } _MODEL = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs) ACTIVE_STRATEGY = "bf16" else: # Local environment: use GPU if available if torch.cuda.is_available(): try: kwargs = { "device_map": "auto", "quantization_config": BitsAndBytesConfig(load_in_8bit=True), "trust_remote_code": True, "low_cpu_mem_usage": True, "token": HF_TOKEN, } _MODEL = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs) ACTIVE_STRATEGY = "8bit" except Exception: kwargs = { "device_map": "auto", "torch_dtype": torch.bfloat16, "trust_remote_code": True, "low_cpu_mem_usage": True, "token": HF_TOKEN, } _MODEL = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs) ACTIVE_STRATEGY = "bf16" else: kwargs = { "device_map": "cpu", "torch_dtype": torch.float32, "trust_remote_code": True, "low_cpu_mem_usage": True, "token": HF_TOKEN, } _MODEL = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs) ACTIVE_STRATEGY = "cpu" _MODEL = _MODEL.eval() print(f"Loaded {MODEL_ID} with strategy='{ACTIVE_STRATEGY}' (ZeroGPU={IS_ZEROGPU})") return _MODEL, _TOKENIZER def get_duration(prompt, max_new_tokens, temperature, top_p): """Estimate generation duration for ZeroGPU scheduling.""" # Base time + token generation time base_duration = 20 token_duration = max_new_tokens * 0.005 # ~200 tokens/second return base_duration + token_duration @spaces.GPU(duration=get_duration) def generate_response(prompt, max_new_tokens=MAX_NEW_TOKENS, temperature=DEFAULT_TEMPERATURE, top_p=DEFAULT_TOP_P): """ Generate response using the router model. In ZeroGPU mode: model is loaded on CPU, moved to GPU here, then back to CPU after. """ if not prompt.strip(): return "ERROR: Prompt must not be empty." global _MODEL model, tokenizer = load_model() # In ZeroGPU: move model to GPU inside this @spaces.GPU function current_device = torch.device("cpu") if IS_ZEROGPU and torch.cuda.is_available(): current_device = torch.device("cuda") model = model.to(current_device) elif torch.cuda.is_available() and not IS_ZEROGPU: current_device = torch.device("cuda") inputs = tokenizer(prompt, return_tensors="pt").to(current_device) eos = tokenizer.eos_token_id try: with torch.inference_mode(): output_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, eos_token_id=eos, pad_token_id=eos, ) text = tokenizer.decode(output_ids[0], skip_special_tokens=True) result = text[len(prompt):].strip() or text.strip() finally: # In ZeroGPU: move model back to CPU to free GPU memory if IS_ZEROGPU and torch.cuda.is_available(): _MODEL = model.to(torch.device("cpu")) torch.cuda.empty_cache() return result # Gradio UI with gr.Blocks( title="Router Model API - ZeroGPU", theme=gr.themes.Soft( primary_hue="indigo", secondary_hue="purple", neutral_hue="slate", radius_size="lg", ), css=""" .main-header { text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px; margin-bottom: 20px; } .info-box { background: #f0f0f0; padding: 15px; border-radius: 8px; margin-bottom: 20px; } """ ) as demo: # Header gr.Markdown("""

🚀 Router Model API - ZeroGPU

Intelligent routing agent for coordinating specialized AI agents

""") with gr.Row(): # Left Panel - Input with gr.Column(scale=1): gr.Markdown("### 📝 Input") prompt_input = gr.Textbox( label="Router Prompt", lines=10, placeholder="Enter your router prompt here...\n\nExample:\nYou are the Router Agent coordinating Math, Code, and General-Search specialists.\nUser query: Solve the integral of x^2 from 0 to 1", ) with gr.Accordion("⚙️ Generation Parameters", open=True): max_tokens_input = gr.Slider( minimum=64, maximum=2048, value=MAX_NEW_TOKENS, step=16, label="Max New Tokens", info="Maximum number of tokens to generate" ) temp_input = gr.Slider( minimum=0.0, maximum=2.0, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature", info="Controls randomness: lower = more deterministic" ) top_p_input = gr.Slider( minimum=0.0, maximum=1.0, value=DEFAULT_TOP_P, step=0.05, label="Top-p (Nucleus Sampling)", info="Probability mass to consider for sampling" ) generate_btn = gr.Button("🚀 Generate", variant="primary") clear_btn = gr.Button("🗑️ Clear", variant="secondary") # Right Panel - Output with gr.Column(scale=1): gr.Markdown("### 📤 Output") output = gr.Textbox( label="Generated Response", lines=25, placeholder="Generated response will appear here...", show_copy_button=True, ) with gr.Accordion("📚 Model Information", open=False): gr.Markdown(f""" **Model:** `{MODEL_ID}` **Strategy:** `{ACTIVE_STRATEGY or 'pending'}` **ZeroGPU:** `{IS_ZEROGPU}` **Max Tokens:** `{MAX_NEW_TOKENS}` **Default Temperature:** `{DEFAULT_TEMPERATURE}` **Default Top-p:** `{DEFAULT_TOP_P}` """) # Event handlers generate_btn.click( fn=generate_response, inputs=[prompt_input, max_tokens_input, temp_input, top_p_input], outputs=output, ) clear_btn.click( fn=lambda: ("", ""), inputs=None, outputs=[prompt_input, output], ) # Launch the app if __name__ == "__main__": print("Warm start skipped for ZeroGPU. Model will load on first request.") demo.queue(max_size=8).launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))