File size: 25,267 Bytes
5b58e8c
b277e29
3c2a451
 
b277e29
5b58e8c
 
 
 
3c2a451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b2a27e
f9018e6
b277e29
5b2a27e
 
9527ae3
5b2a27e
b277e29
 
5b2a27e
 
9527ae3
 
3c2a451
f9018e6
 
 
3c2a451
b277e29
3c2a451
 
 
 
 
 
f9018e6
b277e29
 
5b2a27e
3c2a451
 
 
 
 
 
4717b60
 
 
 
01736c5
4717b60
ec057be
4717b60
 
 
 
 
 
 
 
 
 
 
3c2a451
8898a5b
 
 
 
 
2201147
 
b5b8cb6
2201147
 
 
 
 
 
 
 
 
3c2a451
2201147
b5b8cb6
3c2a451
 
f9018e6
3c2a451
2201147
 
 
 
 
 
 
3c2a451
2201147
3c2a451
 
 
 
 
2201147
 
 
 
 
 
 
3c2a451
2201147
5b2a27e
 
 
ec057be
b277e29
 
5b2a27e
b277e29
9527ae3
b277e29
c2612c7
 
9527ae3
 
 
b277e29
d7d2496
 
 
 
 
 
0b705ce
 
d7d2496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cfa1d1
 
2201147
 
9cfa1d1
d7d2496
f2e8d27
9cfa1d1
 
 
 
 
 
 
 
 
d7d2496
f2e8d27
2201147
 
 
9cfa1d1
9527ae3
3c2a451
 
 
 
 
 
 
 
 
 
 
 
5b58e8c
 
 
 
8baed35
 
3c2a451
5b58e8c
 
 
 
 
8baed35
5b58e8c
 
 
 
 
 
 
3c2a451
5b58e8c
 
3c2a451
 
 
 
 
5b58e8c
 
 
 
 
 
 
 
 
 
 
 
 
 
3c2a451
5b58e8c
3c2a451
 
 
 
 
 
5b58e8c
 
3c2a451
5b58e8c
 
3c2a451
 
 
5b58e8c
 
 
3c2a451
 
5b58e8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c2a451
 
5b58e8c
 
 
 
 
 
3c2a451
5b58e8c
 
 
 
3c2a451
5b58e8c
3c2a451
 
5b58e8c
3c2a451
3dc2928
 
 
 
 
 
 
 
 
f0b4550
3dc2928
f0b4550
 
3dc2928
5b58e8c
 
 
 
 
 
 
 
 
 
3dc2928
 
 
 
 
3c2a451
 
5b58e8c
3c2a451
5b58e8c
 
 
 
 
 
 
3c2a451
5b58e8c
 
 
 
 
3c2a451
5b58e8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b87513a
5b58e8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b2a27e
 
3c2a451
6670ba7
3c2a451
 
6670ba7
 
3c2a451
 
 
f0b4550
3c2a451
5b58e8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb8133f
5b58e8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c2a451
 
5b58e8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b277e29
5b58e8c
b277e29
 
5b58e8c
 
 
 
b277e29
3c2a451
 
 
cdaecbc
145e85f
3c2a451
5b58e8c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
import os
# Force CPU-only in this process by hiding CUDA devices (set before importing heavy libs)
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import torch
import gradio as gr
import time

# Force CPU device globally by overriding torch.cuda functions
torch.cuda.is_available = lambda: False
torch.cuda.device_count = lambda: 0

# Prevent any CUDA initialization attempts
def _dummy_lazy_init():
    pass

torch.cuda._lazy_init = _dummy_lazy_init
torch.cuda.init = _dummy_lazy_init

# Override cuda() method to return self (stay on CPU)
def _cpu_only_cuda(self, *_args, **_kwargs):
    # Instead of moving to CUDA, just return self (stay on CPU)
    return self

torch.Tensor.cuda = _cpu_only_cuda
torch.nn.Module.cuda = _cpu_only_cuda

