Y Phung Nguyen commited on
Commit
2fffb9d
·
1 Parent(s): 98c58ec

Upd models loader #7

Browse files
Files changed (3) hide show
  1. pipeline.py +2 -1
  2. supervisor.py +3 -1
  3. ui.py +61 -33
pipeline.py CHANGED
@@ -26,6 +26,7 @@ from supervisor import (
26
  )
27
 
28
  MAX_CLINICAL_QA_ROUNDS = 5
 
29
  _clinical_intake_sessions = {}
30
  _clinical_intake_lock = threading.Lock()
31
 
@@ -343,7 +344,7 @@ def _handle_clinical_answer(session_id: str, answer_text: str):
343
  return {"type": "question", "prompt": prompt}
344
 
345
 
346
- @spaces.GPU(max_duration=120)
347
  def stream_chat(
348
  message: str,
349
  history: list,
 
26
  )
27
 
28
  MAX_CLINICAL_QA_ROUNDS = 5
29
+ MAX_DURATION = 120
30
  _clinical_intake_sessions = {}
31
  _clinical_intake_lock = threading.Lock()
32
 
 
344
  return {"type": "question", "prompt": prompt}
345
 
346
 
347
+ @spaces.GPU(max_duration=MAX_DURATION)
348
  def stream_chat(
349
  message: str,
350
  history: list,
supervisor.py CHANGED
@@ -12,6 +12,8 @@ from utils import format_prompt_manually
12
  MAX_SUBTASKS = 3
13
  # Maximum number of search strategies
14
  MAX_SEARCH_STRATEGIES = 3
 
 
15
 
16
  try:
17
  import nest_asyncio
@@ -23,7 +25,7 @@ async def gemini_supervisor_breakdown_async(
23
  use_rag: bool,
24
  use_web_search: bool,
25
  time_elapsed: float,
26
- max_duration: int = 120,
27
  previous_answer: str | None = None,
28
  ) -> dict:
29
  """Gemini Supervisor: Break user query into sub-topics.
 
12
  MAX_SUBTASKS = 3
13
  # Maximum number of search strategies
14
  MAX_SEARCH_STRATEGIES = 3
15
+ # Maximum duration for GPU requests
16
+ MAX_DURATION = 120
17
 
18
  try:
19
  import nest_asyncio
 
25
  use_rag: bool,
26
  use_web_search: bool,
27
  time_elapsed: float,
28
+ max_duration: int = MAX_DURATION,
29
  previous_answer: str | None = None,
30
  ) -> dict:
31
  """Gemini Supervisor: Break user query into sub-topics.
ui.py CHANGED
@@ -19,6 +19,7 @@ from models import (
19
  )
20
  from logger import logger
21
 
 
22
 
23
  def create_demo():
24
  """Create and return Gradio demo interface"""
@@ -292,7 +293,7 @@ def create_demo():
292
  )
293
 
294
  # GPU-decorated function to load any model (for user selection)
295
- # @spaces.GPU(max_duration=120)
296
  def load_model_with_gpu(model_name):
297
  """Load medical model (GPU-decorated for ZeroGPU compatibility)"""
298
  try:
@@ -406,7 +407,7 @@ def create_demo():
406
 
407
  # GPU-decorated function to load ALL models sequentially on startup
408
  # This prevents ZeroGPU conflicts from multiple simultaneous GPU requests
409
- # @spaces.GPU(max_duration=120)
410
  def load_all_models_on_startup():
411
  """Load all models sequentially in a single GPU session to avoid ZeroGPU conflicts"""
412
  import time
@@ -562,7 +563,7 @@ def create_demo():
562
  )
563
 
564
  # GPU-decorated function to load Whisper ASR model on-demand
565
- @spaces.GPU(max_duration=120)
566
  def load_whisper_model_on_demand():
567
  """Load Whisper ASR model when needed"""
568
  try:
@@ -610,37 +611,64 @@ def create_demo():
610
  use_rag, medical_model_name, use_web_search,
611
  enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None
612
  ):
613
- try:
614
- # Check if model is currently loading (don't block if it's already loaded)
615
- loading_state = get_model_loading_state(medical_model_name)
616
- if loading_state == "loading" and not is_model_loaded(medical_model_name):
617
- error_msg = f"⏳ {medical_model_name} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
618
- updated_history = history + [{"role": "assistant", "content": error_msg}]
619
- yield updated_history, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620
  return
