import os # Force CPU-only in this process by hiding CUDA devices (set before importing heavy libs) os.environ['CUDA_VISIBLE_DEVICES'] = '' os.environ['CUDA_LAUNCH_BLOCKING'] = '1' import torch import gradio as gr import time # Force CPU device globally by overriding torch.cuda functions torch.cuda.is_available = lambda: False torch.cuda.device_count = lambda: 0 # Prevent any CUDA initialization attempts def _dummy_lazy_init(): pass torch.cuda._lazy_init = _dummy_lazy_init torch.cuda.init = _dummy_lazy_init # Override cuda() method to return self (stay on CPU) def _cpu_only_cuda(self, *_args, **_kwargs): # Instead of moving to CUDA, just return self (stay on CPU) return self torch.Tensor.cuda = _cpu_only_cuda torch.nn.Module.cuda = _cpu_only_cuda # ========================================= # Safe Libra Hook (CPU fallback + dtype fix) # This hook must run before any heavyweight libra model-loading occurs. # ========================================= import libra.model.builder as builder import libra.eval.run_libra as run_libra # 保存原始函数(如果存在) _original_load_pretrained_model = getattr(builder, 'load_pretrained_model', None) def safe_load_pretrained_model(model_path, model_base=None, model_name=None, **kwargs): print("[INFO] Hook activated: safe_load_pretrained_model()") # Complete model_name to avoid .lower() on None if model_name is None: model_name = model_path # Force CPU device only (remove conflicting parameters) kwargs = dict(kwargs) kwargs['device'] = 'cpu' # Remove any parameters that might conflict kwargs.pop('device_map', None) kwargs.pop('low_cpu_mem_usage', None) kwargs.pop('torch_dtype', None) if _original_load_pretrained_model is None: raise RuntimeError('Original load_pretrained_model not found in builder') # Call original function with CPU device print(f"[INFO] Loading model with kwargs: {kwargs}") tokenizer, model, image_processor, context_len = _original_load_pretrained_model( model_path, model_base, model_name, **kwargs ) # Fix tokenizer pad_token_id if it's None (common issue with Llama 3 models) if tokenizer.pad_token_id is None: if tokenizer.eos_token_id is not None: # tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.pad_token_id = 128001 print(f'[INFO] Set pad_token_id to eos_token_id: {tokenizer.pad_token_id}') else: tokenizer.pad_token_id = 0 print('[INFO] Set pad_token_id to 0 (default)') # # Also ensure pad_token is set (Llama 3 specific) - CRITICAL for output # if tokenizer.pad_token is None: # if tokenizer.eos_token is not None: # tokenizer.pad_token_id = tokenizer.eos_token # print(f'[INFO] Set pad_token to eos_token: {tokenizer.pad_token_id}') # else: # tokenizer.add_special_tokens({'pad_token': '[PAD]'}) # print('[INFO] Added [PAD] token to tokenizer') # Set padding_side to left for Llama 3 (prevents empty generation) if hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = "left" print('[INFO] Set tokenizer.padding_side to "left" for proper generation') # Force all model components to CPU (keep original dtype if possible, fallback to float32) print('[INFO] Ensuring all components are on CPU...') try: # Only convert to float32 if model is in float16 (which is slow on CPU) current_dtype = next(model.parameters()).dtype if current_dtype == torch.float16 or current_dtype == torch.bfloat16: print(f'[INFO] Converting model from {current_dtype} to float32 for CPU compatibility...') model = model.to(device='cpu', dtype=torch.float32) else: print(f'[INFO] Keeping model dtype as {current_dtype} (already CPU-compatible)') model = model.to(device='cpu') print('[INFO] Model moved to CPU.') except Exception as e: print(f"[WARN] Could not move model to CPU: {e}") try: if hasattr(model, 'get_vision_tower'): vt = model.get_vision_tower() if vt is not None: vt_dtype = next(vt.parameters()).dtype if vt_dtype == torch.float16 or vt_dtype == torch.bfloat16: vt = vt.to(device='cpu', dtype=torch.float32) print(f'[INFO] Vision tower converted to float32 for CPU.') else: vt = vt.to(device='cpu') print(f'[INFO] Vision tower moved to CPU (keeping {vt_dtype}).') except Exception as e: print(f"[WARN] Could not move vision_tower to CPU: {e}") try: if hasattr(model, 'get_model'): inner_model = model.get_model() if inner_model is not None: inner_dtype = next(inner_model.parameters()).dtype if inner_dtype == torch.float16 or inner_dtype == torch.bfloat16: inner_model = inner_model.to(device='cpu', dtype=torch.float32) print(f'[INFO] Inner model converted to float32 for CPU.') else: inner_model = inner_model.to(device='cpu') print(f'[INFO] Inner model moved to CPU (keeping {inner_dtype}).') except Exception as e: print(f"[WARN] Could not move inner model to CPU: {e}") return tokenizer, model, image_processor, context_len if _original_load_pretrained_model is not None: builder.load_pretrained_model = safe_load_pretrained_model # 同时替换 run_libra.load_model def safe_load_model(model_path, model_base=None, model_name=None): print('[INFO] Hook activated: safe_load_model()') if model_name is None: model_name = model_path return safe_load_pretrained_model(model_path, model_base, model_name) run_libra.load_model = safe_load_model def get_image_tensors_batch_cpu(images, image_processor, model, device='cpu'): """ CPU-only version of get_image_tensors_batch. Keeps the same structure and behaviour as the original function. """ import torch from libra.mm_utils import process_images batch_size = len(images) all_processed = [] if isinstance(images, str): images = [images] elif not isinstance(images, (list, tuple)): raise TypeError("images must be a string or a list/tuple of strings") for i in range(batch_size): images_i = [images[i]] # Ensure two images are present if len(images_i) != 2: images_i.append(images_i[0]) if hasattr(model, "config") and getattr(model.config, "mm_projector_type", None) == "TAC": print("Contains only current image. Adding a dummy prior image for TAC.") processed_images = [] for img_data in images_i: image_tensor = process_images([img_data], image_processor, model.config)[0] # ✅ Force to CPU instead of CUDA image_tensor = image_tensor.to(device=device, non_blocking=False) processed_images.append(image_tensor) cur_images = processed_images[0] prior_images = processed_images[1] batch_images = torch.stack([cur_images, prior_images]) all_processed.append(batch_images) # ⚠️ Keep the original structure (dim=1) batch_images = torch.stack(all_processed, dim=1) return batch_images # IMPORTANT: Patch libra functions BEFORE importing CCD # (because ccd imports these functions during module load) import libra.eval.run_libra as run_libra_module # Replace the function in the module BEFORE ccd imports it run_libra_module.get_image_tensors_batch = get_image_tensors_batch_cpu # print('[INFO] Replaced get_image_tensors_batch with CPU-only version') # Now import CCD and hook ccd_utils to force CPU for expert models import ccd.ccd_utils as ccd_utils_module ccd_utils_module._DEVICE = torch.device('cpu') print('[INFO] Forced ccd_utils._DEVICE to CPU') # Now import CCD module and patch its imported function import ccd.run_ccd as ccd_run_ccd_module # Replace the function that ccd.run_ccd imported ccd_run_ccd_module.get_image_tensors_batch = get_image_tensors_batch_cpu # print('[INFO] Patched ccd.run_ccd.get_image_tensors_batch') # Now import the evaluation functions from ccd import ccd_eval as _original_ccd_eval, run_eval from libra.eval.run_libra import load_model # Wrap ccd_eval to ensure all tensors stay on CPU def ccd_eval_cpu_wrapper(*args, **kwargs): """Wrapper to ensure ccd_eval runs on CPU only""" import warnings with warnings.catch_warnings(): warnings.filterwarnings('ignore') result = _original_ccd_eval(*args, **kwargs) return result # Replace with wrapped version ccd_eval = ccd_eval_cpu_wrapper # ========================================= # Global Configuration # ========================================= MODEL_CATALOGUE = { "MAIRA-2": "X-iZhang/libra-maira-2", "Libra-v1.0-3B (⚡Recommended for CPU)": "X-iZhang/libra-v1.0-3b", "Libra-v1.0-7B": "X-iZhang/libra-v1.0-7b", "LLaVA-Med-v1.5": "X-iZhang/libra-llava-med-v1.5-mistral-7b", "LLaVA-Rad": "X-iZhang/libra-llava-rad", "Med-CXRGen-F": "X-iZhang/Med-CXRGen-F", "Med-CXRGen-I": "X-iZhang/Med-CXRGen-I" } DEFAULT_MODEL_NAME = "Libra-v1.0-3B (⚡Recommended for CPU)" _loaded_models = {} # ========================================= # Environment Setup # ========================================= def setup_environment(): print("🔹 Running in CPU-only mode (forced for Hugging Face Spaces)") os.environ['TOKENIZERS_PARALLELISM'] = 'false' os.environ['TRANSFORMERS_CACHE'] = './cache' # Set number of threads for CPU inference num_threads = min(os.cpu_count() or 4, 8) torch.set_num_threads(num_threads) print(f"🔹 Using {num_threads} CPU threads") # ========================================= # Model Loader # ========================================= def load_or_get_model(model_name: str): """Load the model based on its display name.""" model_path = MODEL_CATALOGUE[model_name] print(f"🔹 Model path resolved: {model_path}") if model_path in _loaded_models: print(f"🔹 Model already loaded: {model_name}") return _loaded_models[model_path] print(f"🔹 Loading model: {model_name} ({model_path}) ...") print(f"🔹 This may take 2-5 minutes on CPU, please wait...") try: # Clear cache before loading to maximize available memory import gc gc.collect() if hasattr(torch.cuda, 'empty_cache'): torch.cuda.empty_cache() with torch.no_grad(): model = load_model(model_path) _loaded_models[model_path] = model print(f"✅ Loaded successfully: {model_name}") # Clean up after loading gc.collect() return model except Exception as e: print(f"❌ Error loading model {model_name}: {e}") import traceback traceback.print_exc() raise # ========================================= # CCD Logic # ========================================= def generate_ccd_description( selected_model_name, current_img, prompt, expert_model, alpha, beta, gamma, use_run_eval, max_new_tokens, progress=gr.Progress() ): """Generate findings using CCD evaluation.""" if not current_img: return "⚠️ Please upload or select an example image first." try: progress(0, desc="Starting inference...") print(f"🔹 Generating description with model: {selected_model_name}") print(f"🔹 Parameters: alpha={alpha}, beta={beta}, gamma={gamma}") print(f"🔹 Image path: {current_img}") progress(0.1, desc="Loading model (this may take several minutes on CPU)...") model = load_or_get_model(selected_model_name) progress(0.3, desc="Running CCD inference (this may take 5-10 minutes on CPU)...") print(f"🔹 Running CCD with {selected_model_name} and expert model {expert_model}...") # Debug: Print tokenizer info tokenizer = model[0] print(f"[DEBUG] Tokenizer pad_token: {tokenizer.pad_token}") print(f"[DEBUG] Tokenizer pad_token_id: {tokenizer.pad_token_id}") print(f"[DEBUG] Tokenizer eos_token: {tokenizer.eos_token}") print(f"[DEBUG] Tokenizer eos_token_id: {tokenizer.eos_token_id}") print(f"[DEBUG] Tokenizer padding_side: {getattr(tokenizer, 'padding_side', 'NOT SET')}") print(f"[DEBUG] Input image path: {current_img}") print(f"[DEBUG] max_new_tokens: {max_new_tokens}") prompt = "You are a helpful AI Assistant. " + prompt print(f"[DEBUG] Input prompt: {prompt}") ccd_output = ccd_eval( libra_model=model, image=current_img, question=prompt, max_new_tokens=max_new_tokens, expert_model=expert_model, alpha=alpha, beta=beta, gamma=gamma ) print(f"[DEBUG] CCD output type: {type(ccd_output)}") print(f"[DEBUG] CCD output length: {len(ccd_output) if ccd_output else 0}") print(f"[DEBUG] CCD output content: '{ccd_output}'") progress(0.8, desc="Processing results...") if use_run_eval: progress(0.85, desc="Running baseline comparison...") baseline_output = run_eval( libra_model=model, image=current_img, question=prompt, max_new_tokens=max_new_tokens, num_beams=1 ) progress(1.0, desc="Complete!") return ( f"### 🩺 CCD Result ({expert_model})\n{ccd_output}\n\n" f"---\n### ⚖️ Baseline (run_eval)\n{baseline_output[0]}" ) progress(1.0, desc="Complete!") return f"### 🩺 CCD Result ({expert_model})\n{ccd_output}" except Exception: import traceback, sys error_msg = traceback.format_exc() print("========== CCD ERROR LOG ==========", file=sys.stderr) print(error_msg, file=sys.stderr) print("===================================", file=sys.stderr) return f"❌ Exception Trace:\n```\n{error_msg}\n```" def safe_generate_ccd_description( selected_model_name, current_img, prompt, expert_model, alpha, beta, gamma, use_run_eval, max_new_tokens ): """Wrapper around generate_ccd_description that logs inputs and prints full traceback on error.""" import traceback, sys, time print("\n=== Gradio callback invoked ===") print(f"timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}") print(f"selected_model_name={selected_model_name}") print(f"current_img={current_img}") print(f"prompt={prompt}") print(f"expert_model={expert_model}, alpha={alpha}, beta={beta}, gamma={gamma}, use_run_eval={use_run_eval}, max_new_tokens={max_new_tokens}") try: return generate_ccd_description( selected_model_name, current_img, prompt, expert_model, alpha, beta, gamma, use_run_eval, max_new_tokens ) except Exception as e: err = traceback.format_exc() print("========== GRADIO CALLBACK ERROR ==========", file=sys.stderr) print(err, file=sys.stderr) print("==========================================", file=sys.stderr) # Also write the error and inputs to a persistent log file for easier inspection try: with open('/workspace/CCD/callback.log', 'a', encoding='utf-8') as f: f.write('\n=== CALLBACK LOG ENTRY ===\n') f.write(f"timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}\n") f.write(f"selected_model_name={selected_model_name}\n") f.write(f"current_img={current_img}\n") f.write(f"prompt={prompt}\n") f.write(f"expert_model={expert_model}, alpha={alpha}, beta={beta}, gamma={gamma}, use_run_eval={use_run_eval}, max_new_tokens={max_new_tokens}\n") f.write('TRACEBACK:\n') f.write(err + '\n') f.write('=== END ENTRY ===\n') except Exception as fe: print(f"Failed to write callback.log: {fe}", file=sys.stderr) # Also return a user-friendly error message to the UI with traceback return f"❌ An internal error occurred. See server logs for details.\n\nTraceback:\n```\n{err}\n```" # ========================================= # Main Application # ========================================= def main(): setup_environment() # Example Image Path cur_dir = os.path.abspath(os.path.dirname(__file__)) example_path = os.path.abspath(os.path.join(cur_dir, "example.jpg")) example_exists = os.path.exists(example_path) # Model reference table model_table = """ | **Model Name** | **HuggingFace Link** | |----------------|----------------------| | **Libra-v1.0-7B** | [X-iZhang/libra-v1.0-7b](https://huggingface.co/X-iZhang/libra-v1.0-7b) | | **Libra-v1.0-3B** | [X-iZhang/libra-v1.0-3b](https://huggingface.co/X-iZhang/libra-v1.0-3b) | | **MAIRA-2** | [X-iZhang/libra-maira-2](https://huggingface.co/X-iZhang/libra-maira-2) | | **LLaVA-Med-v1.5** | [X-iZhang/libra-llava-med-v1.5-mistral-7b](https://huggingface.co/X-iZhang/libra-llava-med-v1.5-mistral-7b) | | **LLaVA-Rad** | [X-iZhang/libra-llava-rad](https://huggingface.co/X-iZhang/libra-llava-rad) | | **Med-CXRGen-F** | [X-iZhang/Med-CXRGen-F](https://huggingface.co/X-iZhang/Med-CXRGen-F) | | **Med-CXRGen-I** | [X-iZhang/Med-CXRGen-I](https://huggingface.co/X-iZhang/Med-CXRGen-I) | """ with gr.Blocks(title="📷 Clinical Contrastive Decoding", theme="soft") as demo: gr.Markdown(""" # 📷 CCD: Mitigating Hallucinations in Radiology MLLMs via Clinical Contrastive Decoding ### [Project Page](https://x-izhang.github.io/CCD/) | [Paper](https://arxiv.org/abs/2509.23379) | [Code](https://github.com/X-iZhang/CCD) | [Models](https://huggingface.co/collections/X-iZhang/libra-6772bfccc6079298a0fa5f8d) **🚨 Performance Warning** This demo is running on **CPU-only** mode. A single inference may take **25-30 minutes** depending on the model and parameters. **Recommendations for faster inference:** - Use smaller models (Libra-v1.0-3B is faster than 7B models) **The model has already been loaded** ⏬ - Please do not attempt to load other models, as this may cause a **runtime error**: "Workload evicted, storage limit exceeded (50G)" - Reduce `Max New Tokens` to 64-128 (default: 128) - Disable baseline comparison - For GPU acceleration, please [run the demo locally](https://github.com/X-iZhang/CCD#gradio-web-interface) **Note:** If you see "Connection Lost", please wait - the inference is still running. The results will appear when complete. """) with gr.Tab("✨ CCD Demo"): with gr.Row(): # -------- Left Column: Image -------- with gr.Column(scale=1): gr.Markdown("### Radiology Image (eg. Chest X-ray)") current_img = gr.Image(label="Radiology Image", type="filepath", interactive=True) if example_exists: gr.Examples( examples=[[example_path]], inputs=[current_img], label="Example Image" ) else: gr.Markdown(f"⚠️ Example image not found at `{example_path}`") # -------- Right Column: Controls -------- with gr.Column(scale=1): gr.Markdown("### Model Selection & Prompt") selected_model_name = gr.Dropdown( label="Base Radiology MLLM", choices=list(MODEL_CATALOGUE.keys()), value=DEFAULT_MODEL_NAME ) prompt = gr.Textbox( label="Question / Prompt", value="What are the findings in this chest X-ray? Give a detailed description.", lines=1 ) gr.Markdown("### CCD Parameters") expert_model = gr.Radio( label="Expert Model", choices=["MedSigLip", "DenseNet"], value="DenseNet" ) # Notice for MedSigLip access requirements (hidden by default) medsiglip_message = ( "**Note: The MedSigLip model requires authorization to access.**\n\n" "To use MedSigLip, please deploy the Gradio Web Interface locally and complete the authentication steps.\n" "See deployment instructions and how to run locally here: " "[Gradio Web Interface](https://github.com/X-iZhang/CCD#gradio-web-interface)" ) medsiglip_notice = gr.Markdown(value="", visible=False) def _toggle_medsiglip_notice(choice): if choice == "MedSigLip": return gr.update(visible=True, value=medsiglip_message) else: return gr.update(visible=False, value="") # Connect radio change to the notice visibility expert_model.change(fn=_toggle_medsiglip_notice, inputs=[expert_model], outputs=[medsiglip_notice]) with gr.Row(): alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.1, label="Alpha") beta = gr.Slider(0.0, 1.0, value=0.5, step=0.1, label="Beta") gamma = gr.Slider(0, 20, value=10, step=1, label="Gamma") with gr.Accordion("Advanced Options", open=False): max_new_tokens = gr.Slider(10, 256, value=64, step=1, label="Max New Tokens (lower = faster)") use_run_eval = gr.Checkbox(label="Compare with baseline (run_eval) [doubles inference time]", value=False) generate_btn = gr.Button("🚀 Generate", variant="primary") # -------- Output -------- # output = gr.Markdown(label="Output", value="### 📷 Results will appear here.👇") output = gr.Markdown( value='