# =========================================
# Safe Libra Hook (CPU fallback + dtype fix)
# This hook must run before any heavyweight libra model-loading occurs.
# =========================================
import libra.model.builder as builder
import libra.eval.run_libra as run_libra

# 保存原始函数(如果存在)
_original_load_pretrained_model = getattr(builder, 'load_pretrained_model', None)

def safe_load_pretrained_model(model_path, model_base=None, model_name=None, **kwargs):
    print("[INFO] Hook activated: safe_load_pretrained_model()")

    # Complete model_name to avoid .lower() on None
    if model_name is None:
        model_name = model_path

    # Force CPU device only (remove conflicting parameters)
    kwargs = dict(kwargs)
    kwargs['device'] = 'cpu'

    # Remove any parameters that might conflict
    kwargs.pop('device_map', None)
    kwargs.pop('low_cpu_mem_usage', None)
    kwargs.pop('torch_dtype', None)

    if _original_load_pretrained_model is None:
        raise RuntimeError('Original load_pretrained_model not found in builder')

    # Call original function with CPU device
    print(f"[INFO] Loading model with kwargs: {kwargs}")
    tokenizer, model, image_processor, context_len = _original_load_pretrained_model(
        model_path, model_base, model_name, **kwargs
    )

    # Fix tokenizer pad_token_id if it's None (common issue with Llama 3 models)
    if tokenizer.pad_token_id is None:
        if tokenizer.eos_token_id is not None:
            # tokenizer.pad_token_id = tokenizer.eos_token_id
            tokenizer.pad_token_id = 128001
            print(f'[INFO] Set pad_token_id to eos_token_id: {tokenizer.pad_token_id}')
        else:
            tokenizer.pad_token_id = 0
            print('[INFO] Set pad_token_id to 0 (default)')

    # # Also ensure pad_token is set (Llama 3 specific) - CRITICAL for output
    # if tokenizer.pad_token is None:
    #     if tokenizer.eos_token is not None:
    #         tokenizer.pad_token_id = tokenizer.eos_token
    #         print(f'[INFO] Set pad_token to eos_token: {tokenizer.pad_token_id}')
    #     else:
    #         tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    #         print('[INFO] Added [PAD] token to tokenizer')

    # Set padding_side to left for Llama 3 (prevents empty generation)
    if hasattr(tokenizer, 'padding_side'):
        tokenizer.padding_side = "left"
        print('[INFO] Set tokenizer.padding_side to "left" for proper generation')

    # Force all model components to CPU (keep original dtype if possible, fallback to float32)
    print('[INFO] Ensuring all components are on CPU...')
    try:
        # Only convert to float32 if model is in float16 (which is slow on CPU)
        current_dtype = next(model.parameters()).dtype
        if current_dtype == torch.float16 or current_dtype == torch.bfloat16:
            print(f'[INFO] Converting model from {current_dtype} to float32 for CPU compatibility...')
            model = model.to(device='cpu', dtype=torch.float32)
        else:
            print(f'[INFO] Keeping model dtype as {current_dtype} (already CPU-compatible)')
            model = model.to(device='cpu')
        print('[INFO] Model moved to CPU.')
    except Exception as e:
        print(f"[WARN] Could not move model to CPU: {e}")

    try:
        if hasattr(model, 'get_vision_tower'):
            vt = model.get_vision_tower()
            if vt is not None:
                vt_dtype = next(vt.parameters()).dtype
                if vt_dtype == torch.float16 or vt_dtype == torch.bfloat16:
                    vt = vt.to(device='cpu', dtype=torch.float32)
                    print(f'[INFO] Vision tower converted to float32 for CPU.')
                else:
                    vt = vt.to(device='cpu')
                    print(f'[INFO] Vision tower moved to CPU (keeping {vt_dtype}).')
    except Exception as e:
        print(f"[WARN] Could not move vision_tower to CPU: {e}")

    try:
        if hasattr(model, 'get_model'):
            inner_model = model.get_model()
            if inner_model is not None:
                inner_dtype = next(inner_model.parameters()).dtype
                if inner_dtype == torch.float16 or inner_dtype == torch.bfloat16:
                    inner_model = inner_model.to(device='cpu', dtype=torch.float32)
                    print(f'[INFO] Inner model converted to float32 for CPU.')
                else:
                    inner_model = inner_model.to(device='cpu')
                    print(f'[INFO] Inner model moved to CPU (keeping {inner_dtype}).')
    except Exception as e:
        print(f"[WARN] Could not move inner model to CPU: {e}")

    return tokenizer, model, image_processor, context_len


