from __future__ import annotations import os from functools import lru_cache from typing import List, Optional, Tuple import torch from fastapi import HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel try: import spaces # type: ignore except Exception: # pragma: no cover class _SpacesShim: @staticmethod def GPU(*_args, **_kwargs): def identity(fn): return fn return identity spaces = _SpacesShim() from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, ) import gradio as gr # Configuration MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "600")) DEFAULT_TEMPERATURE = float(os.environ.get("DEFAULT_TEMPERATURE", "0.2")) DEFAULT_TOP_P = float(os.environ.get("DEFAULT_TOP_P", "0.9")) HF_TOKEN = os.environ.get("HF_TOKEN") MODEL_ID = "Alovestocode/router-gemma3-merged" # Model state _MODEL = None ACTIVE_STRATEGY: Optional[str] = None # Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, token=HF_TOKEN) print(f"Loaded tokenizer from {MODEL_ID}") # Pydantic models class GeneratePayload(BaseModel): prompt: str max_new_tokens: Optional[int] = None temperature: Optional[float] = None top_p: Optional[float] = None class GenerateResponse(BaseModel): text: str # Detect ZeroGPU environment IS_ZEROGPU = os.environ.get("SPACE_RUNTIME_STATELESS", "0") == "1" if os.environ.get("SPACES_ZERO_GPU") is not None: IS_ZEROGPU = True # Model loading - ZeroGPU pattern: load on CPU, move to GPU inside @spaces.GPU functions def get_model() -> AutoModelForCausalLM: """Load the model on CPU. GPU movement happens inside @spaces.GPU decorated function.""" global _MODEL, ACTIVE_STRATEGY if _MODEL is None: # In ZeroGPU: always load on CPU first, will use GPU only in @spaces.GPU functions # For local runs: use CUDA if available if IS_ZEROGPU: # ZeroGPU: load on CPU with device_map=None try: kwargs = { "device_map": None, # Stay on CPU for ZeroGPU "quantization_config": BitsAndBytesConfig(load_in_8bit=True), "trust_remote_code": True, "low_cpu_mem_usage": True, "token": HF_TOKEN, } model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs) ACTIVE_STRATEGY = "8bit" except Exception: # Fallback to bf16 on CPU kwargs = { "device_map": None, "torch_dtype": torch.bfloat16, "trust_remote_code": True, "low_cpu_mem_usage": True, "token": HF_TOKEN, } model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs) ACTIVE_STRATEGY = "bf16" else: # Local environment: use GPU if available if torch.cuda.is_available(): try: kwargs = { "device_map": "auto", "quantization_config": BitsAndBytesConfig(load_in_8bit=True), "trust_remote_code": True, "low_cpu_mem_usage": True, "token": HF_TOKEN, } model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs) ACTIVE_STRATEGY = "8bit" except Exception: kwargs = { "device_map": "auto", "torch_dtype": torch.bfloat16, "trust_remote_code": True, "low_cpu_mem_usage": True, "token": HF_TOKEN, } model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs) ACTIVE_STRATEGY = "bf16" else: kwargs = { "device_map": "cpu", "torch_dtype": torch.float32, "trust_remote_code": True, "low_cpu_mem_usage": True, "token": HF_TOKEN, } model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs) ACTIVE_STRATEGY = "cpu" _MODEL = model.eval() print(f"Loaded {MODEL_ID} with strategy='{ACTIVE_STRATEGY}' (ZeroGPU={IS_ZEROGPU})") return _MODEL # ZeroGPU decorated function - moves model to GPU inside this function @spaces.GPU(duration=300) def _generate_with_gpu( prompt: str, max_new_tokens: int = MAX_NEW_TOKENS, temperature: float = DEFAULT_TEMPERATURE, top_p: float = DEFAULT_TOP_P, ) -> str: """ GPU generation wrapper for ZeroGPU. In ZeroGPU mode: model is loaded on CPU, moved to GPU here, then back to CPU after. """ if not prompt.strip(): raise ValueError("Prompt must not be empty.") global _MODEL model = get_model() # In ZeroGPU: move model to GPU inside this @spaces.GPU function # For local: model might already be on GPU current_device = torch.device("cpu") if IS_ZEROGPU and torch.cuda.is_available(): current_device = torch.device("cuda") model = model.to(current_device) elif torch.cuda.is_available() and not IS_ZEROGPU: current_device = torch.device("cuda") inputs = tokenizer(prompt, return_tensors="pt").to(current_device) eos = tokenizer.eos_token_id try: with torch.inference_mode(): output_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, eos_token_id=eos, pad_token_id=eos, ) text = tokenizer.decode(output_ids[0], skip_special_tokens=True) result = text[len(prompt):].strip() or text.strip() finally: # In ZeroGPU: move model back to CPU to free GPU memory if IS_ZEROGPU and torch.cuda.is_available(): _MODEL = model.to(torch.device("cpu")) torch.cuda.empty_cache() return result # Gradio UI def gradio_generate( prompt: str, max_new_tokens: int, temperature: float, top_p: float, ) -> tuple[str, str]: """Gradio handler for generation.""" if not prompt.strip(): return "ERROR: Prompt must not be empty.", "❌ Prompt required." try: text = _generate_with_gpu(prompt, max_new_tokens, temperature, top_p) return text, "✅ Generation successful." except Exception as exc: return f"ERROR: {exc}", f"❌ Generation failed: {exc}" # Create Gradio interface with gr.Blocks(title="Router Model API - ZeroGPU", theme=gr.themes.Soft()) as gradio_app: gr.Markdown("# 🚀 Router Model API - ZeroGPU") gr.Markdown("Intelligent routing agent for coordinating specialized AI agents") with gr.Row(): with gr.Column(): prompt_input = gr.Textbox( label="Router Prompt", lines=8, placeholder="Enter your router prompt here...", ) max_tokens_input = gr.Slider( minimum=64, maximum=2048, value=MAX_NEW_TOKENS, step=16, label="Max New Tokens", ) temp_input = gr.Slider( minimum=0.0, maximum=2.0, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature", ) top_p_input = gr.Slider( minimum=0.0, maximum=1.0, value=DEFAULT_TOP_P, step=0.05, label="Top-p", ) generate_btn = gr.Button("🚀 Generate", variant="primary") with gr.Column(): output = gr.Textbox( label="Generated Response", lines=20, show_copy_button=True, ) status = gr.Markdown("Status: Ready") generate_btn.click( fn=gradio_generate, inputs=[prompt_input, max_tokens_input, temp_input, top_p_input], outputs=[output, status], ) # API routes - add directly to Gradio's FastAPI app @gradio_app.app.get("/health") def api_health(): """Health check endpoint.""" return { "status": "ok", "model": MODEL_ID, "strategy": ACTIVE_STRATEGY or "pending", } @gradio_app.app.post("/v1/generate") async def api_generate(payload: GeneratePayload): """Generate endpoint.""" try: text = _generate_with_gpu( prompt=payload.prompt, max_new_tokens=payload.max_new_tokens or MAX_NEW_TOKENS, temperature=payload.temperature or DEFAULT_TEMPERATURE, top_p=payload.top_p or DEFAULT_TOP_P, ) return {"text": text} except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) # Setup print("Warm start skipped for ZeroGPU. Model will load on first request.") gradio_app.queue(max_size=8) app = gradio_app if __name__ == "__main__": # pragma: no cover app.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))