Feature Extraction
Transformers
Safetensors
jolia
medical
radiology
ct
3d
vision
foundation-model
self-supervised
custom_code
Instructions to use raidium/Jolia with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use raidium/Jolia with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="raidium/Jolia", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("raidium/Jolia", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 5,859 Bytes
6858e35 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | """End-to-end zero-shot example for Jolia.
What this shows:
1. Load the Jolia vision backbone from the Hub (self-contained, trust_remote_code).
2. Load the paired text encoder (Qwen3-Embedding-8B) and Jolia's bundled
``JoliaTextEncoder`` helper.
3. Encode an image + a list of text prompts into a shared 576-d space and
compute zero-shot scores (cosine or calibrated logits).
This file works two ways:
* On the Hub, alongside the repo files (``snapshot_download`` + ``sys.path``).
* In the staging dir locally, as in this repo's ``scripts/hf_jolia/``.
Heads-up: Qwen3-Embedding-8B is ~18 GB; it downloads to your HF cache on
first use. Pass ``--text-dtype bf16 --text-device cuda`` to fit on a single
GPU.
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
import torch
from transformers import AutoModel
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--jolia-id", default="raidium/Jolia", help="Jolia repo id (or local dir).")
parser.add_argument("--text-id", default="Qwen/Qwen3-Embedding-8B")
parser.add_argument(
"--text-dtype",
choices=("float32", "float16", "bfloat16", "bf16"),
default="bfloat16",
)
parser.add_argument("--text-device", default="cpu", help='e.g. "cuda", "cuda:0", "cpu".')
parser.add_argument("--image-device", default="cpu")
parser.add_argument(
"--prompts",
nargs="+",
default=[
"a CT showing a liver lesion",
"a CT showing pneumonia",
"a normal abdominal CT",
],
)
parser.add_argument(
"--organ-prompts",
nargs="+",
default=["a lesion", "an enlarged organ", "looks normal"],
help="Short organ-specific findings phrases for per-organ zero-shot.",
)
parser.add_argument(
"--organs",
nargs="+",
default=["liver", "spleen", "lungs", "pancreas"],
help="Organs to score the --organ-prompts against (per-organ zero-shot).",
)
args = parser.parse_args()
# 1) Vision: Jolia from the Hub. trust_remote_code=True pulls the small
# self-contained modeling code shipped with the repo.
print(f"[1/3] Loading Jolia from {args.jolia_id} ...")
jolia = AutoModel.from_pretrained(args.jolia_id, trust_remote_code=True).eval().to(args.image_device)
# 2) Text: Qwen3-Embedding-8B + Jolia's bundled JoliaTextEncoder helper.
# We get JoliaTextEncoder either from the local staging dir (if we're
# running from rarm/scripts/hf_jolia) or by snapshotting the Hub repo.
print(f"[2/3] Loading text encoder {args.text_id} ...")
JoliaTextEncoder = _import_text_encoder(args.jolia_id)
dtype = {"bf16": torch.bfloat16, "bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[
args.text_dtype
]
text_encoder = (
JoliaTextEncoder.from_pretrained(args.text_id, dtype=dtype, context_length=jolia.config.text_context_length)
.eval()
.to(args.text_device)
)
# 3) Zero-shot. The image here is a random (1, 11, 192, 192, 192) tensor
# so the example is self-contained. In a real run, build it with
# JoliaPreprocessor on a CT volume:
# from preprocessing_jolia import JoliaPreprocessor
# image = JoliaPreprocessor()(volume, resolution=(0.7, 0.7, 1.0)).unsqueeze(0)
print(f"[3/3] Running zero-shot on {len(args.prompts)} prompts...")
torch.manual_seed(0)
image = torch.randn(1, 11, 192, 192, 192, device=args.image_device)
with torch.no_grad():
text_features = text_encoder(args.prompts).to(args.image_device)
# Default is calibrated CLIP logits: cosine * exp(logit_scale) + bias.
logits = jolia.zero_shot(image, text_features) # (1, N)
probs = torch.sigmoid(logits) # (1, N) per-pair "is it a match?"
# Raw cosine in [-1, 1] for ranking-only use cases.
cosine = jolia.zero_shot(image, text_features, calibrated=False)
print("\n--- Global zero-shot (image CLS vs whole-volume text) ---")
print("prompt logit sigmoid cosine")
for p, lg, pr, c in zip(args.prompts, logits[0].tolist(), probs[0].tolist(), cosine[0].tolist()):
print(f" {p:50s} {lg:+.4f} {pr:.4f} {c:+.4f}")
best = int(logits[0].argmax())
print(f"best match (argmax logit): {args.prompts[best]!r}")
# --- Per-organ (query-routed) zero-shot ----------------------------------
# Score short findings phrases against specific organs' query embeddings.
# Uses the ParallelOrganCLIP text head — calibrated with each organ's own
# trained temperature + bias.
with torch.no_grad():
organ_text = text_encoder(args.organ_prompts).to(args.image_device)
per_organ_logits = jolia.zero_shot_organs(image, organ_text, organs=args.organs)
# {organ: (B, N)} — calibrated logits
print("\n--- Per-organ zero-shot (organ-query vs per-organ text head, calibrated logits) ---")
header = f" {'organ':12s} " + " ".join(f"{p[:18]:>18s}" for p in args.organ_prompts)
print(header)
for organ in args.organs:
row = per_organ_logits[organ][0].tolist()
print(f" {organ:12s} " + " ".join(f"{v:+18.4f}" for v in row))
def _import_text_encoder(jolia_id_or_dir: str):
"""Return the ``JoliaTextEncoder`` class, fetching the repo if needed."""
local = Path(jolia_id_or_dir)
if local.is_dir():
sys.path.append(str(local))
else:
from huggingface_hub import snapshot_download
repo_dir = snapshot_download(jolia_id_or_dir)
sys.path.append(repo_dir)
from text_encoder_jolia import JoliaTextEncoder
return JoliaTextEncoder
if __name__ == "__main__":
main()
|