if _original_load_pretrained_model is not None:
    builder.load_pretrained_model = safe_load_pretrained_model

# 同时替换 run_libra.load_model
def safe_load_model(model_path, model_base=None, model_name=None):
    print('[INFO] Hook activated: safe_load_model()')
    if model_name is None:
        model_name = model_path
    return safe_load_pretrained_model(model_path, model_base, model_name)

run_libra.load_model = safe_load_model

def get_image_tensors_batch_cpu(images, image_processor, model, device='cpu'):
    """
    CPU-only version of get_image_tensors_batch.
    Keeps the same structure and behaviour as the original function.
    """
    import torch
    from libra.mm_utils import process_images
    
    batch_size = len(images)
    all_processed = []

    if isinstance(images, str):
        images = [images]
    elif not isinstance(images, (list, tuple)):
        raise TypeError("images must be a string or a list/tuple of strings")

    for i in range(batch_size):
        images_i = [images[i]]

        # Ensure two images are present
        if len(images_i) != 2:
            images_i.append(images_i[0])
            if hasattr(model, "config") and getattr(model.config, "mm_projector_type", None) == "TAC":
                print("Contains only current image. Adding a dummy prior image for TAC.")

        processed_images = []
        for img_data in images_i:
            image_tensor = process_images([img_data], image_processor, model.config)[0]
            
            # ✅ Force to CPU instead of CUDA
            image_tensor = image_tensor.to(device=device, non_blocking=False)
            processed_images.append(image_tensor)

        cur_images = processed_images[0]
        prior_images = processed_images[1]

        batch_images = torch.stack([cur_images, prior_images])
        all_processed.append(batch_images)

    # ⚠️ Keep the original structure (dim=1)
    batch_images = torch.stack(all_processed, dim=1)

    return batch_images

# IMPORTANT: Patch libra functions BEFORE importing CCD
# (because ccd imports these functions during module load)
import libra.eval.run_libra as run_libra_module

# Replace the function in the module BEFORE ccd imports it
run_libra_module.get_image_tensors_batch = get_image_tensors_batch_cpu
# print('[INFO] Replaced get_image_tensors_batch with CPU-only version')

# Now import CCD and hook ccd_utils to force CPU for expert models
import ccd.ccd_utils as ccd_utils_module
ccd_utils_module._DEVICE = torch.device('cpu')
print('[INFO] Forced ccd_utils._DEVICE to CPU')

# Now import CCD module and patch its imported function
import ccd.run_ccd as ccd_run_ccd_module
# Replace the function that ccd.run_ccd imported
ccd_run_ccd_module.get_image_tensors_batch = get_image_tensors_batch_cpu
# print('[INFO] Patched ccd.run_ccd.get_image_tensors_batch')

# Now import the evaluation functions
from ccd import ccd_eval as _original_ccd_eval, run_eval
from libra.eval.run_libra import load_model

# Wrap ccd_eval to ensure all tensors stay on CPU
def ccd_eval_cpu_wrapper(*args, **kwargs):
    """Wrapper to ensure ccd_eval runs on CPU only"""
    import warnings
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore')
        result = _original_ccd_eval(*args, **kwargs)
    return result

