|
|
|
|
|
|
|
|
""" |
|
|
Token→region cross-attention visualization for GroundingDINO integrated as a helper. |
|
|
|
|
|
Usage from other modules: |
|
|
from vg_token_attention import run_token_ca_visualization |
|
|
|
|
|
paths = run_token_ca_visualization( |
|
|
cfg_path="VG/config/GroundingDINO_SwinT_OGC_2.py", |
|
|
ckpt_path="VG/weights/checkpoint0399_log4.pth", |
|
|
image_path=image_path, |
|
|
prompt=text_prompt, |
|
|
terms=chexbert_terms, # e.g. ["edema", "effusion"] |
|
|
out_dir="outputs/attn_overlays", |
|
|
device="cuda" or "cpu", |
|
|
) |
|
|
""" |
|
|
|
|
|
import os |
|
|
import math |
|
|
import re |
|
|
import cv2 |
|
|
import torch |
|
|
import numpy as np |
|
|
import torch.nn.functional as F |
|
|
from torch import nn |
|
|
from PIL import Image |
|
|
import torchvision.transforms as T |
|
|
|
|
|
from VG.groundingdino.util.inference import load_model |
|
|
from VG.groundingdino.util.misc import NestedTensor |
|
|
|
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
DEVICE_DEFAULT = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
IMAGENET_MEAN = [0.485, 0.456, 0.406] |
|
|
IMAGENET_STD = [0.229, 0.224, 0.225] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_image_fn_factory(device=DEVICE_DEFAULT, longest=1024, pad_divisor=32): |
|
|
to_tensor = T.ToTensor() |
|
|
normalize = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) |
|
|
|
|
|
def _resize_longest(pil_img: Image.Image, longest_side=1024): |
|
|
w, h = pil_img.size |
|
|
scale = float(longest_side) / max(w, h) |
|
|
new_w, new_h = int(round(w * scale)), int(round(h * scale)) |
|
|
return pil_img.resize((new_w, new_h), Image.BICUBIC) |
|
|
|
|
|
def preprocess_image_fn(pil_img: Image.Image): |
|
|
img_resized = _resize_longest(pil_img, longest_side=longest) |
|
|
x = normalize(to_tensor(img_resized)) |
|
|
_, H, W = x.shape |
|
|
|
|
|
|
|
|
H_pad = math.ceil(H / pad_divisor) * pad_divisor |
|
|
W_pad = math.ceil(W / pad_divisor) * pad_divisor |
|
|
pad_h, pad_w = H_pad - H, W_pad - W |
|
|
x = F.pad(x, (0, pad_w, 0, pad_h), value=0.0) |
|
|
|
|
|
|
|
|
mask = torch.zeros((H_pad, W_pad), dtype=torch.bool) |
|
|
if pad_h > 0: |
|
|
mask[H:, :] = True |
|
|
if pad_w > 0: |
|
|
mask[:, W:] = True |
|
|
|
|
|
return x.unsqueeze(0).to(device), mask.unsqueeze(0).to(device) |
|
|
|
|
|
return preprocess_image_fn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BIOMEDVLP_TOKENIZER_PATH = "VG/weights/BiomedVLP-CXR-BERT/" |
|
|
|
|
|
_tokenizer = AutoTokenizer.from_pretrained(BIOMEDVLP_TOKENIZER_PATH) |
|
|
|
|
|
|
|
|
def tokenize_with_offsets(prompt: str, device=DEVICE_DEFAULT): |
|
|
enc = _tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
return_offsets_mapping=True, |
|
|
add_special_tokens=True, |
|
|
truncation=True, |
|
|
) |
|
|
tokens = _tokenizer.convert_ids_to_tokens(enc["input_ids"][0]) |
|
|
offsets = enc["offset_mapping"][0].tolist() |
|
|
return { |
|
|
"input_ids": enc["input_ids"].to(device), |
|
|
"attention_mask": enc["attention_mask"].to(device), |
|
|
"tokens": tokens, |
|
|
"offsets": offsets, |
|
|
} |
|
|
|
|
|
|
|
|
def find_token_span_by_offsets(prompt: str, offsets, term: str): |
|
|
s = prompt.lower() |
|
|
t = term.lower() |
|
|
m = re.search(r'\b' + re.escape(t) + r'\b', s) or re.search(re.escape(t), s) |
|
|
if not m: |
|
|
return [] |
|
|
a, b = m.start(), m.end() |
|
|
idxs = [] |
|
|
for i, (u, v) in enumerate(offsets): |
|
|
if ( |
|
|
u is None or v is None or |
|
|
u < 0 or v < 0 or |
|
|
(u == 0 and v == 0) |
|
|
): |
|
|
continue |
|
|
if not (v <= a or u >= b): |
|
|
idxs.append(i) |
|
|
return idxs |
|
|
|
|
|
|
|
|
def model_span_indices_for_term(tokens, offsets, attn_T, term: str): |
|
|
|
|
|
raw_hf_idxs = find_token_span_by_offsets( |
|
|
"".join(t if t != "[PAD]" else " " for t in tokens), |
|
|
offsets, |
|
|
term |
|
|
) |
|
|
if not raw_hf_idxs: |
|
|
low = term.lower() |
|
|
raw_hf_idxs = [i for i, t in enumerate(tokens) if low in t.lower()] |
|
|
|
|
|
|
|
|
non_special_hf = [] |
|
|
for i, (tok_i, (u, v)) in enumerate(zip(tokens, offsets)): |
|
|
if tok_i in ("[CLS]", "[SEP]", "[PAD]"): |
|
|
continue |
|
|
if u is None or v is None or u < 0 or v < 0 or (u == 0 and v == 0): |
|
|
continue |
|
|
non_special_hf.append(i) |
|
|
|
|
|
non_special_hf = non_special_hf[:attn_T] |
|
|
hf2model = {hf_idx: j for j, hf_idx in enumerate(non_special_hf)} |
|
|
model_term_idxs = [hf2model[i] for i in raw_hf_idxs if i in hf2model] |
|
|
|
|
|
return torch.tensor(model_term_idxs, dtype=torch.long) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CrossAttnRecorder: |
|
|
def __init__(self, decoder_layers, attn_attr_name='ca_text'): |
|
|
self.attn_weights = [] |
|
|
self.handles = [] |
|
|
self._register(decoder_layers, attn_attr_name) |
|
|
|
|
|
def _hook(self, module, input, output): |
|
|
if isinstance(output, tuple) and len(output) >= 2: |
|
|
attn_w = output[1] |
|
|
elif hasattr(module, 'attn_output_weights'): |
|
|
attn_w = module.attn_output_weights |
|
|
else: |
|
|
attn_w = None |
|
|
if attn_w is not None: |
|
|
self.attn_weights.append(attn_w.detach().to('cpu', dtype=torch.float32)) |
|
|
|
|
|
def _wrap_forward(self, mha_module: nn.MultiheadAttention): |
|
|
orig_forward = mha_module.forward |
|
|
|
|
|
def wrapped_forward(*args, **kwargs): |
|
|
kwargs['need_weights'] = True |
|
|
kwargs['average_attn_weights'] = False |
|
|
return orig_forward(*args, **kwargs) |
|
|
|
|
|
return orig_forward, wrapped_forward |
|
|
|
|
|
def _register(self, decoder_layers, attn_attr_name): |
|
|
for layer in decoder_layers: |
|
|
attn_module = getattr(layer, attn_attr_name, None) |
|
|
if attn_module is None: |
|
|
continue |
|
|
if isinstance(attn_module, nn.MultiheadAttention): |
|
|
orig_fwd, wrapped = self._wrap_forward(attn_module) |
|
|
attn_module.forward = wrapped |
|
|
handle = attn_module.register_forward_hook(self._hook) |
|
|
self.handles.append((attn_module, handle, orig_fwd)) |
|
|
else: |
|
|
handle = attn_module.register_forward_hook(self._hook) |
|
|
self.handles.append((attn_module, handle, None)) |
|
|
|
|
|
def close(self): |
|
|
for attn_module, handle, orig_fwd in self.handles: |
|
|
handle.remove() |
|
|
if (orig_fwd is not None) and isinstance(attn_module, nn.MultiheadAttention): |
|
|
attn_module.forward = orig_fwd |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def boxes_to_heatmap(boxes_xyxy, weights, hw, score_scale=None, blur_ksize=51, blur_sigma=0): |
|
|
H, W = hw |
|
|
heat = np.zeros((H, W), dtype=np.float32) |
|
|
|
|
|
w = weights.detach().cpu().numpy() |
|
|
if score_scale is not None: |
|
|
s = score_scale.detach().cpu().numpy() |
|
|
w = w * s |
|
|
|
|
|
for i, box in enumerate(boxes_xyxy): |
|
|
x1, y1, x2, y2 = map(int, box.tolist()) |
|
|
x1 = max(0, min(W - 1, x1)); x2 = max(0, min(W - 1, x2)) |
|
|
y1 = max(0, min(H - 1, y1)); y2 = max(0, min(H - 1, y2)) |
|
|
if x2 <= x1 or y2 <= y1: |
|
|
continue |
|
|
heat[y1:y2, x1:x2] += float(w[i]) |
|
|
|
|
|
if blur_ksize is not None and blur_ksize >= 3 and blur_ksize % 2 == 1: |
|
|
heat = cv2.GaussianBlur(heat, (blur_ksize, blur_ksize), blur_sigma) |
|
|
|
|
|
mx = heat.max() |
|
|
if mx > 1e-6: |
|
|
heat /= mx |
|
|
return heat |
|
|
|
|
|
|
|
|
def overlay_heatmap(img_pil: Image.Image, heatmap, alpha=0.45, cmap=cv2.COLORMAP_JET): |
|
|
img = np.array(img_pil.convert("RGB")) |
|
|
H, W = img.shape[:2] |
|
|
h = (np.clip(heatmap, 0, 1) * 255).astype(np.uint8) |
|
|
h_color = cv2.applyColorMap(h, cmap)[:, :, ::-1] |
|
|
blended = cv2.addWeighted(h_color, alpha, img, 1 - alpha, 0) |
|
|
return Image.fromarray(blended) |
|
|
|
|
|
|
|
|
def load_image_keep_longest(path, longest=1024): |
|
|
img = Image.open(path).convert("RGB") |
|
|
w, h = img.size |
|
|
s = float(longest) / max(w, h) |
|
|
new_w, new_h = int(round(w * s)), int(round(h * s)) |
|
|
return img.resize((new_w, new_h), Image.BICUBIC) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def run_token_ca_visualization( |
|
|
cfg_path: str, |
|
|
ckpt_path: str, |
|
|
image_path: str, |
|
|
prompt: str, |
|
|
terms, |
|
|
out_dir: str, |
|
|
device: str = DEVICE_DEFAULT, |
|
|
score_thresh: float = 0.25, |
|
|
topk: int = 100, |
|
|
term_agg: str = "mean", |
|
|
save_per_term: bool = True, |
|
|
): |
|
|
""" |
|
|
Returns: |
|
|
{ |
|
|
"combined": <path_to_combined_overlay>, |
|
|
"per_term": { term: path_to_overlay, ... } |
|
|
} |
|
|
""" |
|
|
if isinstance(terms, str): |
|
|
terms = [terms] |
|
|
|
|
|
prompt_lower = prompt.lower() |
|
|
|
|
|
|
|
|
terms = [t for t in terms if t.lower() in prompt_lower] |
|
|
|
|
|
if not terms: |
|
|
print(f"[TokenCA] No configured terms found in prompt: {prompt!r}") |
|
|
return {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = device or DEVICE_DEFAULT |
|
|
model = load_model(cfg_path, ckpt_path).to(device).eval() |
|
|
preprocess_image_fn = preprocess_image_fn_factory(device=device, longest=1024, pad_divisor=32) |
|
|
|
|
|
img_pil = load_image_keep_longest(image_path, longest=1024) |
|
|
|
|
|
os.makedirs(out_dir, exist_ok=True) |
|
|
base_name = os.path.splitext(os.path.basename(image_path))[0] |
|
|
combined_path = os.path.join(out_dir, f"{base_name}__attn_combined.png") |
|
|
|
|
|
|
|
|
decoder_layers = model.transformer.decoder.layers |
|
|
recorder = CrossAttnRecorder(decoder_layers, attn_attr_name="ca_text") |
|
|
|
|
|
|
|
|
img_tensor, mask = preprocess_image_fn(img_pil) |
|
|
samples = NestedTensor(img_tensor, mask) |
|
|
|
|
|
outputs = model(samples, captions=[prompt]) |
|
|
|
|
|
|
|
|
pred_logits = outputs["pred_logits"] |
|
|
pred_boxes = outputs["pred_boxes"] |
|
|
logits = pred_logits[0].sigmoid() |
|
|
scores, _ = logits.max(dim=1) |
|
|
|
|
|
keep = torch.nonzero(scores > score_thresh).squeeze(1) |
|
|
if keep.numel() == 0: |
|
|
keep = torch.argsort(scores, descending=True)[:min(topk, scores.numel())] |
|
|
else: |
|
|
keep = keep[:topk] |
|
|
|
|
|
W, H = img_pil.size |
|
|
boxes_cxcywh = pred_boxes[0][keep] |
|
|
cx, cy, w, h = boxes_cxcywh.unbind(-1) |
|
|
x1 = (cx - 0.5 * w) * W |
|
|
y1 = (cy - 0.5 * h) * H |
|
|
x2 = (cx + 0.5 * w) * W |
|
|
y2 = (cy + 0.5 * h) * H |
|
|
boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=-1) |
|
|
kept_scores = scores[keep] |
|
|
|
|
|
keep_cpu = keep.cpu() |
|
|
|
|
|
if len(recorder.attn_weights) == 0: |
|
|
recorder.close() |
|
|
raise RuntimeError("No attention weights captured. Check that 'ca_text' exists.") |
|
|
attn_qt_layers = [] |
|
|
for w_att in recorder.attn_weights: |
|
|
w_att = w_att.squeeze(0).mean(0) |
|
|
attn_qt_layers.append(w_att) |
|
|
attn_qt = torch.stack(attn_qt_layers, 0).mean(0) |
|
|
recorder.close() |
|
|
|
|
|
|
|
|
tok = tokenize_with_offsets(prompt, device="cpu") |
|
|
tokens, offsets = tok["tokens"], tok["offsets"] |
|
|
T_text = attn_qt.shape[1] |
|
|
|
|
|
per_term_attn_kept = {} |
|
|
per_term_attn_full = {} |
|
|
|
|
|
for t in terms: |
|
|
model_idxs = model_span_indices_for_term(tokens, offsets, T_text, t) |
|
|
if model_idxs.numel() == 0: |
|
|
continue |
|
|
attn_per_query = attn_qt[:, model_idxs].mean(1) |
|
|
attn_kept = attn_per_query[keep_cpu] |
|
|
attn_kept = (attn_kept - attn_kept.min()) / (attn_kept.max() - attn_kept.min() + 1e-6) |
|
|
per_term_attn_kept[t] = attn_kept |
|
|
per_term_attn_full[t] = attn_per_query |
|
|
|
|
|
if not per_term_attn_kept: |
|
|
|
|
|
print(f"[TokenCA] None of the terms were found in the first T tokens: {terms}") |
|
|
|
|
|
return {} |
|
|
|
|
|
|
|
|
agg = None |
|
|
for t, v in per_term_attn_full.items(): |
|
|
agg = v if agg is None else ( |
|
|
agg + v if term_agg == "sum" |
|
|
else torch.maximum(agg, v) if term_agg == "max" |
|
|
else (agg + v) |
|
|
) |
|
|
if term_agg == "mean": |
|
|
agg = agg / float(len(per_term_attn_full)) |
|
|
|
|
|
agg_kept = agg[keep_cpu] |
|
|
agg_kept = (agg_kept - agg_kept.min()) / (agg_kept.max() - agg_kept.min() + 1e-6) |
|
|
|
|
|
heat = boxes_to_heatmap( |
|
|
boxes_xyxy=boxes_xyxy, |
|
|
weights=agg_kept, |
|
|
hw=(H, W), |
|
|
score_scale=kept_scores, |
|
|
blur_ksize=61, |
|
|
blur_sigma=0, |
|
|
) |
|
|
overlay = overlay_heatmap(img_pil, heat, alpha=0.45) |
|
|
overlay.save(combined_path) |
|
|
|
|
|
per_term_paths = {} |
|
|
if save_per_term and len(per_term_attn_kept) > 1: |
|
|
for t, v in per_term_attn_kept.items(): |
|
|
heat_t = boxes_to_heatmap( |
|
|
boxes_xyxy=boxes_xyxy, |
|
|
weights=v, |
|
|
hw=(H, W), |
|
|
score_scale=kept_scores, |
|
|
blur_ksize=61, |
|
|
blur_sigma=0, |
|
|
) |
|
|
ov_t = overlay_heatmap(img_pil, heat_t, alpha=0.45) |
|
|
term_tag = re.sub(r"[^a-zA-Z0-9]+", "_", t.lower())[:32] |
|
|
p = os.path.join(out_dir, f"{base_name}__{term_tag}.png") |
|
|
ov_t.save(p) |
|
|
per_term_paths[t] = p |
|
|
|
|
|
return { |
|
|
"combined": combined_path, |
|
|
"per_term": per_term_paths, |
|
|
} |
|
|
|