File size: 5,859 Bytes
6858e35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""End-to-end zero-shot example for Jolia.

What this shows:
  1. Load the Jolia vision backbone from the Hub (self-contained, trust_remote_code).
  2. Load the paired text encoder (Qwen3-Embedding-8B) and Jolia's bundled
     ``JoliaTextEncoder`` helper.
  3. Encode an image + a list of text prompts into a shared 576-d space and
     compute zero-shot scores (cosine or calibrated logits).

This file works two ways:
  * On the Hub, alongside the repo files (``snapshot_download`` + ``sys.path``).
  * In the staging dir locally, as in this repo's ``scripts/hf_jolia/``.

Heads-up: Qwen3-Embedding-8B is ~18 GB; it downloads to your HF cache on
first use. Pass ``--text-dtype bf16 --text-device cuda`` to fit on a single
GPU.
"""

from __future__ import annotations

import argparse
import sys
from pathlib import Path

import torch
from transformers import AutoModel


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--jolia-id", default="raidium/Jolia", help="Jolia repo id (or local dir).")
    parser.add_argument("--text-id", default="Qwen/Qwen3-Embedding-8B")
    parser.add_argument(
        "--text-dtype",
        choices=("float32", "float16", "bfloat16", "bf16"),
        default="bfloat16",
    )
    parser.add_argument("--text-device", default="cpu", help='e.g. "cuda", "cuda:0", "cpu".')
    parser.add_argument("--image-device", default="cpu")
    parser.add_argument(
        "--prompts",
        nargs="+",
        default=[
            "a CT showing a liver lesion",
            "a CT showing pneumonia",
            "a normal abdominal CT",
        ],
    )
    parser.add_argument(
        "--organ-prompts",
        nargs="+",
        default=["a lesion", "an enlarged organ", "looks normal"],
        help="Short organ-specific findings phrases for per-organ zero-shot.",
    )
    parser.add_argument(
        "--organs",
        nargs="+",
        default=["liver", "spleen", "lungs", "pancreas"],
        help="Organs to score the --organ-prompts against (per-organ zero-shot).",
    )
    args = parser.parse_args()

    # 1) Vision: Jolia from the Hub. trust_remote_code=True pulls the small
    # self-contained modeling code shipped with the repo.
    print(f"[1/3] Loading Jolia from {args.jolia_id} ...")
    jolia = AutoModel.from_pretrained(args.jolia_id, trust_remote_code=True).eval().to(args.image_device)

    # 2) Text: Qwen3-Embedding-8B + Jolia's bundled JoliaTextEncoder helper.
    # We get JoliaTextEncoder either from the local staging dir (if we're
    # running from rarm/scripts/hf_jolia) or by snapshotting the Hub repo.
    print(f"[2/3] Loading text encoder {args.text_id} ...")
    JoliaTextEncoder = _import_text_encoder(args.jolia_id)
    dtype = {"bf16": torch.bfloat16, "bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[
        args.text_dtype
    ]
    text_encoder = (
        JoliaTextEncoder.from_pretrained(args.text_id, dtype=dtype, context_length=jolia.config.text_context_length)
        .eval()
        .to(args.text_device)
    )

    # 3) Zero-shot. The image here is a random (1, 11, 192, 192, 192) tensor
    # so the example is self-contained. In a real run, build it with
    # JoliaPreprocessor on a CT volume:
    #   from preprocessing_jolia import JoliaPreprocessor
    #   image = JoliaPreprocessor()(volume, resolution=(0.7, 0.7, 1.0)).unsqueeze(0)
    print(f"[3/3] Running zero-shot on {len(args.prompts)} prompts...")
    torch.manual_seed(0)
    image = torch.randn(1, 11, 192, 192, 192, device=args.image_device)

    with torch.no_grad():
        text_features = text_encoder(args.prompts).to(args.image_device)
        # Default is calibrated CLIP logits: cosine * exp(logit_scale) + bias.
        logits = jolia.zero_shot(image, text_features)              # (1, N)
        probs = torch.sigmoid(logits)                                # (1, N) per-pair "is it a match?"
        # Raw cosine in [-1, 1] for ranking-only use cases.
        cosine = jolia.zero_shot(image, text_features, calibrated=False)

    print("\n--- Global zero-shot (image CLS vs whole-volume text) ---")
    print("prompt                                              logit    sigmoid    cosine")
    for p, lg, pr, c in zip(args.prompts, logits[0].tolist(), probs[0].tolist(), cosine[0].tolist()):
        print(f"  {p:50s}  {lg:+.4f}    {pr:.4f}    {c:+.4f}")
    best = int(logits[0].argmax())
    print(f"best match (argmax logit): {args.prompts[best]!r}")

    # --- Per-organ (query-routed) zero-shot ----------------------------------
    # Score short findings phrases against specific organs' query embeddings.
    # Uses the ParallelOrganCLIP text head — calibrated with each organ's own
    # trained temperature + bias.
    with torch.no_grad():
        organ_text = text_encoder(args.organ_prompts).to(args.image_device)
        per_organ_logits = jolia.zero_shot_organs(image, organ_text, organs=args.organs)
        # {organ: (B, N)} — calibrated logits

    print("\n--- Per-organ zero-shot (organ-query vs per-organ text head, calibrated logits) ---")
    header = f"  {'organ':12s}  " + "  ".join(f"{p[:18]:>18s}" for p in args.organ_prompts)
    print(header)
    for organ in args.organs:
        row = per_organ_logits[organ][0].tolist()
        print(f"  {organ:12s}  " + "  ".join(f"{v:+18.4f}" for v in row))


def _import_text_encoder(jolia_id_or_dir: str):
    """Return the ``JoliaTextEncoder`` class, fetching the repo if needed."""
    local = Path(jolia_id_or_dir)
    if local.is_dir():
        sys.path.append(str(local))
    else:
        from huggingface_hub import snapshot_download

        repo_dir = snapshot_download(jolia_id_or_dir)
        sys.path.append(repo_dir)
    from text_encoder_jolia import JoliaTextEncoder

    return JoliaTextEncoder


if __name__ == "__main__":
    main()