# Replace with wrapped version
ccd_eval = ccd_eval_cpu_wrapper

# =========================================
# Global Configuration
# =========================================
MODEL_CATALOGUE = {
    "MAIRA-2": "X-iZhang/libra-maira-2",
    "Libra-v1.0-3B (⚡Recommended for CPU)": "X-iZhang/libra-v1.0-3b",
    "Libra-v1.0-7B": "X-iZhang/libra-v1.0-7b",
    "LLaVA-Med-v1.5": "X-iZhang/libra-llava-med-v1.5-mistral-7b",
    "LLaVA-Rad": "X-iZhang/libra-llava-rad",
    "Med-CXRGen-F": "X-iZhang/Med-CXRGen-F",
    "Med-CXRGen-I": "X-iZhang/Med-CXRGen-I"
}
DEFAULT_MODEL_NAME = "Libra-v1.0-3B (⚡Recommended for CPU)"
_loaded_models = {}


# =========================================
# Environment Setup
# =========================================
def setup_environment():
    print("🔹 Running in CPU-only mode (forced for Hugging Face Spaces)")
    os.environ['TOKENIZERS_PARALLELISM'] = 'false'
    os.environ['TRANSFORMERS_CACHE'] = './cache'

    # Set number of threads for CPU inference
    num_threads = min(os.cpu_count() or 4, 8)
    torch.set_num_threads(num_threads)
    print(f"🔹 Using {num_threads} CPU threads")


# =========================================
# Model Loader
# =========================================
def load_or_get_model(model_name: str):
    """Load the model based on its display name."""
    model_path = MODEL_CATALOGUE[model_name]
    print(f"🔹 Model path resolved: {model_path}")
    if model_path in _loaded_models:
        print(f"🔹 Model already loaded: {model_name}")
        return _loaded_models[model_path]

    print(f"🔹 Loading model: {model_name} ({model_path}) ...")
    print(f"🔹 This may take 2-5 minutes on CPU, please wait...")
    try:
        # Clear cache before loading to maximize available memory
        import gc
        gc.collect()
        if hasattr(torch.cuda, 'empty_cache'):
            torch.cuda.empty_cache()

        with torch.no_grad():
            model = load_model(model_path)

        _loaded_models[model_path] = model
        print(f"✅ Loaded successfully: {model_name}")

        # Clean up after loading
        gc.collect()
        return model
    except Exception as e:
        print(f"❌ Error loading model {model_name}: {e}")
        import traceback
        traceback.print_exc()
        raise


