sayehghp commited on
Commit
b94ff1b
·
1 Parent(s): d148884

Displaying attention image

Browse files
Files changed (4) hide show
  1. app.py +72 -27
  2. inference.py +56 -1
  3. vg_token_attention.py +396 -0
  4. vicca_api.py +28 -0
app.py CHANGED
@@ -5,45 +5,90 @@ import tempfile
5
  import gradio as gr
6
  from vicca_api import run_vicca
7
 
8
- def vicca_interface(image, text_prompt):
9
- """
10
- image: file from Gradio, we'll use its temp path
11
- text_prompt: report / description
12
- """
13
- # Gradio passes a PIL image or a file path depending on type
14
- # We'll request type='filepath' so this is already a path
15
- image_path = image
16
 
17
  result = run_vicca(
18
- image_path=image_path,
19
  text_prompt=text_prompt,
 
 
 
20
  )
21
 
22
- # You could also return the best generated image as an image output
23
- # For now, we expose the dict as JSON
24
- return result
 
 
 
 
 
 
25
 
26
  demo = gr.Interface(
27
  fn=vicca_interface,
28
  inputs=[
29
- gr.Image(type="filepath", label="Chest X-ray"),
30
- gr.Textbox(label="Report / pathology description", lines=3),
 
 
 
 
 
 
 
 
 
31
  ],
32
- outputs=gr.JSON(label="VICCA output"),
33
- title="VICCA – Visual Interpretation & Comprehension",
34
- description=(
35
- "Upload a chest X-ray and provide a text report / pathology description. "
36
- "The VICCA pipeline will run CXR generation, visual grounding, "
37
- "and ROI-level similarity scoring."
38
- ),
39
  )
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # if __name__ == "__main__":
42
  # demo.launch()
43
- if __name__ == "__main__":
44
- demo.launch(
45
- server_name="0.0.0.0",
46
- server_port=7860,
47
- debug=False
48
- )
49
 
 
5
  import gradio as gr
6
  from vicca_api import run_vicca
7
 
8
+ def vicca_interface(image, text_prompt, box_threshold=0.2, text_threshold=0.2, num_samples=4):
9
+ os.makedirs("uploads", exist_ok=True)
10
+ input_path = os.path.join("uploads", "input.png")
11
+ image.save(input_path)
 
 
 
 
12
 
13
  result = run_vicca(
14
+ image_path=input_path,
15
  text_prompt=text_prompt,
16
+ box_threshold=box_threshold,
17
+ text_threshold=text_threshold,
18
+ num_samples=num_samples,
19
  )
20
 
21
+ best_gen = result.get("best_generated_image_path")
22
+
23
+ attn = result.get("attention_overlays") or {}
24
+ combined = attn.get("combined")
25
+ per_term_dict = attn.get("per_term") or {}
26
+
27
+ gallery_items = [(p, term) for term, p in per_term_dict.items()]
28
+
29
+ return best_gen, combined, gallery_items, result
30
 
31
  demo = gr.Interface(
32
  fn=vicca_interface,
33
  inputs=[
34
+ gr.Image(type="pil", label="Input CXR"),
35
+ gr.Textbox(lines=3, label="Text prompt"),
36
+ gr.Slider(0.0, 1.0, value=0.2, label="Box threshold"),
37
+ gr.Slider(0.0, 1.0, value=0.2, label="Text threshold"),
38
+ gr.Slider(1, 8, step=1, value=4, label="Number of samples"),
39
+ ],
40
+ outputs=[
41
+ gr.Image(label="Best generated CXR"),
42
+ gr.Image(label="Combined attention heatmap"),
43
+ gr.Gallery(label="Per-term overlays").style(grid=[3], height=400),
44
+ gr.JSON(label="Raw VICCA output"),
45
  ],
46
+ title="VICCA",
 
 
 
 
 
 
47
  )
48
 
