Spaces:
Sleeping
Sleeping
| """Gradio UI for PanCancerSeg single-case CT tumour segmentation.""" | |
| import shutil | |
| import tempfile | |
| from pathlib import Path | |
| import gradio as gr | |
| from predict import ( | |
| CANCER_CONFIGS, | |
| install_custom_trainer, | |
| resolve_case_id, | |
| resolve_model_folder, | |
| run_nnunet_prediction_single, | |
| summarize_segmentation, | |
| ) | |
| from visualize import generate_outputs | |
| # ── Constants ────────────────────────────────────────────────────────────────── | |
| CANCER_TYPE_CHOICES = { | |
| "Kidney Cancer": "kidney_cancer", | |
| "Liver Cancer": "liver_cancer", | |
| "Pancreatic Cancer": "pancreatic_cancer", | |
| "Lung Cancer": "lung_cancer", | |
| } | |
| DEFAULT_MODEL_DIR = str(Path(__file__).parent / "PanCancerSeg-Specialized-weights") | |
| DEFAULT_DEVICE = "cuda" | |
| # Hugging Face Hub repo that hosts the trained nnUNet weights. On Spaces (where the | |
| # local weights folder is absent) we download them on first use. | |
| MODEL_REPO_ID = "KS987/PanCancerSeg-Specialized-weights" | |
| # Resolved once per process; subsequent inferences reuse it (no re-download). | |
| _WEIGHTS_DIR: Path | None = None | |
| def resolve_weights_dir() -> Path: | |
| """Return a directory containing the DatasetXXX_* model folders. | |
| Prefer a local checkout (fast local dev); otherwise download the weights | |
| from the Hugging Face Hub once and cache the resolved path in-process so we | |
| never hit the Hub again on later inferences. | |
| """ | |
| global _WEIGHTS_DIR | |
| if _WEIGHTS_DIR is not None: | |
| return _WEIGHTS_DIR | |
| local_dir = Path(DEFAULT_MODEL_DIR).expanduser().resolve() | |
| if local_dir.exists() and any(local_dir.glob("Dataset*")): | |
| _WEIGHTS_DIR = local_dir | |
| return _WEIGHTS_DIR | |
| from huggingface_hub import snapshot_download | |
| downloaded = snapshot_download( | |
| repo_id=MODEL_REPO_ID, | |
| repo_type="model", | |
| allow_patterns=["Dataset*/**"], | |
| ) | |
| _WEIGHTS_DIR = Path(downloaded) | |
| return _WEIGHTS_DIR | |
| # ── ZeroGPU support ────────────────────────────────────────────────────────── | |
| # On Hugging Face ZeroGPU Spaces the `spaces` package is available, and any GPU | |
| # work must run inside a function decorated with `@spaces.GPU`. Locally (or on a | |
| # dedicated GPU Space) the package is absent, so we fall back to a no-op so the | |
| # same code keeps working everywhere. | |
| try: | |
| import spaces # type: ignore | |
| _HAS_ZEROGPU = True | |
| except ImportError: | |
| spaces = None | |
| _HAS_ZEROGPU = False | |
| def gpu_task(duration: int = 180): | |
| if _HAS_ZEROGPU: | |
| return spaces.GPU(duration=duration) | |
| def _identity(fn): | |
| return fn | |
| return _identity | |
| def run_gpu_segmentation(model_folder_str: str, input_file_str: str, output_file_str: str) -> None: | |
| """Run nnUNet inference on GPU. Executed inside the ZeroGPU worker process. | |
| Uses the single-case, no-multiprocessing path because ZeroGPU runs this in a | |
| daemon process that is not allowed to spawn child processes. | |
| """ | |
| # The custom trainer must be registered inside the GPU worker process so that | |
| # nnUNet can discover it when initialising from the trained model folder. | |
| install_custom_trainer() | |
| run_nnunet_prediction_single( | |
| model_folder=model_folder_str, | |
| input_file=input_file_str, | |
| output_file=output_file_str, | |
| device="cuda", | |
| ) | |
| _SAMPLE_DIR = Path(__file__).parent / "sample_input" | |
| _CANCER_TYPE_TO_FOLDER = { | |
| "Kidney Cancer": "kidney", | |
| "Liver Cancer": "liver", | |
| "Pancreatic Cancer": "pancreas", | |
| "Lung Cancer": "lung", | |
| } | |
| def load_example(cancer_type_label: str, index: int) -> str: | |
| """Return the index-th (1-based) example _0000.nii.gz for the given cancer type.""" | |
| folder = _SAMPLE_DIR / _CANCER_TYPE_TO_FOLDER[cancer_type_label] | |
| files = sorted(folder.glob("*_0000.nii.gz")) | |
| if len(files) < index: | |
| raise gr.Error(f"Example {index} not found for {cancer_type_label} in {folder}") | |
| return str(files[index - 1]) | |
| def count_examples(cancer_type_label: str) -> int: | |
| """Number of bundled example CT volumes for a cancer type.""" | |
| folder = _SAMPLE_DIR / _CANCER_TYPE_TO_FOLDER[cancer_type_label] | |
| if not folder.exists(): | |
| return 0 | |
| return len(sorted(folder.glob("*_0000.nii.gz"))) | |
| def available_cancer_labels(weights_dir) -> list: | |
| """Cancer labels whose DatasetXXX folder is present under ``weights_dir``. | |
| A single-cancer Space bundles exactly one DatasetXXX folder, so this returns | |
| a single label and the UI locks to it. A full checkout with all four datasets | |
| returns every label and the UI shows the selector. | |
| """ | |
| weights_dir = Path(weights_dir) | |
| found = [ | |
| label | |
| for label, key in CANCER_TYPE_CHOICES.items() | |
| if (weights_dir / CANCER_CONFIGS[key]["dataset_name"]).exists() | |
| ] | |
| return found or list(CANCER_TYPE_CHOICES.keys()) | |
| # ── Inference ────────────────────────────────────────────────────────────────── | |
| def run_inference( | |
| input_file, | |
| cancer_type_label, | |
| fps, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if input_file is None: | |
| raise gr.Error("Please upload a .nii.gz CT image first.") | |
| input_path = Path(input_file) | |
| if not input_path.name.endswith(".nii.gz"): | |
| raise gr.Error(f"File must be .nii.gz format. Got: {input_path.name}") | |
| progress(0.02, desc="Resolving model weights...") | |
| try: | |
| model_dir_path = resolve_weights_dir() | |
| except Exception as e: | |
| raise gr.Error(f"Failed to obtain model weights from '{MODEL_REPO_ID}': {e}") | |
| cancer_key = CANCER_TYPE_CHOICES[cancer_type_label] | |
| config = CANCER_CONFIGS[cancer_key] | |
| case_id = resolve_case_id(input_path) | |
| progress(0.10, desc="Loading model weights...") | |
| model_folder = resolve_model_folder(model_dir_path, config["dataset_name"]) | |
| output_dir = Path(tempfile.mkdtemp(prefix="pancancerseg_out_")) | |
| try: | |
| with tempfile.TemporaryDirectory(prefix="pancancerseg_in_") as tmp: | |
| tmp_path = Path(tmp) | |
| tmp_input_dir = tmp_path / "input" | |
| tmp_output_dir = tmp_path / "prediction" | |
| tmp_input_dir.mkdir() | |
| tmp_output_dir.mkdir() | |
| nnunet_input = tmp_input_dir / f"{case_id}_0000.nii.gz" | |
| try: | |
| nnunet_input.symlink_to(input_path.resolve()) | |
| except (OSError, NotImplementedError): | |
| shutil.copy2(input_path, nnunet_input) | |
| raw_seg = tmp_output_dir / f"{case_id}.nii.gz" | |
| progress(0.20, desc="Running nnUNet inference on GPU (this may take a few minutes)...") | |
| run_gpu_segmentation( | |
| str(model_folder), | |
| str(nnunet_input), | |
| str(raw_seg), | |
| ) | |
| if not raw_seg.exists(): | |
| produced = [p.name for p in tmp_output_dir.glob("*.nii.gz")] | |
| raise RuntimeError( | |
| f"nnUNet did not produce the expected segmentation. Found: {produced}" | |
| ) | |
| seg_path = output_dir / f"{case_id}_seg.nii.gz" | |
| shutil.copy2(raw_seg, seg_path) | |
| progress(0.80, desc="Generating slice images and overlay video...") | |
| viz = generate_outputs( | |
| image_path=input_path, | |
| mask_path=seg_path, | |
| output_dir=output_dir, | |
| case_name=case_id, | |
| cancer_type=config["display_name"], | |
| wl=config["wl"], | |
| ww=config["ww"], | |
| color=config["color"], | |
| alpha=0.5, | |
| fps=int(fps), | |
| ) | |
| progress(0.95, desc="Computing tumour volume...") | |
| positive_voxels, tumor_volume_ml = summarize_segmentation(seg_path) | |
| stats = ( | |
| f"Case ID : {case_id}\n" | |
| f"Cancer type : {config['display_name']}\n" | |
| f"Positive voxels: {positive_voxels:,}\n" | |
| f"Tumour volume : {tumor_volume_ml:.3f} mL" | |
| ) | |
| slices = viz["slices"] | |
| video_path = viz["video"] | |
| video_out = ( | |
| str(video_path) | |
| if video_path.exists() and video_path.stat().st_size > 0 | |
| else None | |
| ) | |
| progress(1.0, desc="Done!") | |
| return ( | |
| stats, | |
| str(seg_path), | |
| str(slices.get("centroid")), | |
| str(slices.get("max_area")), | |
| str(slices.get("extent25")), | |
| str(slices.get("extent75")), | |
| video_out, | |
| ) | |
| except Exception as e: | |
| shutil.rmtree(output_dir, ignore_errors=True) | |
| raise gr.Error(str(e)) | |
| # ── UI ───────────────────────────────────────────────────────────────────────── | |
| def build_ui(available_labels=None): | |
| labels = available_labels or list(CANCER_TYPE_CHOICES.keys()) | |
| single = len(labels) == 1 | |
| default_label = labels[0] | |
| if single: | |
| title = f"# PanCancerSeg — {default_label} CT Segmentation" | |
| intro = ( | |
| f"Upload a `.nii.gz` CT image and click **Run Inference** to segment " | |
| f"**{default_label.lower()}** and obtain a mask plus visualisations." | |
| ) | |
| else: | |
| title = "# PanCancerSeg — Specialist CT Tumour Segmentation" | |
| intro = ( | |
| "Upload a `.nii.gz` CT image, select the cancer type, and click " | |
| "**Run Inference** to obtain a segmentation mask and visualisations." | |
| ) | |
| n_examples = count_examples(default_label) if single else 2 | |
| with gr.Blocks(title="PanCancerSeg Inference") as demo: | |
| gr.Markdown(f"{title}\n{intro}") | |
| with gr.Row(): | |
| # ── Left panel: inputs ───────────────────────────────────────────── | |
| with gr.Column(scale=1, min_width=300): | |
| input_file = gr.File( | |
| label="CT Image (.nii.gz)", | |
| file_types=[".gz"], | |
| ) | |
| cancer_type = gr.Dropdown( | |
| choices=labels, | |
| value=default_label, | |
| label="Cancer Type", | |
| interactive=not single, | |
| ) | |
| fps = gr.Slider( | |
| minimum=1, | |
| maximum=30, | |
| value=10, | |
| step=1, | |
| label="Video FPS", | |
| ) | |
| example_buttons = [] | |
| if n_examples > 0: | |
| with gr.Row(): | |
| for i in range(1, n_examples + 1): | |
| label = "Load Example" if n_examples == 1 else f"Load Example {i}" | |
| example_buttons.append(gr.Button(label, size="lg")) | |
| run_btn = gr.Button("Run Inference", variant="primary", size="lg") | |
| video_out = gr.Video(label="Overlay Video") | |
| # ── Right panel: outputs ─────────────────────────────────────────── | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| stats_box = gr.Textbox( | |
| label="Inference Summary", | |
| lines=4, | |
| interactive=False, | |
| ) | |
| seg_file = gr.File(label="Download Segmentation Mask (.nii.gz)") | |
| with gr.Row(): | |
| img_centroid = gr.Image(label="Centroid Slice", type="filepath") | |
| img_max_area = gr.Image(label="Max Area Slice", type="filepath") | |
| with gr.Row(): | |
| img_ext25 = gr.Image(label="Extent 25% Slice", type="filepath") | |
| img_ext75 = gr.Image(label="Extent 75% Slice", type="filepath") | |
| for idx, btn in enumerate(example_buttons, start=1): | |
| btn.click( | |
| fn=(lambda i: lambda ct: load_example(ct, i))(idx), | |
| inputs=[cancer_type], | |
| outputs=[input_file], | |
| ) | |
| run_btn.click( | |
| fn=run_inference, | |
| inputs=[input_file, cancer_type, fps], | |
| outputs=[ | |
| stats_box, | |
| seg_file, | |
| img_centroid, | |
| img_max_area, | |
| img_ext25, | |
| img_ext75, | |
| video_out, | |
| ], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| import os | |
| # Warm the weights cache at startup so the very first inference (and every | |
| # later one) does not trigger a download. Failures are non-fatal: we fall | |
| # back to lazy download on the first request. | |
| labels = None | |
| try: | |
| weights_dir = resolve_weights_dir() | |
| labels = available_cancer_labels(weights_dir) | |
| print(f"[startup] available cancer models: {labels}") | |
| except Exception as e: | |
| print(f"[startup] weight pre-fetch skipped: {e}") | |
| demo = build_ui(labels) | |
| # Hugging Face Spaces expect the app on port 7860 (set via GRADIO_SERVER_PORT). | |
| # Locally this falls back to 7860 unless overridden. | |
| port = int(os.environ.get("GRADIO_SERVER_PORT", 7860)) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=port, | |
| share=False, | |
| theme=gr.themes.Soft(), | |
| ssr_mode=False, | |
| mcp_server=True | |
| ) | |