# =========================================
# CCD Logic
# =========================================
def generate_ccd_description(
    selected_model_name,
    current_img,
    prompt,
    expert_model,
    alpha,
    beta,
    gamma,
    use_run_eval,
    max_new_tokens,
    progress=gr.Progress()
):
    """Generate findings using CCD evaluation."""
    if not current_img:
        return "⚠️ Please upload or select an example image first."

    try:
        progress(0, desc="Starting inference...")
        print(f"🔹 Generating description with model: {selected_model_name}")
        print(f"🔹 Parameters: alpha={alpha}, beta={beta}, gamma={gamma}")
        print(f"🔹 Image path: {current_img}")

        progress(0.1, desc="Loading model (this may take several minutes on CPU)...")
        model = load_or_get_model(selected_model_name)

        progress(0.3, desc="Running CCD inference (this may take 5-10 minutes on CPU)...")
        print(f"🔹 Running CCD with {selected_model_name} and expert model {expert_model}...")


        # Debug: Print tokenizer info
        tokenizer = model[0]
        print(f"[DEBUG] Tokenizer pad_token: {tokenizer.pad_token}")
        print(f"[DEBUG] Tokenizer pad_token_id: {tokenizer.pad_token_id}")
        print(f"[DEBUG] Tokenizer eos_token: {tokenizer.eos_token}")
        print(f"[DEBUG] Tokenizer eos_token_id: {tokenizer.eos_token_id}")
        print(f"[DEBUG] Tokenizer padding_side: {getattr(tokenizer, 'padding_side', 'NOT SET')}")
        print(f"[DEBUG] Input image path: {current_img}")
        print(f"[DEBUG] max_new_tokens: {max_new_tokens}")    

        prompt = "You are a helpful AI Assistant. " + prompt  
        print(f"[DEBUG] Input prompt: {prompt}")
        
        ccd_output = ccd_eval(
            libra_model=model,
            image=current_img,
            question=prompt,
            max_new_tokens=max_new_tokens,
            expert_model=expert_model,
            alpha=alpha,
            beta=beta,
            gamma=gamma
        )
        
        print(f"[DEBUG] CCD output type: {type(ccd_output)}")
        print(f"[DEBUG] CCD output length: {len(ccd_output) if ccd_output else 0}")
        print(f"[DEBUG] CCD output content: '{ccd_output}'")
        
        progress(0.8, desc="Processing results...")

        if use_run_eval:
            progress(0.85, desc="Running baseline comparison...")
            baseline_output = run_eval(
                libra_model=model,
                image=current_img,
                question=prompt,
                max_new_tokens=max_new_tokens,
                num_beams=1
            )
            progress(1.0, desc="Complete!")
            return (
                f"### 🩺 CCD Result ({expert_model})\n{ccd_output}\n\n"
                f"---\n### ⚖️ Baseline (run_eval)\n{baseline_output[0]}"
            )

        progress(1.0, desc="Complete!")
        return f"### 🩺 CCD Result ({expert_model})\n{ccd_output}"

    except Exception:
        import traceback, sys
        error_msg = traceback.format_exc()
        print("========== CCD ERROR LOG ==========", file=sys.stderr)
        print(error_msg, file=sys.stderr)
        print("===================================", file=sys.stderr)
        return f"❌ Exception Trace:\n```\n{error_msg}\n```"


