Spaces:
Running
on
Zero
Running
on
Zero
| """Model initialization and management""" | |
| import os | |
| import torch | |
| import threading | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from llama_index.llms.huggingface import HuggingFaceLLM | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from logger import logger | |
| import config | |
| import spaces | |
| try: | |
| from snac import SNAC | |
| SNAC_AVAILABLE = True | |
| except ImportError: | |
| SNAC_AVAILABLE = False | |
| SNAC = None | |
| # For backward compatibility, check TTS library too (but we use Maya1 directly) | |
| try: | |
| from TTS.api import TTS | |
| TTS_AVAILABLE = True | |
| except ImportError: | |
| TTS_AVAILABLE = False | |
| TTS = None | |
| try: | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| try: | |
| import torchaudio | |
| except ImportError: | |
| torchaudio = None | |
| WHISPER_AVAILABLE = True | |
| except ImportError: | |
| WHISPER_AVAILABLE = False | |
| WhisperProcessor = None | |
| WhisperForConditionalGeneration = None | |
| torchaudio = None | |
| # Model loading state tracking | |
| _model_loading_states = {} | |
| _model_loading_lock = threading.Lock() | |
| def set_model_loading_state(model_name: str, state: str): | |
| """ | |
| Set model loading state: 'loading', 'loaded', 'error' | |
| Note: No GPU decorator needed - this just sets a dictionary value, no GPU access required. | |
| """ | |
| with _model_loading_lock: | |
| _model_loading_states[model_name] = state | |
| logger.debug(f"Model {model_name} state set to: {state}") | |
| def get_model_loading_state(model_name: str) -> str: | |
| """ | |
| Get model loading state: 'loading', 'loaded', 'error', or 'unknown' | |
| Note: No GPU decorator needed - this just reads a dictionary value, no GPU access required. | |
| """ | |
| with _model_loading_lock: | |
| return _model_loading_states.get(model_name, "unknown") | |
| def is_model_loaded(model_name: str) -> bool: | |
| """Check if model is loaded and ready""" | |
| with _model_loading_lock: | |
| return (model_name in config.global_medical_models and | |
| config.global_medical_models[model_name] is not None and | |
| _model_loading_states.get(model_name) == "loaded") | |
| def initialize_medical_model(model_name: str, load_to_gpu: bool = True): | |
| """ | |
| Initialize medical model (MedSwin) - download on demand | |
| According to ZeroGPU best practices: | |
| - If load_to_gpu=True: Load directly to GPU using device_map="auto" (must be called within @spaces.GPU decorated function) | |
| - If load_to_gpu=False: Load to CPU first, then move to GPU in inference function | |
| Args: | |
| model_name: Name of the model to load | |
| load_to_gpu: If True, load directly to GPU. If False, load to CPU (for ZeroGPU best practices) | |
| """ | |
| if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None: | |
| set_model_loading_state(model_name, "loading") | |
| logger.info(f"Initializing medical model: {model_name}... (load_to_gpu={load_to_gpu})") | |
| try: | |
| model_path = config.MEDSWIN_MODELS[model_name] | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN) | |
| if load_to_gpu: | |
| # Load directly to GPU (must be within @spaces.GPU decorated function) | |
| # Clear GPU cache before loading to prevent memory issues | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.debug("Cleared GPU cache before model loading") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="auto", # Automatically places model on GPU | |
| trust_remote_code=True, | |
| token=config.HF_TOKEN, | |
| torch_dtype=torch.float16 | |
| ) | |
| # Clear cache after loading | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.debug("Cleared GPU cache after model loading") | |
| else: | |
| # Load to CPU first (ZeroGPU best practice - no GPU decorator needed) | |
| logger.info(f"Loading {model_name} to CPU (will move to GPU during inference)...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="cpu", # Load to CPU | |
| trust_remote_code=True, | |
| token=config.HF_TOKEN, | |
| torch_dtype=torch.float16 | |
| ) | |
| logger.info(f"Model {model_name} loaded to CPU successfully") | |
| # Set models in config BEFORE setting state to "loaded" | |
| config.global_medical_models[model_name] = model | |
| config.global_medical_tokenizers[model_name] = tokenizer | |
| # Set state to "loaded" AFTER models are stored | |
| set_model_loading_state(model_name, "loaded") | |
| logger.info(f"Medical model {model_name} initialized successfully") | |
| # Verify the state was set correctly | |
| if not is_model_loaded(model_name): | |
| logger.warning(f"Model {model_name} initialized but is_model_loaded() returns False. State: {get_model_loading_state(model_name)}, in dict: {model_name in config.global_medical_models}") | |
| except Exception as e: | |
| set_model_loading_state(model_name, "error") | |
| logger.error(f"Failed to initialize medical model {model_name}: {e}") | |
| # Clear cache on error | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| raise | |
| else: | |
| # Model already loaded, ensure state is set | |
| if get_model_loading_state(model_name) != "loaded": | |
| logger.info(f"Model {model_name} exists in config but state not set to 'loaded'. Setting state now.") | |
| set_model_loading_state(model_name, "loaded") | |
| return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name] | |
| def move_model_to_gpu(model_name: str): | |
| """ | |
| Move a model from CPU to GPU (for ZeroGPU best practices) | |
| Must be called within a @spaces.GPU decorated function | |
| According to ZeroGPU best practices: | |
| - Models should be loaded to CPU first (no GPU quota used) | |
| - Models are moved to GPU only during inference (within @spaces.GPU decorated function) | |
| For models loaded with device_map="cpu", we reload with device_map="auto" to avoid | |
| meta tensor issues when moving to GPU. | |
| """ | |
| if model_name not in config.global_medical_models: | |
| raise ValueError(f"Model {model_name} not found in config") | |
| model = config.global_medical_models[model_name] | |
| if model is None: | |
| raise ValueError(f"Model {model_name} is None") | |
| # Check if model is already on GPU | |
| try: | |
| # For models with device_map, check the actual device | |
| if hasattr(model, 'device'): | |
| device_str = str(model.device) | |
| if 'cuda' in device_str.lower(): | |
| logger.debug(f"Model {model_name} is already on GPU ({device_str})") | |
| return model | |
| # Check device_map if available | |
| if hasattr(model, 'hf_device_map'): | |
| device_map = model.hf_device_map | |
| if isinstance(device_map, dict): | |
| # Check if any device is GPU | |
| if any('cuda' in str(dev).lower() for dev in device_map.values()): | |
| logger.debug(f"Model {model_name} is already on GPU (device_map)") | |
| return model | |
| except Exception as e: | |
| logger.debug(f"Could not check model device: {e}") | |
| # For models loaded with device_map="cpu", we need to reload with device_map="auto" | |
| # because models with meta tensors cannot be moved with .to() | |
| logger.info(f"Moving model {model_name} from CPU to GPU...") | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Get model path for reloading | |
| if model_name not in config.MEDSWIN_MODELS: | |
| raise ValueError(f"Model path for {model_name} not found in config.MEDSWIN_MODELS") | |
| model_path = config.MEDSWIN_MODELS[model_name] | |
| try: | |
| # Reload model with device_map="auto" to place it on GPU | |
| # This avoids meta tensor issues when moving from CPU to GPU | |
| logger.info(f"Reloading model {model_name} with device_map='auto' for GPU placement...") | |
| # Delete the old model to free memory | |
| del model | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Reload with GPU device_map | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="auto", # Automatically places model on GPU | |
| trust_remote_code=True, | |
| token=config.HF_TOKEN, | |
| torch_dtype=torch.float16 | |
| ) | |
| config.global_medical_models[model_name] = model | |
| logger.info(f"Model {model_name} reloaded to GPU successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to reload model {model_name} to GPU: {e}") | |
| # Try fallback with accelerate dispatch if reload fails | |
| try: | |
| logger.info(f"Trying accelerate dispatch as fallback...") | |
| from accelerate import dispatch_model | |
| from accelerate.utils import get_balanced_memory, infer_auto_device_map | |
| # Reload model first (in case deletion happened) | |
| if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None: | |
| logger.info(f"Reloading model {model_name} to CPU for accelerate dispatch...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="cpu", | |
| trust_remote_code=True, | |
| token=config.HF_TOKEN, | |
| torch_dtype=torch.float16 | |
| ) | |
| config.global_medical_models[model_name] = model | |
| else: | |
| model = config.global_medical_models[model_name] | |
| # Get device map for GPU | |
| max_memory = get_balanced_memory(model, max_memory={0: "20GiB"}) | |
| device_map = infer_auto_device_map(model, max_memory=max_memory) | |
| model = dispatch_model(model, device_map=device_map) | |
| config.global_medical_models[model_name] = model | |
| logger.info(f"Model {model_name} moved to GPU successfully using accelerate dispatch") | |
| except Exception as e2: | |
| logger.error(f"Failed to move model {model_name} to GPU with all methods: {e2}") | |
| raise | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return model | |
| def initialize_tts_model(): | |
| """Initialize Maya1 TTS model for text-to-speech using transformers and SNAC""" | |
| if not SNAC_AVAILABLE: | |
| logger.warning("SNAC library not installed. Maya1 TTS features will be disabled.") | |
| logger.warning("Install with: pip install snac") | |
| return None | |
| if config.global_tts_model is None: | |
| try: | |
| # Clear GPU cache before loading | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.debug("Cleared GPU cache before TTS model loading") | |
| logger.info("Initializing Maya1 TTS model with Transformers...") | |
| # Load Maya1 model and tokenizer | |
| model = AutoModelForCausalLM.from_pretrained( | |
| config.TTS_MODEL, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| token=config.HF_TOKEN | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| config.TTS_MODEL, | |
| trust_remote_code=True, | |
| token=config.HF_TOKEN | |
| ) | |
| logger.info("Loading SNAC decoder...") | |
| snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval() | |
| if torch.cuda.is_available(): | |
| snac_model = snac_model.to("cuda") | |
| # Store as a dictionary with model, tokenizer, and snac_model | |
| config.global_tts_model = { | |
| "model": model, | |
| "tokenizer": tokenizer, | |
| "snac_model": snac_model | |
| } | |
| logger.info("Maya1 TTS model initialized successfully") | |
| # Clear cache after loading | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.debug("Cleared GPU cache after TTS model loading") | |
| except Exception as e: | |
| logger.warning(f"Maya1 TTS model initialization failed: {e}") | |
| import traceback | |
| logger.warning(f"TTS initialization traceback: {traceback.format_exc()}") | |
| logger.warning("TTS features will be disabled. Install dependencies: pip install snac transformers") | |
| config.global_tts_model = None | |
| # Clear cache on error | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return config.global_tts_model | |
| def initialize_whisper_model(): | |
| """Initialize Whisper model for speech-to-text (ASR) from Hugging Face""" | |
| if not WHISPER_AVAILABLE: | |
| logger.warning("Whisper transformers not installed. ASR features will be disabled.") | |
| return None | |
| if config.global_whisper_model is None: | |
| try: | |
| # Clear GPU cache before loading | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.debug("Cleared GPU cache before Whisper model loading") | |
| logger.info("Initializing Whisper model (openai/whisper-large-v3-turbo) from Hugging Face...") | |
| model_id = "openai/whisper-large-v3-turbo" | |
| processor = WhisperProcessor.from_pretrained(model_id, token=config.HF_TOKEN) | |
| model = WhisperForConditionalGeneration.from_pretrained( | |
| model_id, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| token=config.HF_TOKEN | |
| ) | |
| # Store both processor and model | |
| config.global_whisper_model = {"processor": processor, "model": model} | |
| logger.info(f"Whisper model ({model_id}) initialized successfully") | |
| # Clear cache after loading | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.debug("Cleared GPU cache after Whisper model loading") | |
| except Exception as e: | |
| logger.warning(f"Whisper model initialization failed: {e}") | |
| logger.warning("ASR features will be disabled. Install with: pip install transformers torchaudio") | |
| config.global_whisper_model = None | |
| # Clear cache on error | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return config.global_whisper_model | |
| def get_or_create_embed_model(): | |
| """Reuse embedding model to avoid reloading weights each request""" | |
| if config.global_embed_model is None: | |
| logger.info("Initializing shared embedding model for RAG retrieval...") | |
| config.global_embed_model = HuggingFaceEmbedding(model_name=config.EMBEDDING_MODEL, token=config.HF_TOKEN) | |
| return config.global_embed_model | |
| def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50): | |
| """Get LLM for RAG indexing (uses medical model)""" | |
| medical_model_obj, medical_tokenizer = initialize_medical_model(config.DEFAULT_MEDICAL_MODEL) | |
| return HuggingFaceLLM( | |
| context_window=4096, | |
| max_new_tokens=max_new_tokens, | |
| tokenizer=medical_tokenizer, | |
| model=medical_model_obj, | |
| generate_kwargs={ | |
| "do_sample": True, | |
| "temperature": temperature, | |
| "top_k": top_k, | |
| "top_p": top_p | |
| } | |
| ) | |