49
+ if __name__ == "__main__":
50
+ demo.launch(server_name="0.0.0.0", server_port=7860, debug=False)
51
+
52
+
53
+ # def vicca_interface(image, text_prompt):
54
+ # """
55
+ # image: file from Gradio, we'll use its temp path
56
+ # text_prompt: report / description
57
+ # """
58
+ # # Gradio passes a PIL image or a file path depending on type
59
+ # # We'll request type='filepath' so this is already a path
60
+ # image_path = image
61
+
62
+ # result = run_vicca(
63
+ # image_path=image_path,
64
+ # text_prompt=text_prompt,
65
+ # )
66
+
67
+ # # You could also return the best generated image as an image output
68
+ # # For now, we expose the dict as JSON
69
+ # return result
70
+
71
+ # demo = gr.Interface(
72
+ # fn=vicca_interface,
73
+ # inputs=[
74
+ # gr.Image(type="filepath", label="Chest X-ray"),
75
+ # gr.Textbox(label="Report / pathology description", lines=3),
76
+ # ],
77
+ # outputs=gr.JSON(label="VICCA output"),
78
+ # title="VICCA – Visual Interpretation & Comprehension",
79
+ # description=(
80
+ # "Upload a chest X-ray and provide a text report / pathology description. "
81
+ # "The VICCA pipeline will run CXR generation, visual grounding, "
82
+ # "and ROI-level similarity scoring."
83
+ # ),
84
+ # )
85
+
86
  # if __name__ == "__main__":
87
  # demo.launch()
88
+ # if __name__ == "__main__":
89
+ # demo.launch(
90
+ # server_name="0.0.0.0",
91
+ # server_port=7860,
92
+ # debug=False
93
+ # )
94
 
inference.py CHANGED
@@ -9,11 +9,13 @@ python inference.py \
9
 
10
  """
11
  import pandas as pd
 
12
  import time
13
  import cv2
14
  import sys
15
  import argparse
16
  from ast import literal_eval
 
17
 
18
  # sys.path.append('/home/gholipos-admin/Desktop/Thesis/Training_Code/VICCA')
19
  from pathlib import Path
@@ -101,7 +103,7 @@ from DETR import svc
101
  from DETR.arguments import get_args_parser as get_detr_args_parser
102
  from VG import localization
103
  from ssim import ssim
104
-
105
 
106
  def get_args_parser():
107
  parser = argparse.ArgumentParser('Set the Input', add_help=True)
@@ -120,6 +122,59 @@ def get_args_parser():
120
  help="Path to save generated files.")
121
  return parser
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  def extract_tensor(value):
124
  cleaned_value = value.replace('tensor(', '').replace(')', '')
125
  return literal_eval(cleaned_value)
 
9
 
10
  """
11
  import pandas as pd
12
+ import numpy as np
13
  import time
14
  import cv2
15
  import sys
16
  import argparse
17
  from ast import literal_eval
18
+ from nltk import tokenize
19
 
20
  # sys.path.append('/home/gholipos-admin/Desktop/Thesis/Training_Code/VICCA')
21
  from pathlib import Path
 
103
  from DETR.arguments import get_args_parser as get_detr_args_parser
104
  from VG import localization
105
  from ssim import ssim
106
+ from CheXbert.src.label import label
107
 
108
  def get_args_parser():
109
  parser = argparse.ArgumentParser('Set the Input', add_help=True)
 
122
  help="Path to save generated files.")
123
  return parser
124
 
