MedLLM-Agent / models.py
Y Phung Nguyen
Upd maya configs
d0e54ed
"""Model initialization and management"""
import os
import torch
import threading
from transformers import AutoModelForCausalLM, AutoTokenizer
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from logger import logger
import config
import spaces
try:
from snac import SNAC
SNAC_AVAILABLE = True
except ImportError:
SNAC_AVAILABLE = False
SNAC = None
# For backward compatibility, check TTS library too (but we use Maya1 directly)
try:
from TTS.api import TTS
TTS_AVAILABLE = True
except ImportError:
TTS_AVAILABLE = False
TTS = None
try:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
try:
import torchaudio
except ImportError:
torchaudio = None
WHISPER_AVAILABLE = True
except ImportError:
WHISPER_AVAILABLE = False
WhisperProcessor = None
WhisperForConditionalGeneration = None
torchaudio = None
# Model loading state tracking
_model_loading_states = {}
_model_loading_lock = threading.Lock()
def set_model_loading_state(model_name: str, state: str):
"""
Set model loading state: 'loading', 'loaded', 'error'
Note: No GPU decorator needed - this just sets a dictionary value, no GPU access required.
"""
with _model_loading_lock:
_model_loading_states[model_name] = state
logger.debug(f"Model {model_name} state set to: {state}")
def get_model_loading_state(model_name: str) -> str:
"""
Get model loading state: 'loading', 'loaded', 'error', or 'unknown'
Note: No GPU decorator needed - this just reads a dictionary value, no GPU access required.
"""
with _model_loading_lock:
return _model_loading_states.get(model_name, "unknown")
def is_model_loaded(model_name: str) -> bool:
"""Check if model is loaded and ready"""
with _model_loading_lock:
return (model_name in config.global_medical_models and
config.global_medical_models[model_name] is not None and
_model_loading_states.get(model_name) == "loaded")
def initialize_medical_model(model_name: str, load_to_gpu: bool = True):
"""
Initialize medical model (MedSwin) - download on demand
According to ZeroGPU best practices:
- If load_to_gpu=True: Load directly to GPU using device_map="auto" (must be called within @spaces.GPU decorated function)
- If load_to_gpu=False: Load to CPU first, then move to GPU in inference function
Args:
model_name: Name of the model to load
load_to_gpu: If True, load directly to GPU. If False, load to CPU (for ZeroGPU best practices)
"""
if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None:
set_model_loading_state(model_name, "loading")
logger.info(f"Initializing medical model: {model_name}... (load_to_gpu={load_to_gpu})")
try:
model_path = config.MEDSWIN_MODELS[model_name]
tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN)
if load_to_gpu:
# Load directly to GPU (must be within @spaces.GPU decorated function)
# Clear GPU cache before loading to prevent memory issues
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug("Cleared GPU cache before model loading")
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto", # Automatically places model on GPU
trust_remote_code=True,
token=config.HF_TOKEN,
torch_dtype=torch.float16
)
# Clear cache after loading
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug("Cleared GPU cache after model loading")
else:
# Load to CPU first (ZeroGPU best practice - no GPU decorator needed)
logger.info(f"Loading {model_name} to CPU (will move to GPU during inference)...")
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="cpu", # Load to CPU
trust_remote_code=True,
token=config.HF_TOKEN,
torch_dtype=torch.float16
)
logger.info(f"Model {model_name} loaded to CPU successfully")
# Set models in config BEFORE setting state to "loaded"
config.global_medical_models[model_name] = model
config.global_medical_tokenizers[model_name] = tokenizer
# Set state to "loaded" AFTER models are stored
set_model_loading_state(model_name, "loaded")
logger.info(f"Medical model {model_name} initialized successfully")
# Verify the state was set correctly
if not is_model_loaded(model_name):
logger.warning(f"Model {model_name} initialized but is_model_loaded() returns False. State: {get_model_loading_state(model_name)}, in dict: {model_name in config.global_medical_models}")
except Exception as e:
set_model_loading_state(model_name, "error")
logger.error(f"Failed to initialize medical model {model_name}: {e}")
# Clear cache on error
if torch.cuda.is_available():
torch.cuda.empty_cache()
raise
else:
# Model already loaded, ensure state is set
if get_model_loading_state(model_name) != "loaded":
logger.info(f"Model {model_name} exists in config but state not set to 'loaded'. Setting state now.")
set_model_loading_state(model_name, "loaded")
return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name]
def move_model_to_gpu(model_name: str):
"""
Move a model from CPU to GPU (for ZeroGPU best practices)
Must be called within a @spaces.GPU decorated function
According to ZeroGPU best practices:
- Models should be loaded to CPU first (no GPU quota used)
- Models are moved to GPU only during inference (within @spaces.GPU decorated function)
For models loaded with device_map="cpu", we reload with device_map="auto" to avoid
meta tensor issues when moving to GPU.
"""
if model_name not in config.global_medical_models:
raise ValueError(f"Model {model_name} not found in config")
model = config.global_medical_models[model_name]
if model is None:
raise ValueError(f"Model {model_name} is None")
# Check if model is already on GPU
try:
# For models with device_map, check the actual device
if hasattr(model, 'device'):
device_str = str(model.device)
if 'cuda' in device_str.lower():
logger.debug(f"Model {model_name} is already on GPU ({device_str})")
return model
# Check device_map if available
if hasattr(model, 'hf_device_map'):
device_map = model.hf_device_map
if isinstance(device_map, dict):
# Check if any device is GPU
if any('cuda' in str(dev).lower() for dev in device_map.values()):
logger.debug(f"Model {model_name} is already on GPU (device_map)")
return model
except Exception as e:
logger.debug(f"Could not check model device: {e}")
# For models loaded with device_map="cpu", we need to reload with device_map="auto"
# because models with meta tensors cannot be moved with .to()
logger.info(f"Moving model {model_name} from CPU to GPU...")
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Get model path for reloading
if model_name not in config.MEDSWIN_MODELS:
raise ValueError(f"Model path for {model_name} not found in config.MEDSWIN_MODELS")
model_path = config.MEDSWIN_MODELS[model_name]
try:
# Reload model with device_map="auto" to place it on GPU
# This avoids meta tensor issues when moving from CPU to GPU
logger.info(f"Reloading model {model_name} with device_map='auto' for GPU placement...")
# Delete the old model to free memory
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Reload with GPU device_map
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto", # Automatically places model on GPU
trust_remote_code=True,
token=config.HF_TOKEN,
torch_dtype=torch.float16
)
config.global_medical_models[model_name] = model
logger.info(f"Model {model_name} reloaded to GPU successfully")
except Exception as e:
logger.error(f"Failed to reload model {model_name} to GPU: {e}")
# Try fallback with accelerate dispatch if reload fails
try:
logger.info(f"Trying accelerate dispatch as fallback...")
from accelerate import dispatch_model
from accelerate.utils import get_balanced_memory, infer_auto_device_map
# Reload model first (in case deletion happened)
if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None:
logger.info(f"Reloading model {model_name} to CPU for accelerate dispatch...")
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="cpu",
trust_remote_code=True,
token=config.HF_TOKEN,
torch_dtype=torch.float16
)
config.global_medical_models[model_name] = model
else:
model = config.global_medical_models[model_name]
# Get device map for GPU
max_memory = get_balanced_memory(model, max_memory={0: "20GiB"})
device_map = infer_auto_device_map(model, max_memory=max_memory)
model = dispatch_model(model, device_map=device_map)
config.global_medical_models[model_name] = model
logger.info(f"Model {model_name} moved to GPU successfully using accelerate dispatch")
except Exception as e2:
logger.error(f"Failed to move model {model_name} to GPU with all methods: {e2}")
raise
if torch.cuda.is_available():
torch.cuda.empty_cache()
return model
def initialize_tts_model():
"""Initialize Maya1 TTS model for text-to-speech using transformers and SNAC"""
if not SNAC_AVAILABLE:
logger.warning("SNAC library not installed. Maya1 TTS features will be disabled.")
logger.warning("Install with: pip install snac")
return None
if config.global_tts_model is None:
try:
# Clear GPU cache before loading
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug("Cleared GPU cache before TTS model loading")
logger.info("Initializing Maya1 TTS model with Transformers...")
# Load Maya1 model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
config.TTS_MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
token=config.HF_TOKEN
)
tokenizer = AutoTokenizer.from_pretrained(
config.TTS_MODEL,
trust_remote_code=True,
token=config.HF_TOKEN
)
logger.info("Loading SNAC decoder...")
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
if torch.cuda.is_available():
snac_model = snac_model.to("cuda")
# Store as a dictionary with model, tokenizer, and snac_model
config.global_tts_model = {
"model": model,
"tokenizer": tokenizer,
"snac_model": snac_model
}
logger.info("Maya1 TTS model initialized successfully")
# Clear cache after loading
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug("Cleared GPU cache after TTS model loading")
except Exception as e:
logger.warning(f"Maya1 TTS model initialization failed: {e}")
import traceback
logger.warning(f"TTS initialization traceback: {traceback.format_exc()}")
logger.warning("TTS features will be disabled. Install dependencies: pip install snac transformers")
config.global_tts_model = None
# Clear cache on error
if torch.cuda.is_available():
torch.cuda.empty_cache()
return config.global_tts_model
def initialize_whisper_model():
"""Initialize Whisper model for speech-to-text (ASR) from Hugging Face"""
if not WHISPER_AVAILABLE:
logger.warning("Whisper transformers not installed. ASR features will be disabled.")
return None
if config.global_whisper_model is None:
try:
# Clear GPU cache before loading
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug("Cleared GPU cache before Whisper model loading")
logger.info("Initializing Whisper model (openai/whisper-large-v3-turbo) from Hugging Face...")
model_id = "openai/whisper-large-v3-turbo"
processor = WhisperProcessor.from_pretrained(model_id, token=config.HF_TOKEN)
model = WhisperForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16,
token=config.HF_TOKEN
)
# Store both processor and model
config.global_whisper_model = {"processor": processor, "model": model}
logger.info(f"Whisper model ({model_id}) initialized successfully")
# Clear cache after loading
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug("Cleared GPU cache after Whisper model loading")
except Exception as e:
logger.warning(f"Whisper model initialization failed: {e}")
logger.warning("ASR features will be disabled. Install with: pip install transformers torchaudio")
config.global_whisper_model = None
# Clear cache on error
if torch.cuda.is_available():
torch.cuda.empty_cache()
return config.global_whisper_model
def get_or_create_embed_model():
"""Reuse embedding model to avoid reloading weights each request"""
if config.global_embed_model is None:
logger.info("Initializing shared embedding model for RAG retrieval...")
config.global_embed_model = HuggingFaceEmbedding(model_name=config.EMBEDDING_MODEL, token=config.HF_TOKEN)
return config.global_embed_model
def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
"""Get LLM for RAG indexing (uses medical model)"""
medical_model_obj, medical_tokenizer = initialize_medical_model(config.DEFAULT_MEDICAL_MODEL)
return HuggingFaceLLM(
context_window=4096,
max_new_tokens=max_new_tokens,
tokenizer=medical_tokenizer,
model=medical_model_obj,
generate_kwargs={
"do_sample": True,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p
}
)