"""End-to-end zero-shot example for Jolia. What this shows: 1. Load the Jolia vision backbone from the Hub (self-contained, trust_remote_code). 2. Load the paired text encoder (Qwen3-Embedding-8B) and Jolia's bundled ``JoliaTextEncoder`` helper. 3. Encode an image + a list of text prompts into a shared 576-d space and compute zero-shot scores (cosine or calibrated logits). This file works two ways: * On the Hub, alongside the repo files (``snapshot_download`` + ``sys.path``). * In the staging dir locally, as in this repo's ``scripts/hf_jolia/``. Heads-up: Qwen3-Embedding-8B is ~18 GB; it downloads to your HF cache on first use. Pass ``--text-dtype bf16 --text-device cuda`` to fit on a single GPU. """ from __future__ import annotations import argparse import sys from pathlib import Path import torch from transformers import AutoModel def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--jolia-id", default="raidium/Jolia", help="Jolia repo id (or local dir).") parser.add_argument("--text-id", default="Qwen/Qwen3-Embedding-8B") parser.add_argument( "--text-dtype", choices=("float32", "float16", "bfloat16", "bf16"), default="bfloat16", ) parser.add_argument("--text-device", default="cpu", help='e.g. "cuda", "cuda:0", "cpu".') parser.add_argument("--image-device", default="cpu") parser.add_argument( "--prompts", nargs="+", default=[ "a CT showing a liver lesion", "a CT showing pneumonia", "a normal abdominal CT", ], ) parser.add_argument( "--organ-prompts", nargs="+", default=["a lesion", "an enlarged organ", "looks normal"], help="Short organ-specific findings phrases for per-organ zero-shot.", ) parser.add_argument( "--organs", nargs="+", default=["liver", "spleen", "lungs", "pancreas"], help="Organs to score the --organ-prompts against (per-organ zero-shot).", ) args = parser.parse_args() # 1) Vision: Jolia from the Hub. trust_remote_code=True pulls the small # self-contained modeling code shipped with the repo. print(f"[1/3] Loading Jolia from {args.jolia_id} ...") jolia = AutoModel.from_pretrained(args.jolia_id, trust_remote_code=True).eval().to(args.image_device) # 2) Text: Qwen3-Embedding-8B + Jolia's bundled JoliaTextEncoder helper. # We get JoliaTextEncoder either from the local staging dir (if we're # running from rarm/scripts/hf_jolia) or by snapshotting the Hub repo. print(f"[2/3] Loading text encoder {args.text_id} ...") JoliaTextEncoder = _import_text_encoder(args.jolia_id) dtype = {"bf16": torch.bfloat16, "bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[ args.text_dtype ] text_encoder = ( JoliaTextEncoder.from_pretrained(args.text_id, dtype=dtype, context_length=jolia.config.text_context_length) .eval() .to(args.text_device) ) # 3) Zero-shot. The image here is a random (1, 11, 192, 192, 192) tensor # so the example is self-contained. In a real run, build it with # JoliaPreprocessor on a CT volume: # from preprocessing_jolia import JoliaPreprocessor # image = JoliaPreprocessor()(volume, resolution=(0.7, 0.7, 1.0)).unsqueeze(0) print(f"[3/3] Running zero-shot on {len(args.prompts)} prompts...") torch.manual_seed(0) image = torch.randn(1, 11, 192, 192, 192, device=args.image_device) with torch.no_grad(): text_features = text_encoder(args.prompts).to(args.image_device) # Default is calibrated CLIP logits: cosine * exp(logit_scale) + bias. logits = jolia.zero_shot(image, text_features) # (1, N) probs = torch.sigmoid(logits) # (1, N) per-pair "is it a match?" # Raw cosine in [-1, 1] for ranking-only use cases. cosine = jolia.zero_shot(image, text_features, calibrated=False) print("\n--- Global zero-shot (image CLS vs whole-volume text) ---") print("prompt logit sigmoid cosine") for p, lg, pr, c in zip(args.prompts, logits[0].tolist(), probs[0].tolist(), cosine[0].tolist()): print(f" {p:50s} {lg:+.4f} {pr:.4f} {c:+.4f}") best = int(logits[0].argmax()) print(f"best match (argmax logit): {args.prompts[best]!r}") # --- Per-organ (query-routed) zero-shot ---------------------------------- # Score short findings phrases against specific organs' query embeddings. # Uses the ParallelOrganCLIP text head — calibrated with each organ's own # trained temperature + bias. with torch.no_grad(): organ_text = text_encoder(args.organ_prompts).to(args.image_device) per_organ_logits = jolia.zero_shot_organs(image, organ_text, organs=args.organs) # {organ: (B, N)} — calibrated logits print("\n--- Per-organ zero-shot (organ-query vs per-organ text head, calibrated logits) ---") header = f" {'organ':12s} " + " ".join(f"{p[:18]:>18s}" for p in args.organ_prompts) print(header) for organ in args.organs: row = per_organ_logits[organ][0].tolist() print(f" {organ:12s} " + " ".join(f"{v:+18.4f}" for v in row)) def _import_text_encoder(jolia_id_or_dir: str): """Return the ``JoliaTextEncoder`` class, fetching the repo if needed.""" local = Path(jolia_id_or_dir) if local.is_dir(): sys.path.append(str(local)) else: from huggingface_hub import snapshot_download repo_dir = snapshot_download(jolia_id_or_dir) sys.path.append(repo_dir) from text_encoder_jolia import JoliaTextEncoder return JoliaTextEncoder if __name__ == "__main__": main()