125
+ path_list = ['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
126
+ 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis',
127
+ 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture',
128
+ 'Support Devices', 'No Finding']
129
+
130
+ # Cache CheXbert weights once at import time
131
+ CHEXBERT_WEIGHTS = get_weight("CheXbert/checkpoint/chexbert.pth")
132
+
133
+ # def chexbert_pathology(text):
134
+ # sentences = list(set(tokenize.sent_tokenize(text)))
135
+ # path_dict = []
136
+ # for sentence in sentences:
137
+ # sentence = sentence.replace('\n',' ')
138
+ # sentence = sentence.replace('\s+',' ')
139
+ # chexbert_weight_path = get_weight("CheXbert/checkpoint/chexbert.pth")
140
+ # # pathology = np.array(label("CheXbert/checkpoint/chexbert.pth", sentence)).T[0]
141
+ # pathology = np.array(label(chexbert_weight_path, sentence)).T[0]
142
+ # if pathology[-1]==1 or len(list(set(pathology)))==1 or not any(e==1 for e in pathology):
143
+ # pass
144
+ # else:
145
+ # indice = [i for i, e in enumerate(pathology) if e==1]
146
+ # for ind in indice:
147
+ # path_dict.append(path_list[ind])
148
+ # return path_dict
149
+ def chexbert_pathology(text: str):
150
+ """
151
+ Run CheXbert on the text and return a list of *positive* pathology labels,
152
+ deduplicated.
153
+ """
154
+ # If NLTK punkt ever becomes a problem on Spaces, replace this with a simple split.
155
+ # sentences = list(set(tokenize.sent_tokenize(text)))
156
+ sentences = [s.strip() for s in text.split(".") if s.strip()]
157
+
158
+ path_terms = set()
159
+
160
+ for sentence in sentences:
161
+ sentence = sentence.replace("\n", " ")
162
+ sentence = sentence.replace("\s+", " ")
163
+
164
+ # Run CheXbert
165
+ pathology = np.array(label(CHEXBERT_WEIGHTS, sentence)).T[0]
166
+
167
+ # Skip if: "No Finding" active, or all labels same, or no positives
168
+ if pathology[-1] == 1 or len(set(pathology)) == 1 or not any(e == 1 for e in pathology):
169
+ continue
170
+
171
+ # Collect positive indices
172
+ indices = [i for i, e in enumerate(pathology) if e == 1]
173
+ for ind in indices:
174
+ path_terms.add(path_list[ind])
175
+
176
+ return sorted(path_terms)
177
+
178
  def extract_tensor(value):
179
  cleaned_value = value.replace('tensor(', '').replace(')', '')
180
  return literal_eval(cleaned_value)
