Jolia / example_zero_shot.py
SovanK's picture
Upload folder using huggingface_hub
6858e35 verified
Raw
History Blame Contribute Delete
5.86 kB
"""End-to-end zero-shot example for Jolia.
What this shows:
1. Load the Jolia vision backbone from the Hub (self-contained, trust_remote_code).
2. Load the paired text encoder (Qwen3-Embedding-8B) and Jolia's bundled
``JoliaTextEncoder`` helper.
3. Encode an image + a list of text prompts into a shared 576-d space and
compute zero-shot scores (cosine or calibrated logits).
This file works two ways:
* On the Hub, alongside the repo files (``snapshot_download`` + ``sys.path``).
* In the staging dir locally, as in this repo's ``scripts/hf_jolia/``.
Heads-up: Qwen3-Embedding-8B is ~18 GB; it downloads to your HF cache on
first use. Pass ``--text-dtype bf16 --text-device cuda`` to fit on a single
GPU.
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
import torch
from transformers import AutoModel
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--jolia-id", default="raidium/Jolia", help="Jolia repo id (or local dir).")
parser.add_argument("--text-id", default="Qwen/Qwen3-Embedding-8B")
parser.add_argument(
"--text-dtype",
choices=("float32", "float16", "bfloat16", "bf16"),
default="bfloat16",
)
parser.add_argument("--text-device", default="cpu", help='e.g. "cuda", "cuda:0", "cpu".')
parser.add_argument("--image-device", default="cpu")
parser.add_argument(
"--prompts",
nargs="+",
default=[
"a CT showing a liver lesion",
"a CT showing pneumonia",
"a normal abdominal CT",
],
)
parser.add_argument(
"--organ-prompts",
nargs="+",
default=["a lesion", "an enlarged organ", "looks normal"],
help="Short organ-specific findings phrases for per-organ zero-shot.",
)
parser.add_argument(
"--organs",
nargs="+",
default=["liver", "spleen", "lungs", "pancreas"],
help="Organs to score the --organ-prompts against (per-organ zero-shot).",
)
args = parser.parse_args()
# 1) Vision: Jolia from the Hub. trust_remote_code=True pulls the small
# self-contained modeling code shipped with the repo.
print(f"[1/3] Loading Jolia from {args.jolia_id} ...")
jolia = AutoModel.from_pretrained(args.jolia_id, trust_remote_code=True).eval().to(args.image_device)
# 2) Text: Qwen3-Embedding-8B + Jolia's bundled JoliaTextEncoder helper.
# We get JoliaTextEncoder either from the local staging dir (if we're
# running from rarm/scripts/hf_jolia) or by snapshotting the Hub repo.
print(f"[2/3] Loading text encoder {args.text_id} ...")
JoliaTextEncoder = _import_text_encoder(args.jolia_id)
dtype = {"bf16": torch.bfloat16, "bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[
args.text_dtype
]
text_encoder = (
JoliaTextEncoder.from_pretrained(args.text_id, dtype=dtype, context_length=jolia.config.text_context_length)
.eval()
.to(args.text_device)
)
# 3) Zero-shot. The image here is a random (1, 11, 192, 192, 192) tensor
# so the example is self-contained. In a real run, build it with
# JoliaPreprocessor on a CT volume:
# from preprocessing_jolia import JoliaPreprocessor
# image = JoliaPreprocessor()(volume, resolution=(0.7, 0.7, 1.0)).unsqueeze(0)
print(f"[3/3] Running zero-shot on {len(args.prompts)} prompts...")
torch.manual_seed(0)
image = torch.randn(1, 11, 192, 192, 192, device=args.image_device)
with torch.no_grad():
text_features = text_encoder(args.prompts).to(args.image_device)
# Default is calibrated CLIP logits: cosine * exp(logit_scale) + bias.
logits = jolia.zero_shot(image, text_features) # (1, N)
probs = torch.sigmoid(logits) # (1, N) per-pair "is it a match?"
# Raw cosine in [-1, 1] for ranking-only use cases.
cosine = jolia.zero_shot(image, text_features, calibrated=False)
print("\n--- Global zero-shot (image CLS vs whole-volume text) ---")
print("prompt logit sigmoid cosine")
for p, lg, pr, c in zip(args.prompts, logits[0].tolist(), probs[0].tolist(), cosine[0].tolist()):
print(f" {p:50s} {lg:+.4f} {pr:.4f} {c:+.4f}")
best = int(logits[0].argmax())
print(f"best match (argmax logit): {args.prompts[best]!r}")
# --- Per-organ (query-routed) zero-shot ----------------------------------
# Score short findings phrases against specific organs' query embeddings.
# Uses the ParallelOrganCLIP text head — calibrated with each organ's own
# trained temperature + bias.
with torch.no_grad():
organ_text = text_encoder(args.organ_prompts).to(args.image_device)
per_organ_logits = jolia.zero_shot_organs(image, organ_text, organs=args.organs)
# {organ: (B, N)} — calibrated logits
print("\n--- Per-organ zero-shot (organ-query vs per-organ text head, calibrated logits) ---")
header = f" {'organ':12s} " + " ".join(f"{p[:18]:>18s}" for p in args.organ_prompts)
print(header)
for organ in args.organs:
row = per_organ_logits[organ][0].tolist()
print(f" {organ:12s} " + " ".join(f"{v:+18.4f}" for v in row))
def _import_text_encoder(jolia_id_or_dir: str):
"""Return the ``JoliaTextEncoder`` class, fetching the repo if needed."""
local = Path(jolia_id_or_dir)
if local.is_dir():
sys.path.append(str(local))
else:
from huggingface_hub import snapshot_download
repo_dir = snapshot_download(jolia_id_or_dir)
sys.path.append(repo_dir)
from text_encoder_jolia import JoliaTextEncoder
return JoliaTextEncoder
if __name__ == "__main__":
main()