Spaces:
Running
on
Zero
Running
on
Zero
Y Phung Nguyen
commited on
Commit
·
2fffb9d
1
Parent(s):
98c58ec
Upd models loader #7
Browse files- pipeline.py +2 -1
- supervisor.py +3 -1
- ui.py +61 -33
pipeline.py
CHANGED
|
@@ -26,6 +26,7 @@ from supervisor import (
|
|
| 26 |
)
|
| 27 |
|
| 28 |
MAX_CLINICAL_QA_ROUNDS = 5
|
|
|
|
| 29 |
_clinical_intake_sessions = {}
|
| 30 |
_clinical_intake_lock = threading.Lock()
|
| 31 |
|
|
@@ -343,7 +344,7 @@ def _handle_clinical_answer(session_id: str, answer_text: str):
|
|
| 343 |
return {"type": "question", "prompt": prompt}
|
| 344 |
|
| 345 |
|
| 346 |
-
@spaces.GPU(max_duration=
|
| 347 |
def stream_chat(
|
| 348 |
message: str,
|
| 349 |
history: list,
|
|
|
|
| 26 |
)
|
| 27 |
|
| 28 |
MAX_CLINICAL_QA_ROUNDS = 5
|
| 29 |
+
MAX_DURATION = 120
|
| 30 |
_clinical_intake_sessions = {}
|
| 31 |
_clinical_intake_lock = threading.Lock()
|
| 32 |
|
|
|
|
| 344 |
return {"type": "question", "prompt": prompt}
|
| 345 |
|
| 346 |
|
| 347 |
+
@spaces.GPU(max_duration=MAX_DURATION)
|
| 348 |
def stream_chat(
|
| 349 |
message: str,
|
| 350 |
history: list,
|
supervisor.py
CHANGED
|
@@ -12,6 +12,8 @@ from utils import format_prompt_manually
|
|
| 12 |
MAX_SUBTASKS = 3
|
| 13 |
# Maximum number of search strategies
|
| 14 |
MAX_SEARCH_STRATEGIES = 3
|
|
|
|
|
|
|
| 15 |
|
| 16 |
try:
|
| 17 |
import nest_asyncio
|
|
@@ -23,7 +25,7 @@ async def gemini_supervisor_breakdown_async(
|
|
| 23 |
use_rag: bool,
|
| 24 |
use_web_search: bool,
|
| 25 |
time_elapsed: float,
|
| 26 |
-
max_duration: int =
|
| 27 |
previous_answer: str | None = None,
|
| 28 |
) -> dict:
|
| 29 |
"""Gemini Supervisor: Break user query into sub-topics.
|
|
|
|
| 12 |
MAX_SUBTASKS = 3
|
| 13 |
# Maximum number of search strategies
|
| 14 |
MAX_SEARCH_STRATEGIES = 3
|
| 15 |
+
# Maximum duration for GPU requests
|
| 16 |
+
MAX_DURATION = 120
|
| 17 |
|
| 18 |
try:
|
| 19 |
import nest_asyncio
|
|
|
|
| 25 |
use_rag: bool,
|
| 26 |
use_web_search: bool,
|
| 27 |
time_elapsed: float,
|
| 28 |
+
max_duration: int = MAX_DURATION,
|
| 29 |
previous_answer: str | None = None,
|
| 30 |
) -> dict:
|
| 31 |
"""Gemini Supervisor: Break user query into sub-topics.
|
ui.py
CHANGED
|
@@ -19,6 +19,7 @@ from models import (
|
|
| 19 |
)
|
| 20 |
from logger import logger
|
| 21 |
|
|
|
|
| 22 |
|
| 23 |
def create_demo():
|
| 24 |
"""Create and return Gradio demo interface"""
|
|
@@ -292,7 +293,7 @@ def create_demo():
|
|
| 292 |
)
|
| 293 |
|
| 294 |
# GPU-decorated function to load any model (for user selection)
|
| 295 |
-
|
| 296 |
def load_model_with_gpu(model_name):
|
| 297 |
"""Load medical model (GPU-decorated for ZeroGPU compatibility)"""
|
| 298 |
try:
|
|
@@ -406,7 +407,7 @@ def create_demo():
|
|
| 406 |
|
| 407 |
# GPU-decorated function to load ALL models sequentially on startup
|
| 408 |
# This prevents ZeroGPU conflicts from multiple simultaneous GPU requests
|
| 409 |
-
|
| 410 |
def load_all_models_on_startup():
|
| 411 |
"""Load all models sequentially in a single GPU session to avoid ZeroGPU conflicts"""
|
| 412 |
import time
|
|
@@ -562,7 +563,7 @@ def create_demo():
|
|
| 562 |
)
|
| 563 |
|
| 564 |
# GPU-decorated function to load Whisper ASR model on-demand
|
| 565 |
-
@spaces.GPU(max_duration=
|
| 566 |
def load_whisper_model_on_demand():
|
| 567 |
"""Load Whisper ASR model when needed"""
|
| 568 |
try:
|
|
@@ -610,37 +611,64 @@ def create_demo():
|
|
| 610 |
use_rag, medical_model_name, use_web_search,
|
| 611 |
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None
|
| 612 |
):
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
return
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 644 |
|
| 645 |
submit_button.click(
|
| 646 |
fn=stream_chat_with_model_check,
|
|
|
|
| 19 |
)
|
| 20 |
from logger import logger
|
| 21 |
|
| 22 |
+
MAX_DURATION = 120
|
| 23 |
|
| 24 |
def create_demo():
|
| 25 |
"""Create and return Gradio demo interface"""
|
|
|
|
| 293 |
)
|
| 294 |
|
| 295 |
# GPU-decorated function to load any model (for user selection)
|
| 296 |
+
@spaces.GPU(max_duration=MAX_DURATION)
|
| 297 |
def load_model_with_gpu(model_name):
|
| 298 |
"""Load medical model (GPU-decorated for ZeroGPU compatibility)"""
|
| 299 |
try:
|
|
|
|
| 407 |
|
| 408 |
# GPU-decorated function to load ALL models sequentially on startup
|
| 409 |
# This prevents ZeroGPU conflicts from multiple simultaneous GPU requests
|
| 410 |
+
@spaces.GPU(max_duration=MAX_DURATION)
|
| 411 |
def load_all_models_on_startup():
|
| 412 |
"""Load all models sequentially in a single GPU session to avoid ZeroGPU conflicts"""
|
| 413 |
import time
|
|
|
|
| 563 |
)
|
| 564 |
|
| 565 |
# GPU-decorated function to load Whisper ASR model on-demand
|
| 566 |
+
@spaces.GPU(max_duration=MAX_DURATION)
|
| 567 |
def load_whisper_model_on_demand():
|
| 568 |
"""Load Whisper ASR model when needed"""
|
| 569 |
try:
|
|
|
|
| 611 |
use_rag, medical_model_name, use_web_search,
|
| 612 |
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None
|
| 613 |
):
|
| 614 |
+
import time
|
| 615 |
+
max_retries = 2
|
| 616 |
+
base_delay = 2.0
|
| 617 |
+
|
| 618 |
+
for attempt in range(max_retries):
|
| 619 |
+
try:
|
| 620 |
+
# Check if model is currently loading (don't block if it's already loaded)
|
| 621 |
+
loading_state = get_model_loading_state(medical_model_name)
|
| 622 |
+
if loading_state == "loading" and not is_model_loaded(medical_model_name):
|
| 623 |
+
error_msg = f"⏳ {medical_model_name} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
|
| 624 |
+
updated_history = history + [{"role": "assistant", "content": error_msg}]
|
| 625 |
+
yield updated_history, ""
|
| 626 |
+
return
|
| 627 |
+
|
| 628 |
+
# If request is None, create a mock request for compatibility
|
| 629 |
+
if request is None:
|
| 630 |
+
class MockRequest:
|
| 631 |
+
session_hash = "anonymous"
|
| 632 |
+
request = MockRequest()
|
| 633 |
+
|
| 634 |
+
# Let stream_chat handle model loading (it's GPU-decorated and can load on-demand)
|
| 635 |
+
for result in stream_chat(
|
| 636 |
+
message, history, system_prompt, temperature, max_new_tokens,
|
| 637 |
+
top_p, top_k, penalty, retriever_k, merge_threshold,
|
| 638 |
+
use_rag, medical_model_name, use_web_search,
|
| 639 |
+
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request
|
| 640 |
+
):
|
| 641 |
+
yield result
|
| 642 |
+
# If we get here, stream_chat completed successfully
|
| 643 |
return
|
| 644 |
+
|
| 645 |
+
except Exception as e:
|
| 646 |
+
error_msg_lower = str(e).lower()
|
| 647 |
+
is_gpu_error = 'gpu task aborted' in error_msg_lower or 'gpu' in error_msg_lower or 'zerogpu' in error_msg_lower
|
| 648 |
+
|
| 649 |
+
if is_gpu_error and attempt < max_retries - 1:
|
| 650 |
+
delay = base_delay * (2 ** attempt) # Exponential backoff: 2s, 4s
|
| 651 |
+
logger.warning(f"[STREAM_CHAT] GPU task aborted (attempt {attempt + 1}/{max_retries}), retrying after {delay}s...")
|
| 652 |
+
# Yield a message to user about retry
|
| 653 |
+
retry_msg = f"⏳ GPU task was interrupted. Retrying in {delay}s... (attempt {attempt + 1}/{max_retries})"
|
| 654 |
+
updated_history = history + [{"role": "assistant", "content": retry_msg}]
|
| 655 |
+
yield updated_history, ""
|
| 656 |
+
time.sleep(delay)
|
| 657 |
+
continue
|
| 658 |
+
else:
|
| 659 |
+
# Final error handling
|
| 660 |
+
logger.error(f"[STREAM_CHAT] Error in stream_chat_with_model_check: {e}")
|
| 661 |
+
import traceback
|
| 662 |
+
logger.error(f"[STREAM_CHAT] Full traceback: {traceback.format_exc()}")
|
| 663 |
+
|
| 664 |
+
if is_gpu_error:
|
| 665 |
+
error_msg = f"⚠️ GPU task was aborted. This can happen if:\n- The request took too long\n- Multiple GPU requests conflicted\n- GPU quota was exceeded\n\nPlease try again or select a different model."
|
| 666 |
+
else:
|
| 667 |
+
error_msg = f"⚠️ An error occurred: {str(e)[:200]}"
|
| 668 |
+
|
| 669 |
+
updated_history = history + [{"role": "assistant", "content": error_msg}]
|
| 670 |
+
yield updated_history, ""
|
| 671 |
+
return
|
| 672 |
|
| 673 |
submit_button.click(
|
| 674 |
fn=stream_chat_with_model_check,
|