vg_token_attention.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vg_token_attention.py
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Token→region cross-attention visualization for GroundingDINO integrated as a helper.
5
+
6
+ Usage from other modules:
7
+ from vg_token_attention import run_token_ca_visualization
8
+
9
+ paths = run_token_ca_visualization(
10
+ cfg_path="VG/config/GroundingDINO_SwinT_OGC_2.py",
11
+ ckpt_path="VG/weights/checkpoint0399_log4.pth",
12
+ image_path=image_path,
13
+ prompt=text_prompt,
14
+ terms=chexbert_terms, # e.g. ["edema", "effusion"]
15
+ out_dir="outputs/attn_overlays",
16
+ device="cuda" or "cpu",
17
+ )
18
+ """
19
+
20
+ import os
21
+ import math
22
+ import re
23
+ import cv2
24
+ import torch
25
+ import numpy as np
26
+ import torch.nn.functional as F
27
+ from torch import nn
28
+ from PIL import Image
29
+ import torchvision.transforms as T
30
+
31
+ from VG.groundingdino.util.inference import load_model
32
+ from VG.groundingdino.util.misc import NestedTensor
33
+
34
+ from transformers import AutoTokenizer
35
+
36
+ DEVICE_DEFAULT = "cuda" if torch.cuda.is_available() else "cpu"
37
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
38
+ IMAGENET_STD = [0.229, 0.224, 0.225]
39
+
40
+
41
+ # -----------------------------
42
+ # Preprocess: PIL -> (tensor, mask)
43
+ # -----------------------------
44
+ def preprocess_image_fn_factory(device=DEVICE_DEFAULT, longest=1024, pad_divisor=32):
45
+ to_tensor = T.ToTensor()
46
+ normalize = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
47
+
48
+ def _resize_longest(pil_img: Image.Image, longest_side=1024):
49
+ w, h = pil_img.size
50
+ scale = float(longest_side) / max(w, h)
51
+ new_w, new_h = int(round(w * scale)), int(round(h * scale))
52
+ return pil_img.resize((new_w, new_h), Image.BICUBIC)
53
+
54
+ def preprocess_image_fn(pil_img: Image.Image):
55
+ img_resized = _resize_longest(pil_img, longest_side=longest)
56
+ x = normalize(to_tensor(img_resized)) # [3,H,W]
57
+ _, H, W = x.shape
58
+
59
+ # pad to /32 for backbone
60
+ H_pad = math.ceil(H / pad_divisor) * pad_divisor
61
+ W_pad = math.ceil(W / pad_divisor) * pad_divisor
62
+ pad_h, pad_w = H_pad - H, W_pad - W
63
+ x = F.pad(x, (0, pad_w, 0, pad_h), value=0.0) # [3,Hp,Wp]
64
+
65
+ # mask: True on padded pixels
66
+ mask = torch.zeros((H_pad, W_pad), dtype=torch.bool)
67
+ if pad_h > 0:
68
+ mask[H:, :] = True
69
+ if pad_w > 0:
70
+ mask[:, W:] = True
71
+
72
+ return x.unsqueeze(0).to(device), mask.unsqueeze(0).to(device)
73
+
74
+ return preprocess_image_fn
75
+
76
+
77
+ # -----------------------------
78
+ # Tokenizer (BiomedVLP-CXR-BERT)
79
+ # -----------------------------
80
+ BIOMEDVLP_TOKENIZER_PATH = "VG/weights/BiomedVLP-CXR-BERT/"
81
+
82
+ _tokenizer = AutoTokenizer.from_pretrained(BIOMEDVLP_TOKENIZER_PATH)
83
+
84
+
85
+ def tokenize_with_offsets(prompt: str, device=DEVICE_DEFAULT):
86
+ enc = _tokenizer(
87
+ prompt,
88
+ return_tensors="pt",
89
+ return_offsets_mapping=True,
90
+ add_special_tokens=True,
91
+ truncation=True,
92
+ )
93
+ tokens = _tokenizer.convert_ids_to_tokens(enc["input_ids"][0])
94
+ offsets = enc["offset_mapping"][0].tolist()
95
+ return {
96
+ "input_ids": enc["input_ids"].to(device),
97
+ "attention_mask": enc["attention_mask"].to(device),
98
+ "tokens": tokens,
99
+ "offsets": offsets,
100
+ }
101
+
102
+
103
+ def find_token_span_by_offsets(prompt: str, offsets, term: str):
104
+ s = prompt.lower()
105
+ t = term.lower()
106
+ m = re.search(r'\b' + re.escape(t) + r'\b', s) or re.search(re.escape(t), s)
107
+ if not m:
108
+ return []
109
+ a, b = m.start(), m.end()
110
+ idxs = []
111
+ for i, (u, v) in enumerate(offsets):
112
+ if (
113
+ u is None or v is None or
114
+ u < 0 or v < 0 or
115
+ (u == 0 and v == 0)
116
+ ):
117
+ continue
118
+ if not (v <= a or u >= b): # overlap with [a,b)
119
+ idxs.append(i)
120
+ return idxs
121
+
122
+
123
+ def model_span_indices_for_term(tokens, offsets, attn_T, term: str):
124
+ # 1) HF indices by offsets
125
+ raw_hf_idxs = find_token_span_by_offsets(
126
+ "".join(t if t != "[PAD]" else " " for t in tokens),
127
+ offsets,
128
+ term
129
+ )
130
+ if not raw_hf_idxs:
131
+ low = term.lower()
132
+ raw_hf_idxs = [i for i, t in enumerate(tokens) if low in t.lower()]
133
+
134
+ # 2) Map HF non-special → model positions 0..T-1
135
+ non_special_hf = []
136
+ for i, (tok_i, (u, v)) in enumerate(zip(tokens, offsets)):
137
+ if tok_i in ("[CLS]", "[SEP]", "[PAD]"):
138
+ continue
139
+ if u is None or v is None or u < 0 or v < 0 or (u == 0 and v == 0):
140
+ continue
141
+ non_special_hf.append(i)
142
+
143
+ non_special_hf = non_special_hf[:attn_T]
144
+ hf2model = {hf_idx: j for j, hf_idx in enumerate(non_special_hf)}
145
+ model_term_idxs = [hf2model[i] for i in raw_hf_idxs if i in hf2model]
146
+
147
+ return torch.tensor(model_term_idxs, dtype=torch.long)
148
+
149
+
150
+ # -----------------------------
151
+ # Cross-attention recorder
152
+ # -----------------------------
153
+ class CrossAttnRecorder:
154
+ def __init__(self, decoder_layers, attn_attr_name='ca_text'):
155
+ self.attn_weights = [] # list of [B, heads, Q, T]
156
+ self.handles = []
157
+ self._register(decoder_layers, attn_attr_name)
158
+
159
+ def _hook(self, module, input, output):
160
+ if isinstance(output, tuple) and len(output) >= 2:
161
+ attn_w = output[1]
162
+ elif hasattr(module, 'attn_output_weights'):
163
+ attn_w = module.attn_output_weights
164
+ else:
165
+ attn_w = None
166
+ if attn_w is not None:
167
+ self.attn_weights.append(attn_w.detach().to('cpu', dtype=torch.float32))
168
+
169
+ def _wrap_forward(self, mha_module: nn.MultiheadAttention):
170
+ orig_forward = mha_module.forward
171
+
172
+ def wrapped_forward(*args, **kwargs):
173
+ kwargs['need_weights'] = True
174
+ kwargs['average_attn_weights'] = False
175
+ return orig_forward(*args, **kwargs)
176
+
177
+ return orig_forward, wrapped_forward
178
+
179
+ def _register(self, decoder_layers, attn_attr_name):
180
+ for layer in decoder_layers:
181
+ attn_module = getattr(layer, attn_attr_name, None)
182
+ if attn_module is None:
183
+ continue
184
+ if isinstance(attn_module, nn.MultiheadAttention):
185
+ orig_fwd, wrapped = self._wrap_forward(attn_module)
186
+ attn_module.forward = wrapped
187
+ handle = attn_module.register_forward_hook(self._hook)
188
+ self.handles.append((attn_module, handle, orig_fwd))
189
+ else:
190
+ handle = attn_module.register_forward_hook(self._hook)
191
+ self.handles.append((attn_module, handle, None))
192
+
193
+ def close(self):
194
+ for attn_module, handle, orig_fwd in self.handles:
195
+ handle.remove()
196
+ if (orig_fwd is not None) and isinstance(attn_module, nn.MultiheadAttention):
197
+ attn_module.forward = orig_fwd
198
+
199
+
200
+ # -----------------------------
201
+ # Heatmap helpers
202
+ # -----------------------------
203
+ def boxes_to_heatmap(boxes_xyxy, weights, hw, score_scale=None, blur_ksize=51, blur_sigma=0):
204
+ H, W = hw
205
+ heat = np.zeros((H, W), dtype=np.float32)
206
+
207
+ w = weights.detach().cpu().numpy()
208
+ if score_scale is not None:
209
+ s = score_scale.detach().cpu().numpy()
210
+ w = w * s
211
+
212
+ for i, box in enumerate(boxes_xyxy):
213
+ x1, y1, x2, y2 = map(int, box.tolist())
214
+ x1 = max(0, min(W - 1, x1)); x2 = max(0, min(W - 1, x2))
215
+ y1 = max(0, min(H - 1, y1)); y2 = max(0, min(H - 1, y2))
216
+ if x2 <= x1 or y2 <= y1:
217
+ continue
218
+ heat[y1:y2, x1:x2] += float(w[i])
219
+
220
+ if blur_ksize is not None and blur_ksize >= 3 and blur_ksize % 2 == 1:
221
+ heat = cv2.GaussianBlur(heat, (blur_ksize, blur_ksize), blur_sigma)
222
+
223
+ mx = heat.max()
224
+ if mx > 1e-6:
225
+ heat /= mx
226
+ return heat
227
+
228
+
229
+ def overlay_heatmap(img_pil: Image.Image, heatmap, alpha=0.45, cmap=cv2.COLORMAP_JET):
230
+ img = np.array(img_pil.convert("RGB"))
231
+ H, W = img.shape[:2]
232
+ h = (np.clip(heatmap, 0, 1) * 255).astype(np.uint8)
233
+ h_color = cv2.applyColorMap(h, cmap)[:, :, ::-1]
234
+ blended = cv2.addWeighted(h_color, alpha, img, 1 - alpha, 0)
235
+ return Image.fromarray(blended)
236
+
237
+
238
+ def load_image_keep_longest(path, longest=1024):
239
+ img = Image.open(path).convert("RGB")
240
+ w, h = img.size
241
+ s = float(longest) / max(w, h)
242
+ new_w, new_h = int(round(w * s)), int(round(h * s))
243
+ return img.resize((new_w, new_h), Image.BICUBIC)
244
+
245
+
246
+ # -----------------------------
247
+ # Main helper: one call from API
248
+ # -----------------------------
249
+ @torch.no_grad()
250
+ def run_token_ca_visualization(
251
+ cfg_path: str,
252
+ ckpt_path: str,
253
+ image_path: str,
254
+ prompt: str,
255
+ terms,
256
+ out_dir: str,
257
+ device: str = DEVICE_DEFAULT,
258
+ score_thresh: float = 0.25,
259
+ topk: int = 100,
260
+ term_agg: str = "mean", # "mean" | "max" | "sum"
261
+ save_per_term: bool = True,
262
+ ):
263
+ """
264
+ Returns:
265
+ {
266
+ "combined": <path_to_combined_overlay>,
267
+ "per_term": { term: path_to_overlay, ... }
268
+ }
269
+ """
270
+ if isinstance(terms, str):
271
+ terms = [terms]
272
+ terms = [t.strip() for t in terms if t and t.strip()]
273
+ if not terms:
274
+ raise ValueError("No terms provided for attention visualization.")
275
+
276
+ device = device or DEVICE_DEFAULT
277
+ model = load_model(cfg_path, ckpt_path).to(device).eval()
278
+ preprocess_image_fn = preprocess_image_fn_factory(device=device, longest=1024, pad_divisor=32)
279
+
280
+ img_pil = load_image_keep_longest(image_path, longest=1024)
281
+
282
+ os.makedirs(out_dir, exist_ok=True)
283
+ base_name = os.path.splitext(os.path.basename(image_path))[0]
284
+ combined_path = os.path.join(out_dir, f"{base_name}__attn_combined.png")
285
+
286
+ # ---- hook cross-attn
287
+ decoder_layers = model.transformer.decoder.layers
288
+ recorder = CrossAttnRecorder(decoder_layers, attn_attr_name="ca_text")
289
+
290
+ # preprocess → NestedTensor
291
+ img_tensor, mask = preprocess_image_fn(img_pil)
292
+ samples = NestedTensor(img_tensor, mask)
293
+
294
+ outputs = model(samples, captions=[prompt])
295
+
296
+ # decode boxes
297
+ pred_logits = outputs["pred_logits"]
298
+ pred_boxes = outputs["pred_boxes"]
299
+ logits = pred_logits[0].sigmoid()
300
+ scores, _ = logits.max(dim=1)
301
+
302
+ keep = torch.nonzero(scores > score_thresh).squeeze(1)
303
+ if keep.numel() == 0:
304
+ keep = torch.argsort(scores, descending=True)[:min(topk, scores.numel())]
305
+ else:
306
+ keep = keep[:topk]
307
+
308
+ W, H = img_pil.size
309
+ boxes_cxcywh = pred_boxes[0][keep]
310
+ cx, cy, w, h = boxes_cxcywh.unbind(-1)
311
+ x1 = (cx - 0.5 * w) * W
312
+ y1 = (cy - 0.5 * h) * H
313
+ x2 = (cx + 0.5 * w) * W
314
+ y2 = (cy + 0.5 * h) * H
315
+ boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=-1)
316
+ kept_scores = scores[keep]
317
+
318
+ keep_cpu = keep.cpu()
319
+
320
+ if len(recorder.attn_weights) == 0:
321
+ recorder.close()
322
+ raise RuntimeError("No attention weights captured. Check that 'ca_text' exists.")
323
+ attn_qt_layers = []
324
+ for w_att in recorder.attn_weights:
325
+ w_att = w_att.squeeze(0).mean(0) # [Q,T]
326
+ attn_qt_layers.append(w_att)
327
+ attn_qt = torch.stack(attn_qt_layers, 0).mean(0) # [Q,T]
328
+ recorder.close()
329
+
330
+ # tokenize prompt
331
+ tok = tokenize_with_offsets(prompt, device="cpu")
332
+ tokens, offsets = tok["tokens"], tok["offsets"]
333
+ T_text = attn_qt.shape[1]
334
+
335
+ per_term_attn_kept = {}
336
+ per_term_attn_full = {}
337
+
338
+ for t in terms:
339
+ model_idxs = model_span_indices_for_term(tokens, offsets, T_text, t)
340
+ if model_idxs.numel() == 0:
341
+ continue
342
+ attn_per_query = attn_qt[:, model_idxs].mean(1) # [Q]
343
+ attn_kept = attn_per_query[keep_cpu]
344
+ attn_kept = (attn_kept - attn_kept.min()) / (attn_kept.max() - attn_kept.min() + 1e-6)
345
+ per_term_attn_kept[t] = attn_kept
346
+ per_term_attn_full[t] = attn_per_query
347
+
348
+ if not per_term_attn_kept:
349
+ raise ValueError(f"None of the terms were found in the first T tokens: {terms}")
350
+
351
+ # aggregate terms
352
+ agg = None
353
+ for t, v in per_term_attn_full.items():
354
+ agg = v if agg is None else (
355
+ agg + v if term_agg == "sum"
356
+ else torch.maximum(agg, v) if term_agg == "max"
357
+ else (agg + v)
358
+ )
359
+ if term_agg == "mean":
360
+ agg = agg / float(len(per_term_attn_full))
361
+
362
+ agg_kept = agg[keep_cpu]
363
+ agg_kept = (agg_kept - agg_kept.min()) / (agg_kept.max() - agg_kept.min() + 1e-6)
364
+
365
+ heat = boxes_to_heatmap(
366
+ boxes_xyxy=boxes_xyxy,
367
+ weights=agg_kept,
368
+ hw=(H, W),
369
+ score_scale=kept_scores,
370
+ blur_ksize=61,
371
+ blur_sigma=0,
372
+ )
373
+ overlay = overlay_heatmap(img_pil, heat, alpha=0.45)
374
+ overlay.save(combined_path)
375
+
376
+ per_term_paths = {}
377
+ if save_per_term and len(per_term_attn_kept) > 1:
378
+ for t, v in per_term_attn_kept.items():
379
+ heat_t = boxes_to_heatmap(
380
+ boxes_xyxy=boxes_xyxy,
381
+ weights=v,
382
+ hw=(H, W),
383
+ score_scale=kept_scores,
384
+ blur_ksize=61,
385
+ blur_sigma=0,
386
+ )
387
+ ov_t = overlay_heatmap(img_pil, heat_t, alpha=0.45)
388
+ term_tag = re.sub(r"[^a-zA-Z0-9]+", "_", t.lower())[:32]
389
+ p = os.path.join(out_dir, f"{base_name}__{term_tag}.png")
390
+ ov_t.save(p)
391
+ per_term_paths[t] = p
392
+
393
+ return {
394
+ "combined": combined_path,
395
+ "per_term": per_term_paths,
396
+ }
vicca_api.py CHANGED
@@ -3,8 +3,11 @@ import os
3
  import time
4
  import cv2
5
  import pandas as pd
 
6
 
7
  from weights_utils import ensure_all_vicca_weights, get_weight
 
 
8
 
9
  # Make sure all heavy weights are present once per container
10
  ensure_all_vicca_weights()
@@ -14,6 +17,7 @@ from inference import (
14
  cal_shift,
15
  get_local_bbox,
16
  extract_tensor,
 
17
  )
18
 
19
  def run_vicca(
@@ -23,6 +27,7 @@ def run_vicca(
23
  text_threshold: float = 0.2,
24
  num_samples: int = 4,
25
  output_path: str = "CXRGen/test/samples/output/",
 
26
  ):
27
  """