621
-
622
- # If request is None, create a mock request for compatibility
623
- if request is None:
624
- class MockRequest:
625
- session_hash = "anonymous"
626
- request = MockRequest()
627
-
628
- # Let stream_chat handle model loading (it's GPU-decorated and can load on-demand)
629
- for result in stream_chat(
630
- message, history, system_prompt, temperature, max_new_tokens,
631
- top_p, top_k, penalty, retriever_k, merge_threshold,
632
- use_rag, medical_model_name, use_web_search,
633
- enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request
634
- ):
635
- yield result
636
- except Exception as e:
637
- # Handle any errors gracefully
638
- logger.error(f"Error in stream_chat_with_model_check: {e}")
639
- import traceback
640
- logger.debug(f"Full traceback: {traceback.format_exc()}")
641
- error_msg = f"⚠️ An error occurred: {str(e)[:200]}"
642
- updated_history = history + [{"role": "assistant", "content": error_msg}]
643
- yield updated_history, ""
 
 
 
 
 
644
 
645
  submit_button.click(
646
  fn=stream_chat_with_model_check,
 
19
  )
20
  from logger import logger
21
 
22
+ MAX_DURATION = 120
23
 
24
  def create_demo():
25
  """Create and return Gradio demo interface"""
 
293
  )
294
 
295
  # GPU-decorated function to load any model (for user selection)
296
+ @spaces.GPU(max_duration=MAX_DURATION)
297
  def load_model_with_gpu(model_name):
298
  """Load medical model (GPU-decorated for ZeroGPU compatibility)"""
299
  try:
 
407
 
408
  # GPU-decorated function to load ALL models sequentially on startup
409
  # This prevents ZeroGPU conflicts from multiple simultaneous GPU requests
410
+ @spaces.GPU(max_duration=MAX_DURATION)
411
  def load_all_models_on_startup():
412
  """Load all models sequentially in a single GPU session to avoid ZeroGPU conflicts"""
413
  import time
 
563
  )
564
 
565
  # GPU-decorated function to load Whisper ASR model on-demand
566
+ @spaces.GPU(max_duration=MAX_DURATION)
567
  def load_whisper_model_on_demand():
568
  """Load Whisper ASR model when needed"""
569
  try:
 
611
  use_rag, medical_model_name, use_web_search,
612
  enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None
613
  ):
614
+ import time
615
+ max_retries = 2
616
+ base_delay = 2.0
617
+
618
+ for attempt in range(max_retries):
619
+ try:
620
+ # Check if model is currently loading (don't block if it's already loaded)
621
+ loading_state = get_model_loading_state(medical_model_name)
622
+ if loading_state == "loading" and not is_model_loaded(medical_model_name):
623
+ error_msg = f"⏳ {medical_model_name} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
624
+ updated_history = history + [{"role": "assistant", "content": error_msg}]
625
+ yield updated_history, ""
626
+ return
627
+
628
+ # If request is None, create a mock request for compatibility
629
+ if request is None:
630
+ class MockRequest:
631
+ session_hash = "anonymous"
632
+ request = MockRequest()
633
+
634
+ # Let stream_chat handle model loading (it's GPU-decorated and can load on-demand)
635
+ for result in stream_chat(
636
+ message, history, system_prompt, temperature, max_new_tokens,
637
+ top_p, top_k, penalty, retriever_k, merge_threshold,
638
+ use_rag, medical_model_name, use_web_search,
639
+ enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request
640
+ ):
641
+ yield result
642
+ # If we get here, stream_chat completed successfully
643
  return
644
+
645
+ except Exception as e:
646
+ error_msg_lower = str(e).lower()
647
+ is_gpu_error = 'gpu task aborted' in error_msg_lower or 'gpu' in error_msg_lower or 'zerogpu' in error_msg_lower
648
+
649
+ if is_gpu_error and attempt < max_retries - 1:
650
+ delay = base_delay * (2 ** attempt) # Exponential backoff: 2s, 4s
651
+ logger.warning(f"[STREAM_CHAT] GPU task aborted (attempt {attempt + 1}/{max_retries}), retrying after {delay}s...")
652
+ # Yield a message to user about retry
653
+ retry_msg = f"⏳ GPU task was interrupted. Retrying in {delay}s... (attempt {attempt + 1}/{max_retries})"
654
+ updated_history = history + [{"role": "assistant", "content": retry_msg}]
655
+ yield updated_history, ""
656
+ time.sleep(delay)
657
+ continue
658
+ else:
659
+ # Final error handling
660
+ logger.error(f"[STREAM_CHAT] Error in stream_chat_with_model_check: {e}")
661
+ import traceback
662
+ logger.error(f"[STREAM_CHAT] Full traceback: {traceback.format_exc()}")
663
+
664
+ if is_gpu_error:
665
+ error_msg = f"⚠️ GPU task was aborted. This can happen if:\n- The request took too long\n- Multiple GPU requests conflicted\n- GPU quota was exceeded\n\nPlease try again or select a different model."
666
+ else:
667
+ error_msg = f"⚠️ An error occurred: {str(e)[:200]}"
668
+
669
+ updated_history = history + [{"role": "assistant", "content": error_msg}]
670
+ yield updated_history, ""
671
+ return
672
 
673
  submit_button.click(
674
  fn=stream_chat_with_model_check,