"""Gradio UI setup"""
import os
import time
import gradio as gr
import spaces
from config import TITLE, DESCRIPTION, CSS, MEDSWIN_MODELS, DEFAULT_MEDICAL_MODEL
import config
from indexing import create_or_update_index
from pipeline import stream_chat
from voice import transcribe_audio, generate_speech
from models import (
initialize_medical_model,
is_model_loaded,
get_model_loading_state,
set_model_loading_state,
initialize_tts_model,
initialize_whisper_model,
TTS_AVAILABLE,
SNAC_AVAILABLE,
WHISPER_AVAILABLE,
)
from logger import logger
MAX_DURATION = 120
def create_demo():
"""Create and return Gradio demo interface"""
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
with gr.Row(elem_classes="main-container"):
with gr.Column(elem_classes="upload-section"):
file_upload = gr.File(
file_count="multiple",
label="Drag and Drop Files Here",
file_types=[".pdf", ".txt", ".doc", ".docx", ".md", ".json", ".xml", ".csv"],
elem_id="file-upload"
)
upload_button = gr.Button("Upload & Index", elem_classes="upload-button")
status_output = gr.Textbox(
label="Status",
placeholder="Upload files to start...",
interactive=False
)
file_info_output = gr.HTML(
label="File Information",
elem_classes="processing-info"
)
upload_button.click(
fn=create_or_update_index,
inputs=[file_upload],
outputs=[status_output, file_info_output]
)
with gr.Column(elem_classes="chatbot-container"):
chatbot = gr.Chatbot(
height=500,
placeholder="Chat with MedSwin... Type your question below.",
show_label=False,
type="messages"
)
with gr.Row(elem_classes="input-row"):
message_input = gr.Textbox(
placeholder="Type your medical question here...",
show_label=False,
container=False,
lines=1,
scale=10
)
mic_button = gr.Audio(
sources=["microphone"],
type="filepath",
label="",
show_label=False,
container=False,
scale=1
)
submit_button = gr.Button("ā¤", elem_classes="submit-btn", scale=1)
recording_timer = gr.Textbox(
value="",
label="",
show_label=False,
interactive=False,
visible=False,
container=False,
elem_classes="recording-timer"
)
recording_start_time = [None]
def handle_recording_start():
"""Called when recording starts"""
recording_start_time[0] = time.time()
return gr.update(visible=True, value="Recording... 0s")
def handle_recording_stop(audio):
"""Called when recording stops"""
recording_start_time[0] = None
if audio is None:
return gr.update(visible=False, value=""), ""
transcribed = transcribe_audio(audio)
return gr.update(visible=False, value=""), transcribed
mic_button.start_recording(
fn=handle_recording_start,
outputs=[recording_timer]
)
mic_button.stop_recording(
fn=handle_recording_stop,
inputs=[mic_button],
outputs=[recording_timer, message_input]
)
with gr.Row():
tts_button = gr.Button("š Play Response", visible=False, size="sm")
tts_audio = gr.Audio(label="", visible=True, autoplay=True, show_label=False, container=False)
def generate_speech_from_chat(history):
"""Extract last assistant message and generate speech"""
if not history or len(history) == 0:
logger.warning("[TTS] No history available")
return None
last_msg = history[-1]
if last_msg.get("role") == "assistant":
text = last_msg.get("content", "").replace(" š", "").strip()
if text:
logger.info(f"[TTS] Generating speech for text: {text[:100]}...")
try:
audio_path = generate_speech(text)
if audio_path and os.path.exists(audio_path):
logger.info(f"[TTS] ā
Generated audio successfully: {audio_path}")
return audio_path
else:
logger.warning(f"[TTS] ā Failed to generate audio or file doesn't exist: {audio_path}")
return None
except Exception as e:
logger.error(f"[TTS] Error generating speech: {e}")
import traceback
logger.debug(f"[TTS] Traceback: {traceback.format_exc()}")
return None
else:
logger.warning("[TTS] Empty text extracted from assistant message")
else:
logger.warning(f"[TTS] Last message is not from assistant: {last_msg.get('role')}")
return None
def update_tts_button(history):
if history and len(history) > 0 and history[-1].get("role") == "assistant":
return gr.update(visible=True)
return gr.update(visible=False)
chatbot.change(
fn=update_tts_button,
inputs=[chatbot],
outputs=[tts_button]
)
tts_button.click(
fn=generate_speech_from_chat,
inputs=[chatbot],
outputs=[tts_audio]
)
with gr.Accordion("āļø Advanced Settings", open=False):
with gr.Row():
disable_agentic_reasoning = gr.Checkbox(
value=False,
label="Disable agentic reasoning",
info="Use MedSwin model alone without agentic reasoning, RAG, or web search"
)
show_agentic_thought = gr.Button(
"Show agentic thought",
size="sm"
)
enable_clinical_intake = gr.Checkbox(
value=True,
label="Enable clinical intake (max 5 Q&A)",
info="Ask focused follow-up questions before breaking down the case"
)
agentic_thoughts_box = gr.Textbox(
label="Agentic Thoughts",
placeholder="Internal thoughts from MedSwin and supervisor will appear here...",
lines=8,
max_lines=15,
interactive=False,
visible=False,
elem_classes="agentic-thoughts"
)
with gr.Row():
use_rag = gr.Checkbox(
value=False,
label="Enable Document RAG",
info="Answer based on uploaded documents (upload required)"
)
use_web_search = gr.Checkbox(
value=False,
label="Enable Web Search (MCP)",
info="Fetch knowledge from online medical resources"
)
medical_model = gr.Radio(
choices=list(MEDSWIN_MODELS.keys()),
value=DEFAULT_MEDICAL_MODEL,
label="Medical Model",
info="MedSwin DT (default), others download on selection"
)
model_status = gr.Textbox(
value="Checking model status...",
label="Model Status",
interactive=False,
visible=True,
lines=3,
max_lines=3,
elem_classes="model-status"
)
system_prompt = gr.Textbox(
value="As a medical specialist, provide detailed and accurate answers based on the provided medical documents and context. Ensure all information is clinically accurate and cite sources when available. Provide answers directly without conversational prefixes like 'Here is...', 'This is...', or 'To answer your question...'. Start with the actual content immediately.",
label="System Prompt",
lines=3
)
with gr.Tab("Generation Parameters"):
temperature = gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.2,
label="Temperature"
)
max_new_tokens = gr.Slider(
minimum=512,
maximum=4096,
step=128,
value=2048,
label="Max New Tokens",
info="Increased for medical models to prevent early stopping"
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.7,
label="Top P"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=50,
label="Top K"
)
penalty = gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.2,
label="Repetition Penalty"
)
with gr.Tab("Retrieval Parameters"):
retriever_k = gr.Slider(
minimum=5,
maximum=30,
step=1,
value=15,
label="Initial Retrieval Size (Top K)"
)
merge_threshold = gr.Slider(
minimum=0.1,
maximum=0.9,
step=0.1,
value=0.5,
label="Merge Threshold (lower = more merging)"
)
# MedSwin Model Links
gr.Markdown(
"""
š MedSwin Models on Hugging Face
Click any model name to view details on Hugging Face
"""
)
show_thoughts_state = gr.State(value=False)
def toggle_thoughts_box(current_state):
"""Toggle visibility of agentic thoughts box"""
new_state = not current_state
return gr.update(visible=new_state), new_state
show_agentic_thought.click(
fn=toggle_thoughts_box,
inputs=[show_thoughts_state],
outputs=[agentic_thoughts_box, show_thoughts_state]
)
# GPU-decorated function to load any model (for user selection)
# @spaces.GPU(max_duration=MAX_DURATION)
def load_model_with_gpu(model_name):
"""Load medical model (GPU-decorated for ZeroGPU compatibility)"""
try:
if not is_model_loaded(model_name):
logger.info(f"Loading medical model: {model_name}...")
set_model_loading_state(model_name, "loading")
try:
initialize_medical_model(model_name)
logger.info(f"ā
Medical model {model_name} loaded successfully!")
return "ā
The model has been loaded successfully", True
except Exception as e:
logger.error(f"Failed to load medical model {model_name}: {e}")
set_model_loading_state(model_name, "error")
return f"ā Error loading model: {str(e)[:100]}", False
else:
logger.info(f"Medical model {model_name} is already loaded")
return "ā
The model has been loaded successfully", True
except Exception as e:
logger.error(f"Error loading model {model_name}: {e}")
return f"ā Error: {str(e)[:100]}", False
def load_model_and_update_status(model_name):
"""Load model and update status, return status text and whether model is ready"""
try:
status_lines = []
# Medical model status
if is_model_loaded(model_name):
status_lines.append(f"ā
MedSwin ({model_name}): loaded and ready")
else:
state = get_model_loading_state(model_name)
if state == "loading":
status_lines.append(f"ā³ MedSwin ({model_name}): loading...")
elif state == "error":
status_lines.append(f"ā MedSwin ({model_name}): error loading")
else:
# Use GPU-decorated function to load the model
try:
result = load_model_with_gpu(model_name)
if result and isinstance(result, tuple) and len(result) == 2:
status_text, is_ready = result
if is_ready:
status_lines.append(f"ā
MedSwin ({model_name}): loaded and ready")
else:
status_lines.append(f"ā³ MedSwin ({model_name}): loading...")
else:
status_lines.append(f"ā³ MedSwin ({model_name}): loading...")
except Exception as e:
logger.error(f"Error calling load_model_with_gpu: {e}")
status_lines.append(f"ā³ MedSwin ({model_name}): loading...")
# TTS model status (only show if available or if there's an issue)
if SNAC_AVAILABLE:
if config.global_tts_model is not None:
status_lines.append("ā
TTS (maya1): loaded and ready")
else:
# TTS available but not loaded - optional feature
pass # Don't show if not loaded, it's optional
# Don't show TTS status if library not available (it's optional)
# ASR (Whisper) model status
if WHISPER_AVAILABLE:
if config.global_whisper_model is not None:
status_lines.append("ā
ASR (Whisper): loaded and ready")
else:
status_lines.append("ā³ ASR (Whisper): will load on first use")
else:
status_lines.append("ā ASR: library not available")
status_text = "\n".join(status_lines)
is_ready = is_model_loaded(model_name)
return status_text, is_ready
except Exception as e:
return f"ā Error: {str(e)[:100]}", False
def check_model_status(model_name):
"""Check current model status without loading"""
status_lines = []
# Medical model status
if is_model_loaded(model_name):
status_lines.append(f"ā
MedSwin ({model_name}): loaded and ready")
else:
state = get_model_loading_state(model_name)
if state == "loading":
status_lines.append(f"ā³ MedSwin ({model_name}): loading...")
elif state == "error":
status_lines.append(f"ā MedSwin ({model_name}): error loading")
else:
status_lines.append(f"ā ļø MedSwin ({model_name}): not loaded")
# TTS model status (only show if available and loaded)
if SNAC_AVAILABLE:
if config.global_tts_model is not None:
status_lines.append("ā
TTS (maya1): loaded and ready")
# Don't show if TTS library available but model not loaded (optional feature)
# Don't show TTS status if library not available (it's optional)
# ASR (Whisper) model status
if WHISPER_AVAILABLE:
if config.global_whisper_model is not None:
status_lines.append("ā
ASR (Whisper): loaded and ready")
else:
status_lines.append("ā³ ASR (Whisper): will load on first use")
else:
status_lines.append("ā ASR: library not available")
status_text = "\n".join(status_lines)
is_ready = is_model_loaded(model_name)
return status_text, is_ready
# GPU-decorated function to load ONLY medical model on startup
# According to ZeroGPU best practices:
# 1. Load models to CPU in global scope (no GPU decorator needed)
# 2. Move models to GPU only in inference functions (with @spaces.GPU decorator)
# However, for large models, loading to CPU then moving to GPU uses more memory
# So we use a hybrid approach: load to GPU directly but within GPU-decorated function
def load_medical_model_on_startup_cpu():
"""
Load model to CPU on startup (ZeroGPU best practice - no GPU decorator needed)
Model will be moved to GPU during first inference
"""
status_messages = []
try:
# Load only medical model (MedSwin) to CPU - TTS and Whisper load on-demand
if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
logger.info(f"[STARTUP] Loading medical model to CPU: {DEFAULT_MEDICAL_MODEL}...")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
try:
# Load to CPU (no GPU decorator needed)
initialize_medical_model(DEFAULT_MEDICAL_MODEL, load_to_gpu=False)
# Verify model is actually loaded
if is_model_loaded(DEFAULT_MEDICAL_MODEL):
status_messages.append(f"ā
MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded to CPU")
logger.info(f"[STARTUP] ā
Medical model {DEFAULT_MEDICAL_MODEL} loaded to CPU successfully!")
else:
status_messages.append(f"ā ļø MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed")
logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
except Exception as e:
status_messages.append(f"ā MedSwin ({DEFAULT_MEDICAL_MODEL}): error - {str(e)[:50]}")
logger.error(f"[STARTUP] Failed to load medical model: {e}")
import traceback
logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
else:
status_messages.append(f"ā
MedSwin ({DEFAULT_MEDICAL_MODEL}): already loaded")
logger.info(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} already loaded")
# Add ASR status (will load on first use)
if WHISPER_AVAILABLE:
status_messages.append("ā³ ASR (Whisper): will load on first use")
else:
status_messages.append("ā ASR: library not available")
# Return status
status_text = "\n".join(status_messages)
logger.info(f"[STARTUP] ā
Model loading complete. Status:\n{status_text}")
return status_text
except Exception as e:
error_msg = str(e)
logger.error(f"[STARTUP] Error loading model to CPU: {error_msg}")
return f"ā ļø Error loading model: {error_msg[:100]}"
# Alternative: Load directly to GPU (requires GPU decorator)
# @spaces.GPU(max_duration=MAX_DURATION)
def load_medical_model_on_startup_gpu():
"""
Load model directly to GPU on startup (alternative approach)
Uses GPU quota but model is immediately ready for inference
"""
import torch
status_messages = []
try:
# Clear GPU cache at start
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("[STARTUP] Cleared GPU cache before model loading")
# Load only medical model (MedSwin) - TTS and Whisper load on-demand
if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
logger.info(f"[STARTUP] Loading medical model to GPU: {DEFAULT_MEDICAL_MODEL}...")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
try:
# Load directly to GPU (within GPU-decorated function)
initialize_medical_model(DEFAULT_MEDICAL_MODEL, load_to_gpu=True)
# Verify model is actually loaded
if is_model_loaded(DEFAULT_MEDICAL_MODEL):
status_messages.append(f"ā
MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded to GPU")
logger.info(f"[STARTUP] ā
Medical model {DEFAULT_MEDICAL_MODEL} loaded to GPU successfully!")
else:
status_messages.append(f"ā ļø MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed")
logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
except Exception as e:
status_messages.append(f"ā MedSwin ({DEFAULT_MEDICAL_MODEL}): error - {str(e)[:50]}")
logger.error(f"[STARTUP] Failed to load medical model: {e}")
import traceback
logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
else:
status_messages.append(f"ā
MedSwin ({DEFAULT_MEDICAL_MODEL}): already loaded")
logger.info(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} already loaded")
# Add ASR status (will load on first use)
if WHISPER_AVAILABLE:
status_messages.append("ā³ ASR (Whisper): will load on first use")
else:
status_messages.append("ā ASR: library not available")
# Clear cache after loading
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug("[STARTUP] Cleared GPU cache after model loading")
# Return status
status_text = "\n".join(status_messages)
logger.info(f"[STARTUP] ā
Model loading complete. Status:\n{status_text}")
return status_text
except Exception as e:
error_msg = str(e)
# Check if it's a ZeroGPU quota/rate limit error
is_quota_error = ("429" in error_msg or "Too Many Requests" in error_msg or
"quota" in error_msg.lower() or "ZeroGPU" in error_msg or
"runnning out" in error_msg.lower() or "running out" in error_msg.lower())
if is_quota_error:
logger.warning(f"[STARTUP] ZeroGPU quota/rate limit error detected: {error_msg[:100]}")
# Return status message indicating quota error (will be handled by retry logic)
status_messages.append("ā ļø ZeroGPU quota error - will retry")
status_text = "\n".join(status_messages)
# Also add ASR status
if WHISPER_AVAILABLE:
status_text += "\nā³ ASR (Whisper): will load on first use"
return status_text # Return status instead of raising, let wrapper handle retry
logger.error(f"[STARTUP] ā Error in model loading startup: {e}")
import traceback
logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
# Clear cache on error
if torch.cuda.is_available():
torch.cuda.empty_cache()
return f"ā ļø Startup loading error: {str(e)[:100]}"
# Initialize status on load
def init_model_status():
try:
result = check_model_status(DEFAULT_MEDICAL_MODEL)
if result and isinstance(result, tuple) and len(result) == 2:
status_text, is_ready = result
return status_text
else:
return "ā ļø Unable to check model status"
except Exception as e:
logger.error(f"Error in init_model_status: {e}")
return f"ā ļø Error: {str(e)[:100]}"
# Update status when model selection changes
def update_model_status_on_change(model_name):
try:
result = check_model_status(model_name)
if result and isinstance(result, tuple) and len(result) == 2:
status_text, is_ready = result
return status_text
else:
return "ā ļø Unable to check model status"
except Exception as e:
logger.error(f"Error in update_model_status_on_change: {e}")
return f"ā ļø Error: {str(e)[:100]}"
# Handle model selection change
def on_model_change(model_name):
try:
result = load_model_and_update_status(model_name)
if result and isinstance(result, tuple) and len(result) == 2:
status_text, is_ready = result
submit_enabled = is_ready
return (
status_text,
gr.update(interactive=submit_enabled),
gr.update(interactive=submit_enabled)
)
else:
error_msg = "ā ļø Unable to load model status"
return (
error_msg,
gr.update(interactive=False),
gr.update(interactive=False)
)
except Exception as e:
logger.error(f"Error in on_model_change: {e}")
error_msg = f"ā ļø Error: {str(e)[:100]}"
return (
error_msg,
gr.update(interactive=False),
gr.update(interactive=False)
)
# Update status display periodically or on model status changes
def refresh_model_status(model_name):
return update_model_status_on_change(model_name)
medical_model.change(
fn=on_model_change,
inputs=[medical_model],
outputs=[model_status, submit_button, message_input]
)
# GPU-decorated function to load Whisper ASR model on-demand
# @spaces.GPU(max_duration=MAX_DURATION)
def load_whisper_model_on_demand():
"""Load Whisper ASR model when needed"""
try:
if WHISPER_AVAILABLE and config.global_whisper_model is None:
logger.info("[ASR] Loading Whisper model on-demand...")
initialize_whisper_model()
if config.global_whisper_model is not None:
logger.info("[ASR] ā
Whisper model loaded successfully!")
return "ā
ASR (Whisper): loaded"
else:
logger.warning("[ASR] ā ļø Whisper model failed to load")
return "ā ļø ASR (Whisper): failed to load"
elif config.global_whisper_model is not None:
return "ā
ASR (Whisper): already loaded"
else:
return "ā ASR: library not available"
except Exception as e:
logger.error(f"[ASR] Error loading Whisper model: {e}")
return f"ā ASR: error - {str(e)[:100]}"
# Load medical model on startup and update status
# Use a wrapper to handle GPU context properly with retry logic
def load_startup_and_update_ui():
"""
Load model on startup with retry logic (max 3 attempts) and return status with UI updates
Uses CPU-first approach (ZeroGPU best practice):
- Load model to CPU (no GPU decorator needed, avoids quota issues)
- Model will be moved to GPU during first inference
"""
import time
max_retries = 3
base_delay = 5.0 # Start with 5 seconds delay
for attempt in range(1, max_retries + 1):
try:
logger.info(f"[STARTUP] Attempt {attempt}/{max_retries} to load medical model to CPU...")
# Use CPU-first approach (no GPU decorator, avoids quota issues)
status_text = load_medical_model_on_startup_cpu()
# Check if model is ready and update submit button state
is_ready = is_model_loaded(DEFAULT_MEDICAL_MODEL)
if is_ready:
logger.info(f"[STARTUP] ā
Model loaded successfully on attempt {attempt}")
return status_text, gr.update(interactive=is_ready), gr.update(interactive=is_ready)
else:
# Check if status text indicates quota error
if status_text and ("quota" in status_text.lower() or "ZeroGPU" in status_text or
"429" in status_text or "runnning out" in status_text.lower() or
"running out" in status_text.lower()):
if attempt < max_retries:
delay = base_delay * attempt
logger.warning(f"[STARTUP] Quota error detected in status, retrying in {delay} seconds...")
time.sleep(delay)
continue
else:
# Quota exhausted after retries - allow user to proceed, model will load on-demand
status_msg = "ā ļø ZeroGPU quota exhausted.\nā³ Model will load automatically when you send a message.\nš” You can also select a model from the dropdown."
logger.info("[STARTUP] Quota exhausted after retries - allowing user to proceed with on-demand loading")
return status_msg, gr.update(interactive=True), gr.update(interactive=True)
# Model didn't load, but no exception - might be a state issue
logger.warning(f"[STARTUP] Model not ready after attempt {attempt}, but no error")
if attempt < max_retries:
delay = base_delay * attempt # Exponential backoff: 5s, 10s, 15s
logger.info(f"[STARTUP] Retrying in {delay} seconds...")
time.sleep(delay)
continue
else:
# Even if model didn't load, allow user to try selecting another model
return status_text + "\nā ļø Model not loaded. Please select a model from dropdown.", gr.update(interactive=True), gr.update(interactive=True)
except Exception as e:
error_msg = str(e)
is_quota_error = ("429" in error_msg or "Too Many Requests" in error_msg or
"quota" in error_msg.lower() or "ZeroGPU" in error_msg or
"runnning out" in error_msg.lower() or "running out" in error_msg.lower())
if is_quota_error and attempt < max_retries:
delay = base_delay * attempt # Exponential backoff: 5s, 10s, 15s
logger.warning(f"[STARTUP] ZeroGPU rate limit/quota error on attempt {attempt}/{max_retries}")
logger.info(f"[STARTUP] Retrying in {delay} seconds...")
time.sleep(delay)
continue
else:
logger.error(f"[STARTUP] Error in load_startup_and_update_ui (attempt {attempt}/{max_retries}): {e}")
import traceback
logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
if is_quota_error:
# If quota exhausted, allow user to proceed - model will load on-demand
error_display = "ā ļø ZeroGPU quota exhausted.\nā³ Model will load automatically when you send a message.\nš” You can also select a model from the dropdown."
logger.info("[STARTUP] Quota exhausted - allowing user to proceed with on-demand loading")
return error_display, gr.update(interactive=True), gr.update(interactive=True)
else:
error_display = f"ā ļø Startup error: {str(e)[:100]}"
if attempt >= max_retries:
logger.error(f"[STARTUP] Failed after {max_retries} attempts")
return error_display, gr.update(interactive=False), gr.update(interactive=False)
# Should not reach here, but just in case
return "ā ļø Startup failed after retries. Please select a model from dropdown.", gr.update(interactive=True), gr.update(interactive=True)
demo.load(
fn=load_startup_and_update_ui,
inputs=None,
outputs=[model_status, submit_button, message_input]
)
# Note: We removed the preload on focus functionality because:
# 1. Model loading requires GPU access (device_map="auto" needs GPU in ZeroGPU)
# 2. The startup function already loads the model with GPU decorator
# 3. Preloading without GPU decorator would fail or cause conflicts
# 4. If startup fails, user can select a model from dropdown to trigger loading
# Wrap stream_chat - ensure model is loaded before starting (don't load inside stream_chat to save time)
def stream_chat_with_model_check(
message, history, system_prompt, temperature, max_new_tokens,
top_p, top_k, penalty, retriever_k, merge_threshold,
use_rag, medical_model_name, use_web_search,
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None
):
# Check if model is loaded - if not, show error (don't load here to save stream_chat time)
model_loaded = is_model_loaded(medical_model_name)
if not model_loaded:
loading_state = get_model_loading_state(medical_model_name)
# Debug logging to understand why model check fails
logger.debug(f"[STREAM_CHAT] Model check: name={medical_model_name}, loaded={model_loaded}, state={loading_state}, in_dict={medical_model_name in config.global_medical_models}, model_exists={config.global_medical_models.get(medical_model_name) is not None if medical_model_name in config.global_medical_models else False}")
if loading_state == "loading":
error_msg = f"ā³ {medical_model_name} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
else:
error_msg = f"ā ļø {medical_model_name} is not loaded. Please wait a moment for the model to finish loading, or select a model from the dropdown to load it."
updated_history = history + [{"role": "assistant", "content": error_msg}]
yield updated_history, ""
return
# If request is None, create a mock request for compatibility
if request is None:
class MockRequest:
session_hash = "anonymous"
request = MockRequest()
# Model is loaded, proceed with stream_chat (no model loading here to save time)
# Note: We handle "BodyStreamBuffer was aborted" errors by catching stream disconnections
# and not attempting to yield after the client has disconnected
last_result = None
stream_aborted = False
try:
for result in stream_chat(
message, history, system_prompt, temperature, max_new_tokens,
top_p, top_k, penalty, retriever_k, merge_threshold,
use_rag, medical_model_name, use_web_search,
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request
):
last_result = result
try:
yield result
except (GeneratorExit, StopIteration, RuntimeError) as stream_error:
# Stream was closed/aborted by client - don't try to yield again
error_msg_lower = str(stream_error).lower()
if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower:
logger.info(f"[UI] Stream was aborted by client, stopping gracefully")
stream_aborted = True
break
raise
except (GeneratorExit, StopIteration) as stream_exit:
# Stream was closed - this is normal, just log and exit
logger.info(f"[UI] Stream closed normally")
stream_aborted = True
return
except Exception as e:
# Handle any errors gracefully
error_str = str(e)
error_msg_lower = error_str.lower()
# Check if this is a stream abort error
is_stream_abort = (
'bodystreambuffer' in error_msg_lower or
'stream' in error_msg_lower and 'abort' in error_msg_lower or
'connection' in error_msg_lower and 'abort' in error_msg_lower or
isinstance(e, (GeneratorExit, StopIteration, RuntimeError)) and 'abort' in error_msg_lower
)
if is_stream_abort:
logger.info(f"[UI] Stream was aborted (BodyStreamBuffer or similar): {error_str[:100]}")
stream_aborted = True
# If we have a result, it was already yielded, so just return
return
is_gpu_timeout = 'gpu task aborted' in error_msg_lower or 'timeout' in error_msg_lower
logger.error(f"Error in stream_chat_with_model_check: {error_str}")
import traceback
logger.debug(f"Full traceback: {traceback.format_exc()}")
# Check if we have a valid answer in the last result
has_valid_answer = False
if last_result is not None:
try:
last_history, last_thoughts = last_result
# Find the last assistant message in the history
if last_history and isinstance(last_history, list):
for msg in reversed(last_history):
if isinstance(msg, dict) and msg.get("role") == "assistant":
assistant_content = msg.get("content", "")
# Check if it's a valid answer (not empty, not an error message)
if assistant_content and len(assistant_content.strip()) > 0:
# Not an error message
if not assistant_content.strip().startswith("ā ļø") and not assistant_content.strip().startswith("ā³"):
has_valid_answer = True
break
except Exception as parse_error:
logger.debug(f"Error parsing last_result: {parse_error}")
# If stream was aborted, don't try to yield - just return
if stream_aborted:
logger.info(f"[UI] Stream was aborted, not yielding final result")
return
# If we have a valid answer, use it (don't show error message)
if has_valid_answer:
logger.info(f"[UI] Error occurred but final answer already generated, displaying it without error message")
try:
yield last_result
except (GeneratorExit, StopIteration, RuntimeError) as yield_error:
error_msg_lower = str(yield_error).lower()
if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower:
logger.info(f"[UI] Stream aborted while yielding final result, ignoring")
else:
raise
return
# For GPU timeouts, try to use last result even if it's partial
if is_gpu_timeout and last_result is not None:
logger.info(f"[UI] GPU timeout occurred, using last available result")
try:
yield last_result
except (GeneratorExit, StopIteration, RuntimeError) as yield_error:
error_msg_lower = str(yield_error).lower()
if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower:
logger.info(f"[UI] Stream aborted while yielding timeout result, ignoring")
else:
raise
return
# Only show error for non-timeout errors when we have no valid answer
# For GPU timeouts with no result, show empty message (not error)
if is_gpu_timeout:
logger.info(f"[UI] GPU timeout with no result, showing empty assistant message")
updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": ""}]
try:
yield updated_history, ""
except (GeneratorExit, StopIteration, RuntimeError) as yield_error:
error_msg_lower = str(yield_error).lower()
if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower:
logger.info(f"[UI] Stream aborted while yielding empty message, ignoring")
else:
raise
else:
# For other errors, show minimal error message only if no result
error_display = f"ā ļø An error occurred: {error_str[:200]}"
updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": error_display}]
try:
yield updated_history, ""
except (GeneratorExit, StopIteration, RuntimeError) as yield_error:
error_msg_lower = str(yield_error).lower()
if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower:
logger.info(f"[UI] Stream aborted while yielding error message, ignoring")
else:
raise
submit_button.click(
fn=stream_chat_with_model_check,
inputs=[
message_input,
chatbot,
system_prompt,
temperature,
max_new_tokens,
top_p,
top_k,
penalty,
retriever_k,
merge_threshold,
use_rag,
medical_model,
use_web_search,
enable_clinical_intake,
disable_agentic_reasoning,
show_thoughts_state
],
outputs=[chatbot, agentic_thoughts_box]
)
message_input.submit(
fn=stream_chat_with_model_check,
inputs=[
message_input,
chatbot,
system_prompt,
temperature,
max_new_tokens,
top_p,
top_k,
penalty,
retriever_k,
merge_threshold,
use_rag,
medical_model,
use_web_search,
enable_clinical_intake,
disable_agentic_reasoning,
show_thoughts_state
],
outputs=[chatbot, agentic_thoughts_box]
)
return demo