28
  Top-level VICCA API used by app.py / Gradio.
@@ -87,6 +92,28 @@ def run_vicca(
87
  score = ssim(roi_org, roi_gen)
88
  ssim_scores.append(score)
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  return {
91
  "boxes": boxes,
92
  "logits": logits,
@@ -95,4 +122,5 @@ def run_vicca(
95
  "shift_x": sx,
96
  "shift_y": sy,
97
  "best_generated_image_path": max_sim_gen_path,
 
98
  }
 
3
  import time
4
  import cv2
5
  import pandas as pd
6
+ import torch
7
 
8
  from weights_utils import ensure_all_vicca_weights, get_weight
9
+ from vg_token_attention import run_token_ca_visualization
10
+
11
 
12
  # Make sure all heavy weights are present once per container
13
  ensure_all_vicca_weights()
 
17
  cal_shift,
18
  get_local_bbox,
19
  extract_tensor,
20
+ chexbert_pathology,
21
  )
22
 
23
  def run_vicca(
 
27
  text_threshold: float = 0.2,
28
  num_samples: int = 4,
29
  output_path: str = "CXRGen/test/samples/output/",
30
+ attn_terms=None,
31
  ):
32
  """
33
  Top-level VICCA API used by app.py / Gradio.
 
92
  score = ssim(roi_org, roi_gen)
93
  ssim_scores.append(score)
94
 
95
+ # Optional: attention visualization for terms (e.g. from CheXbert)
96
+ attn_paths = None
97
+ attn_terms = chexbert_pathology(text_prompt)
98
+ if attn_terms:
99
+ cfg_path = "VG/config/GroundingDINO_SwinT_OGC_2.py"
100
+ vg_ckpt_path = get_weight("VG/weights/checkpoint0399_log4.pth")
101
+ attn_out_dir = os.path.join(output_path, "attn_overlays")
102
+
103
+ attn_paths = run_token_ca_visualization(
104
+ cfg_path=cfg_path,
105
+ ckpt_path=vg_ckpt_path,
106
+ image_path=image_path, # or max_sim_gen_path if you prefer generated CXR
107
+ prompt=text_prompt,
108
+ terms=attn_terms,
109
+ out_dir=attn_out_dir,
110
+ device="cuda" if torch.cuda.is_available() else "cpu",
111
+ score_thresh=0.25,
112
+ topk=100,
113
+ term_agg="mean",
114
+ save_per_term=True,
115
+ )
116
+
117
  return {
118
  "boxes": boxes,
119
  "logits": logits,
 
122
  "shift_x": sx,
123
  "shift_y": sy,
124
  "best_generated_image_path": max_sim_gen_path,
125
+ "attention_overlays": attn_paths,
126
  }