junma's picture
Enable MCP server (mcp_server=True)
4a059a8 verified
"""Gradio UI for PanCancerSeg single-case CT tumour segmentation."""
import shutil
import tempfile
from pathlib import Path
import gradio as gr
from predict import (
CANCER_CONFIGS,
install_custom_trainer,
resolve_case_id,
resolve_model_folder,
run_nnunet_prediction_single,
summarize_segmentation,
)
from visualize import generate_outputs
# ── Constants ──────────────────────────────────────────────────────────────────
CANCER_TYPE_CHOICES = {
"Kidney Cancer": "kidney_cancer",
"Liver Cancer": "liver_cancer",
"Pancreatic Cancer": "pancreatic_cancer",
"Lung Cancer": "lung_cancer",
}
DEFAULT_MODEL_DIR = str(Path(__file__).parent / "PanCancerSeg-Specialized-weights")
DEFAULT_DEVICE = "cuda"
# Hugging Face Hub repo that hosts the trained nnUNet weights. On Spaces (where the
# local weights folder is absent) we download them on first use.
MODEL_REPO_ID = "KS987/PanCancerSeg-Specialized-weights"
# Resolved once per process; subsequent inferences reuse it (no re-download).
_WEIGHTS_DIR: Path | None = None
def resolve_weights_dir() -> Path:
"""Return a directory containing the DatasetXXX_* model folders.
Prefer a local checkout (fast local dev); otherwise download the weights
from the Hugging Face Hub once and cache the resolved path in-process so we
never hit the Hub again on later inferences.
"""
global _WEIGHTS_DIR
if _WEIGHTS_DIR is not None:
return _WEIGHTS_DIR
local_dir = Path(DEFAULT_MODEL_DIR).expanduser().resolve()
if local_dir.exists() and any(local_dir.glob("Dataset*")):
_WEIGHTS_DIR = local_dir
return _WEIGHTS_DIR
from huggingface_hub import snapshot_download
downloaded = snapshot_download(
repo_id=MODEL_REPO_ID,
repo_type="model",
allow_patterns=["Dataset*/**"],
)
_WEIGHTS_DIR = Path(downloaded)
return _WEIGHTS_DIR
# ── ZeroGPU support ──────────────────────────────────────────────────────────
# On Hugging Face ZeroGPU Spaces the `spaces` package is available, and any GPU
# work must run inside a function decorated with `@spaces.GPU`. Locally (or on a
# dedicated GPU Space) the package is absent, so we fall back to a no-op so the
# same code keeps working everywhere.
try:
import spaces # type: ignore
_HAS_ZEROGPU = True
except ImportError:
spaces = None
_HAS_ZEROGPU = False
def gpu_task(duration: int = 180):
if _HAS_ZEROGPU:
return spaces.GPU(duration=duration)
def _identity(fn):
return fn
return _identity
@gpu_task(duration=180)
def run_gpu_segmentation(model_folder_str: str, input_file_str: str, output_file_str: str) -> None:
"""Run nnUNet inference on GPU. Executed inside the ZeroGPU worker process.
Uses the single-case, no-multiprocessing path because ZeroGPU runs this in a
daemon process that is not allowed to spawn child processes.
"""
# The custom trainer must be registered inside the GPU worker process so that
# nnUNet can discover it when initialising from the trained model folder.
install_custom_trainer()
run_nnunet_prediction_single(
model_folder=model_folder_str,
input_file=input_file_str,
output_file=output_file_str,
device="cuda",
)
_SAMPLE_DIR = Path(__file__).parent / "sample_input"
_CANCER_TYPE_TO_FOLDER = {
"Kidney Cancer": "kidney",
"Liver Cancer": "liver",
"Pancreatic Cancer": "pancreas",
"Lung Cancer": "lung",
}
def load_example(cancer_type_label: str, index: int) -> str:
"""Return the index-th (1-based) example _0000.nii.gz for the given cancer type."""
folder = _SAMPLE_DIR / _CANCER_TYPE_TO_FOLDER[cancer_type_label]
files = sorted(folder.glob("*_0000.nii.gz"))
if len(files) < index:
raise gr.Error(f"Example {index} not found for {cancer_type_label} in {folder}")
return str(files[index - 1])
def count_examples(cancer_type_label: str) -> int:
"""Number of bundled example CT volumes for a cancer type."""
folder = _SAMPLE_DIR / _CANCER_TYPE_TO_FOLDER[cancer_type_label]
if not folder.exists():
return 0
return len(sorted(folder.glob("*_0000.nii.gz")))
def available_cancer_labels(weights_dir) -> list:
"""Cancer labels whose DatasetXXX folder is present under ``weights_dir``.
A single-cancer Space bundles exactly one DatasetXXX folder, so this returns
a single label and the UI locks to it. A full checkout with all four datasets
returns every label and the UI shows the selector.
"""
weights_dir = Path(weights_dir)
found = [
label
for label, key in CANCER_TYPE_CHOICES.items()
if (weights_dir / CANCER_CONFIGS[key]["dataset_name"]).exists()
]
return found or list(CANCER_TYPE_CHOICES.keys())
# ── Inference ──────────────────────────────────────────────────────────────────
def run_inference(
input_file,
cancer_type_label,
fps,
progress=gr.Progress(track_tqdm=True),
):
if input_file is None:
raise gr.Error("Please upload a .nii.gz CT image first.")
input_path = Path(input_file)
if not input_path.name.endswith(".nii.gz"):
raise gr.Error(f"File must be .nii.gz format. Got: {input_path.name}")
progress(0.02, desc="Resolving model weights...")
try:
model_dir_path = resolve_weights_dir()
except Exception as e:
raise gr.Error(f"Failed to obtain model weights from '{MODEL_REPO_ID}': {e}")
cancer_key = CANCER_TYPE_CHOICES[cancer_type_label]
config = CANCER_CONFIGS[cancer_key]
case_id = resolve_case_id(input_path)
progress(0.10, desc="Loading model weights...")
model_folder = resolve_model_folder(model_dir_path, config["dataset_name"])
output_dir = Path(tempfile.mkdtemp(prefix="pancancerseg_out_"))
try:
with tempfile.TemporaryDirectory(prefix="pancancerseg_in_") as tmp:
tmp_path = Path(tmp)
tmp_input_dir = tmp_path / "input"
tmp_output_dir = tmp_path / "prediction"
tmp_input_dir.mkdir()
tmp_output_dir.mkdir()
nnunet_input = tmp_input_dir / f"{case_id}_0000.nii.gz"
try:
nnunet_input.symlink_to(input_path.resolve())
except (OSError, NotImplementedError):
shutil.copy2(input_path, nnunet_input)
raw_seg = tmp_output_dir / f"{case_id}.nii.gz"
progress(0.20, desc="Running nnUNet inference on GPU (this may take a few minutes)...")
run_gpu_segmentation(
str(model_folder),
str(nnunet_input),
str(raw_seg),
)
if not raw_seg.exists():
produced = [p.name for p in tmp_output_dir.glob("*.nii.gz")]
raise RuntimeError(
f"nnUNet did not produce the expected segmentation. Found: {produced}"
)
seg_path = output_dir / f"{case_id}_seg.nii.gz"
shutil.copy2(raw_seg, seg_path)
progress(0.80, desc="Generating slice images and overlay video...")
viz = generate_outputs(
image_path=input_path,
mask_path=seg_path,
output_dir=output_dir,
case_name=case_id,
cancer_type=config["display_name"],
wl=config["wl"],
ww=config["ww"],
color=config["color"],
alpha=0.5,
fps=int(fps),
)
progress(0.95, desc="Computing tumour volume...")
positive_voxels, tumor_volume_ml = summarize_segmentation(seg_path)
stats = (
f"Case ID : {case_id}\n"
f"Cancer type : {config['display_name']}\n"
f"Positive voxels: {positive_voxels:,}\n"
f"Tumour volume : {tumor_volume_ml:.3f} mL"
)
slices = viz["slices"]
video_path = viz["video"]
video_out = (
str(video_path)
if video_path.exists() and video_path.stat().st_size > 0
else None
)
progress(1.0, desc="Done!")
return (
stats,
str(seg_path),
str(slices.get("centroid")),
str(slices.get("max_area")),
str(slices.get("extent25")),
str(slices.get("extent75")),
video_out,
)
except Exception as e:
shutil.rmtree(output_dir, ignore_errors=True)
raise gr.Error(str(e))
# ── UI ─────────────────────────────────────────────────────────────────────────
def build_ui(available_labels=None):
labels = available_labels or list(CANCER_TYPE_CHOICES.keys())
single = len(labels) == 1
default_label = labels[0]
if single:
title = f"# PanCancerSeg — {default_label} CT Segmentation"
intro = (
f"Upload a `.nii.gz` CT image and click **Run Inference** to segment "
f"**{default_label.lower()}** and obtain a mask plus visualisations."
)
else:
title = "# PanCancerSeg — Specialist CT Tumour Segmentation"
intro = (
"Upload a `.nii.gz` CT image, select the cancer type, and click "
"**Run Inference** to obtain a segmentation mask and visualisations."
)
n_examples = count_examples(default_label) if single else 2
with gr.Blocks(title="PanCancerSeg Inference") as demo:
gr.Markdown(f"{title}\n{intro}")
with gr.Row():
# ── Left panel: inputs ─────────────────────────────────────────────
with gr.Column(scale=1, min_width=300):
input_file = gr.File(
label="CT Image (.nii.gz)",
file_types=[".gz"],
)
cancer_type = gr.Dropdown(
choices=labels,
value=default_label,
label="Cancer Type",
interactive=not single,
)
fps = gr.Slider(
minimum=1,
maximum=30,
value=10,
step=1,
label="Video FPS",
)
example_buttons = []
if n_examples > 0:
with gr.Row():
for i in range(1, n_examples + 1):
label = "Load Example" if n_examples == 1 else f"Load Example {i}"
example_buttons.append(gr.Button(label, size="lg"))
run_btn = gr.Button("Run Inference", variant="primary", size="lg")
video_out = gr.Video(label="Overlay Video")
# ── Right panel: outputs ───────────────────────────────────────────
with gr.Column(scale=2):
with gr.Row():
stats_box = gr.Textbox(
label="Inference Summary",
lines=4,
interactive=False,
)
seg_file = gr.File(label="Download Segmentation Mask (.nii.gz)")
with gr.Row():
img_centroid = gr.Image(label="Centroid Slice", type="filepath")
img_max_area = gr.Image(label="Max Area Slice", type="filepath")
with gr.Row():
img_ext25 = gr.Image(label="Extent 25% Slice", type="filepath")
img_ext75 = gr.Image(label="Extent 75% Slice", type="filepath")
for idx, btn in enumerate(example_buttons, start=1):
btn.click(
fn=(lambda i: lambda ct: load_example(ct, i))(idx),
inputs=[cancer_type],
outputs=[input_file],
)
run_btn.click(
fn=run_inference,
inputs=[input_file, cancer_type, fps],
outputs=[
stats_box,
seg_file,
img_centroid,
img_max_area,
img_ext25,
img_ext75,
video_out,
],
)
return demo
if __name__ == "__main__":
import os
# Warm the weights cache at startup so the very first inference (and every
# later one) does not trigger a download. Failures are non-fatal: we fall
# back to lazy download on the first request.
labels = None
try:
weights_dir = resolve_weights_dir()
labels = available_cancer_labels(weights_dir)
print(f"[startup] available cancer models: {labels}")
except Exception as e:
print(f"[startup] weight pre-fetch skipped: {e}")
demo = build_ui(labels)
# Hugging Face Spaces expect the app on port 7860 (set via GRADIO_SERVER_PORT).
# Locally this falls back to 7860 unless overridden.
port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
demo.launch(
server_name="0.0.0.0",
server_port=port,
share=False,
theme=gr.themes.Soft(),
ssr_mode=False,
mcp_server=True
)