def safe_generate_ccd_description(
    selected_model_name,
    current_img,
    prompt,
    expert_model,
    alpha,
    beta,
    gamma,
    use_run_eval,
    max_new_tokens
):
    """Wrapper around generate_ccd_description that logs inputs and prints full traceback on error."""
    import traceback, sys, time
    print("\n=== Gradio callback invoked ===")
    print(f"timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"selected_model_name={selected_model_name}")
    print(f"current_img={current_img}")
    print(f"prompt={prompt}")
    print(f"expert_model={expert_model}, alpha={alpha}, beta={beta}, gamma={gamma}, use_run_eval={use_run_eval}, max_new_tokens={max_new_tokens}")

    try:
        return generate_ccd_description(
            selected_model_name,
            current_img,
            prompt,
            expert_model,
            alpha,
            beta,
            gamma,
            use_run_eval,
            max_new_tokens
        )
    except Exception as e:
        err = traceback.format_exc()
        print("========== GRADIO CALLBACK ERROR ==========", file=sys.stderr)
        print(err, file=sys.stderr)
        print("==========================================", file=sys.stderr)
        # Also write the error and inputs to a persistent log file for easier inspection
        try:
            with open('/workspace/CCD/callback.log', 'a', encoding='utf-8') as f:
                f.write('\n=== CALLBACK LOG ENTRY ===\n')
                f.write(f"timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
                f.write(f"selected_model_name={selected_model_name}\n")
                f.write(f"current_img={current_img}\n")
                f.write(f"prompt={prompt}\n")
                f.write(f"expert_model={expert_model}, alpha={alpha}, beta={beta}, gamma={gamma}, use_run_eval={use_run_eval}, max_new_tokens={max_new_tokens}\n")
                f.write('TRACEBACK:\n')
                f.write(err + '\n')
                f.write('=== END ENTRY ===\n')
        except Exception as fe:
            print(f"Failed to write callback.log: {fe}", file=sys.stderr)
        # Also return a user-friendly error message to the UI with traceback
        return f"❌ An internal error occurred. See server logs for details.\n\nTraceback:\n```\n{err}\n```"


# =========================================
# Main Application
# =========================================
def main():
    setup_environment()

    # Example Image Path
    cur_dir = os.path.abspath(os.path.dirname(__file__))
    example_path = os.path.abspath(os.path.join(cur_dir, "example.jpg"))
    example_exists = os.path.exists(example_path)

    # Model reference table
    model_table = """
| **Model Name** | **HuggingFace Link** |
|----------------|----------------------|
| **Libra-v1.0-7B** | [X-iZhang/libra-v1.0-7b](https://huggingface.co/X-iZhang/libra-v1.0-7b) |
| **Libra-v1.0-3B** | [X-iZhang/libra-v1.0-3b](https://huggingface.co/X-iZhang/libra-v1.0-3b) |
| **MAIRA-2** | [X-iZhang/libra-maira-2](https://huggingface.co/X-iZhang/libra-maira-2) |
| **LLaVA-Med-v1.5** | [X-iZhang/libra-llava-med-v1.5-mistral-7b](https://huggingface.co/X-iZhang/libra-llava-med-v1.5-mistral-7b) |
| **LLaVA-Rad** | [X-iZhang/libra-llava-rad](https://huggingface.co/X-iZhang/libra-llava-rad) |
| **Med-CXRGen-F** | [X-iZhang/Med-CXRGen-F](https://huggingface.co/X-iZhang/Med-CXRGen-F) |
| **Med-CXRGen-I** | [X-iZhang/Med-CXRGen-I](https://huggingface.co/X-iZhang/Med-CXRGen-I) |
    """

    with gr.Blocks(title="📷 Clinical Contrastive Decoding", theme="soft") as demo:
        gr.Markdown("""
        # 📷 CCD: Mitigating Hallucinations in Radiology MLLMs via Clinical Contrastive Decoding
        ### [Project Page](https://x-izhang.github.io/CCD/) | [Paper](https://arxiv.org/abs/2509.23379) | [Code](https://github.com/X-iZhang/CCD) | [Models](https://huggingface.co/collections/X-iZhang/libra-6772bfccc6079298a0fa5f8d)
        
        **🚨 Performance Warning**

        This demo is running on **CPU-only** mode. A single inference may take **25-30 minutes** depending on the model and parameters.

        **Recommendations for faster inference:**
        - Use smaller models (Libra-v1.0-3B is faster than 7B models) **The model has already been loaded** ⏬
        - Please do not attempt to load other models, as this may cause a **runtime error**: "Workload evicted, storage limit exceeded (50G)"
        - Reduce `Max New Tokens` to 64-128 (default: 128)
        - Disable baseline comparison
        - For GPU acceleration, please [run the demo locally](https://github.com/X-iZhang/CCD#gradio-web-interface)
        
        **Note:** If you see "Connection Lost", please wait - the inference is still running. The results will appear when complete.
        """)

        with gr.Tab("✨ CCD Demo"):
            with gr.Row():
                # -------- Left Column: Image --------
                with gr.Column(scale=1):
                    gr.Markdown("### Radiology Image (eg. Chest X-ray)")
                    current_img = gr.Image(label="Radiology Image", type="filepath", interactive=True)
                    if example_exists:
                        gr.Examples(
                            examples=[[example_path]],
                            inputs=[current_img],
                            label="Example Image"
                        )
                    else:
                        gr.Markdown(f"⚠️ Example image not found at `{example_path}`")

                # -------- Right Column: Controls --------
                with gr.Column(scale=1):
                    gr.Markdown("### Model Selection & Prompt")
                    selected_model_name = gr.Dropdown(
                        label="Base Radiology MLLM",
                        choices=list(MODEL_CATALOGUE.keys()),
                        value=DEFAULT_MODEL_NAME
                    )
                    prompt = gr.Textbox(
                        label="Question / Prompt",
                        value="What are the findings in this chest X-ray? Give a detailed description.",
                        lines=1
                    )

                    gr.Markdown("### CCD Parameters")
                    expert_model = gr.Radio(
                        label="Expert Model",
                        choices=["MedSigLip", "DenseNet"],
                        value="DenseNet"
                    )

                    # Notice for MedSigLip access requirements (hidden by default)
                    medsiglip_message = (
                        "**Note: The MedSigLip model requires authorization to access.**\n\n"
                        "To use MedSigLip, please deploy the Gradio Web Interface locally and complete the authentication steps.\n"
                        "See deployment instructions and how to run locally here: "
                        "[Gradio Web Interface](https://github.com/X-iZhang/CCD#gradio-web-interface)"
                    )
                    medsiglip_notice = gr.Markdown(value="", visible=False)

                    def _toggle_medsiglip_notice(choice):
                        if choice == "MedSigLip":
                            return gr.update(visible=True, value=medsiglip_message)
                        else:
                            return gr.update(visible=False, value="")

                    # Connect radio change to the notice visibility
                    expert_model.change(fn=_toggle_medsiglip_notice, inputs=[expert_model], outputs=[medsiglip_notice])

                    with gr.Row():
                        alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.1, label="Alpha")
                        beta = gr.Slider(0.0, 1.0, value=0.5, step=0.1, label="Beta")
                        gamma = gr.Slider(0, 20, value=10, step=1, label="Gamma")

                    with gr.Accordion("Advanced Options", open=False):
                        max_new_tokens = gr.Slider(10, 256, value=64, step=1, label="Max New Tokens (lower = faster)")
                        use_run_eval = gr.Checkbox(label="Compare with baseline (run_eval) [doubles inference time]", value=False)

                    generate_btn = gr.Button("🚀 Generate", variant="primary")

            # -------- Output --------
            # output = gr.Markdown(label="Output", value="### 📷 Results will appear here.👇")
            output = gr.Markdown(
                value='<h3 style="color:#007BFF;">📷 Results will appear here.👇</h3>',
                label="Output"
            )
            # Switch callback to the safe wrapper
            generate_btn.click(
                fn=safe_generate_ccd_description,
                inputs=[
                    selected_model_name, current_img, prompt,
                    expert_model, alpha, beta, gamma,
                    use_run_eval, max_new_tokens
                ],
                outputs=output
            )

        # -------- Model Table --------
        # gr.Markdown("### 🧠 Supported Models")
        # gr.Markdown(model_table)

        gr.Markdown("""
        ### Terms of Use
        The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA.
        
        By accessing or using this demo, you acknowledge and agree to the following:
        - **Research & Non-Commercial Purposes**: This demo is provided solely for research and demonstration. It must not be used for commercial activities or profit-driven endeavors.
        - **Not Medical Advice**: All generated content is experimental and must not replace professional medical judgment.
        - **Content Moderationt**: While we apply basic safety checks, the system may still produce inaccurate or offensive outputs.
        - **Responsible Use**: Do not use this demo for any illegal, harmful, hateful, violent, or sexual purposes.
        By continuing to use this service, you confirm your acceptance of these terms. If you do not agree, please discontinue use immediately.
        """)


    # Log that Gradio is starting (helpful when stdout/stderr are captured)
    # write startup log to local file in repository (avoid permission issues on Spaces)
    try:
        os.makedirs('logs', exist_ok=True)
        with open('logs/callback.log', 'a', encoding='utf-8') as f:
            f.write(f"\n=== GRADIO START ===\nstarted_at: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
    except Exception:
        pass


    # Launch with extended timeout for CPU inference
    demo.queue(max_size=10)  # Enable queue for better handling of long-running tasks
    demo.launch(
        max_threads=5,  # Limit concurrent requests
        show_error=True  # Show detailed errors
    )


if __name__ == "__main__":
    main()