vicca / vg_token_attention.py
sayehghp's picture
Visualization
3600cfd
# vg_token_attention.py
# -*- coding: utf-8 -*-
"""
Token→region cross-attention visualization for GroundingDINO integrated as a helper.
Usage from other modules:
from vg_token_attention import run_token_ca_visualization
paths = run_token_ca_visualization(
cfg_path="VG/config/GroundingDINO_SwinT_OGC_2.py",
ckpt_path="VG/weights/checkpoint0399_log4.pth",
image_path=image_path,
prompt=text_prompt,
terms=chexbert_terms, # e.g. ["edema", "effusion"]
out_dir="outputs/attn_overlays",
device="cuda" or "cpu",
)
"""
import os
import math
import re
import cv2
import torch
import numpy as np
import torch.nn.functional as F
from torch import nn
from PIL import Image
import torchvision.transforms as T
from VG.groundingdino.util.inference import load_model
from VG.groundingdino.util.misc import NestedTensor
from transformers import AutoTokenizer
DEVICE_DEFAULT = "cuda" if torch.cuda.is_available() else "cpu"
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
# -----------------------------
# Preprocess: PIL -> (tensor, mask)
# -----------------------------
def preprocess_image_fn_factory(device=DEVICE_DEFAULT, longest=1024, pad_divisor=32):
to_tensor = T.ToTensor()
normalize = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
def _resize_longest(pil_img: Image.Image, longest_side=1024):
w, h = pil_img.size
scale = float(longest_side) / max(w, h)
new_w, new_h = int(round(w * scale)), int(round(h * scale))
return pil_img.resize((new_w, new_h), Image.BICUBIC)
def preprocess_image_fn(pil_img: Image.Image):
img_resized = _resize_longest(pil_img, longest_side=longest)
x = normalize(to_tensor(img_resized)) # [3,H,W]
_, H, W = x.shape
# pad to /32 for backbone
H_pad = math.ceil(H / pad_divisor) * pad_divisor
W_pad = math.ceil(W / pad_divisor) * pad_divisor
pad_h, pad_w = H_pad - H, W_pad - W
x = F.pad(x, (0, pad_w, 0, pad_h), value=0.0) # [3,Hp,Wp]
# mask: True on padded pixels
mask = torch.zeros((H_pad, W_pad), dtype=torch.bool)
if pad_h > 0:
mask[H:, :] = True
if pad_w > 0:
mask[:, W:] = True
return x.unsqueeze(0).to(device), mask.unsqueeze(0).to(device)
return preprocess_image_fn
# -----------------------------
# Tokenizer (BiomedVLP-CXR-BERT)
# -----------------------------
BIOMEDVLP_TOKENIZER_PATH = "VG/weights/BiomedVLP-CXR-BERT/"
_tokenizer = AutoTokenizer.from_pretrained(BIOMEDVLP_TOKENIZER_PATH)
def tokenize_with_offsets(prompt: str, device=DEVICE_DEFAULT):
enc = _tokenizer(
prompt,
return_tensors="pt",
return_offsets_mapping=True,
add_special_tokens=True,
truncation=True,
)
tokens = _tokenizer.convert_ids_to_tokens(enc["input_ids"][0])
offsets = enc["offset_mapping"][0].tolist()
return {
"input_ids": enc["input_ids"].to(device),
"attention_mask": enc["attention_mask"].to(device),
"tokens": tokens,
"offsets": offsets,
}
def find_token_span_by_offsets(prompt: str, offsets, term: str):
s = prompt.lower()
t = term.lower()
m = re.search(r'\b' + re.escape(t) + r'\b', s) or re.search(re.escape(t), s)
if not m:
return []
a, b = m.start(), m.end()
idxs = []
for i, (u, v) in enumerate(offsets):
if (
u is None or v is None or
u < 0 or v < 0 or
(u == 0 and v == 0)
):
continue
if not (v <= a or u >= b): # overlap with [a,b)
idxs.append(i)
return idxs
def model_span_indices_for_term(tokens, offsets, attn_T, term: str):
# 1) HF indices by offsets
raw_hf_idxs = find_token_span_by_offsets(
"".join(t if t != "[PAD]" else " " for t in tokens),
offsets,
term
)
if not raw_hf_idxs:
low = term.lower()
raw_hf_idxs = [i for i, t in enumerate(tokens) if low in t.lower()]
# 2) Map HF non-special → model positions 0..T-1
non_special_hf = []
for i, (tok_i, (u, v)) in enumerate(zip(tokens, offsets)):
if tok_i in ("[CLS]", "[SEP]", "[PAD]"):
continue
if u is None or v is None or u < 0 or v < 0 or (u == 0 and v == 0):
continue
non_special_hf.append(i)
non_special_hf = non_special_hf[:attn_T]
hf2model = {hf_idx: j for j, hf_idx in enumerate(non_special_hf)}
model_term_idxs = [hf2model[i] for i in raw_hf_idxs if i in hf2model]
return torch.tensor(model_term_idxs, dtype=torch.long)
# -----------------------------
# Cross-attention recorder
# -----------------------------
class CrossAttnRecorder:
def __init__(self, decoder_layers, attn_attr_name='ca_text'):
self.attn_weights = [] # list of [B, heads, Q, T]
self.handles = []
self._register(decoder_layers, attn_attr_name)
def _hook(self, module, input, output):
if isinstance(output, tuple) and len(output) >= 2:
attn_w = output[1]
elif hasattr(module, 'attn_output_weights'):
attn_w = module.attn_output_weights
else:
attn_w = None
if attn_w is not None:
self.attn_weights.append(attn_w.detach().to('cpu', dtype=torch.float32))
def _wrap_forward(self, mha_module: nn.MultiheadAttention):
orig_forward = mha_module.forward
def wrapped_forward(*args, **kwargs):
kwargs['need_weights'] = True
kwargs['average_attn_weights'] = False
return orig_forward(*args, **kwargs)
return orig_forward, wrapped_forward
def _register(self, decoder_layers, attn_attr_name):
for layer in decoder_layers:
attn_module = getattr(layer, attn_attr_name, None)
if attn_module is None:
continue
if isinstance(attn_module, nn.MultiheadAttention):
orig_fwd, wrapped = self._wrap_forward(attn_module)
attn_module.forward = wrapped
handle = attn_module.register_forward_hook(self._hook)
self.handles.append((attn_module, handle, orig_fwd))
else:
handle = attn_module.register_forward_hook(self._hook)
self.handles.append((attn_module, handle, None))
def close(self):
for attn_module, handle, orig_fwd in self.handles:
handle.remove()
if (orig_fwd is not None) and isinstance(attn_module, nn.MultiheadAttention):
attn_module.forward = orig_fwd
# -----------------------------
# Heatmap helpers
# -----------------------------
def boxes_to_heatmap(boxes_xyxy, weights, hw, score_scale=None, blur_ksize=51, blur_sigma=0):
H, W = hw
heat = np.zeros((H, W), dtype=np.float32)
w = weights.detach().cpu().numpy()
if score_scale is not None:
s = score_scale.detach().cpu().numpy()
w = w * s
for i, box in enumerate(boxes_xyxy):
x1, y1, x2, y2 = map(int, box.tolist())
x1 = max(0, min(W - 1, x1)); x2 = max(0, min(W - 1, x2))
y1 = max(0, min(H - 1, y1)); y2 = max(0, min(H - 1, y2))
if x2 <= x1 or y2 <= y1:
continue
heat[y1:y2, x1:x2] += float(w[i])
if blur_ksize is not None and blur_ksize >= 3 and blur_ksize % 2 == 1:
heat = cv2.GaussianBlur(heat, (blur_ksize, blur_ksize), blur_sigma)
mx = heat.max()
if mx > 1e-6:
heat /= mx
return heat
def overlay_heatmap(img_pil: Image.Image, heatmap, alpha=0.45, cmap=cv2.COLORMAP_JET):
img = np.array(img_pil.convert("RGB"))
H, W = img.shape[:2]
h = (np.clip(heatmap, 0, 1) * 255).astype(np.uint8)
h_color = cv2.applyColorMap(h, cmap)[:, :, ::-1]
blended = cv2.addWeighted(h_color, alpha, img, 1 - alpha, 0)
return Image.fromarray(blended)
def load_image_keep_longest(path, longest=1024):
img = Image.open(path).convert("RGB")
w, h = img.size
s = float(longest) / max(w, h)
new_w, new_h = int(round(w * s)), int(round(h * s))
return img.resize((new_w, new_h), Image.BICUBIC)
# -----------------------------
# Main helper: one call from API
# -----------------------------
@torch.no_grad()
def run_token_ca_visualization(
cfg_path: str,
ckpt_path: str,
image_path: str,
prompt: str,
terms,
out_dir: str,
device: str = DEVICE_DEFAULT,
score_thresh: float = 0.25,
topk: int = 100,
term_agg: str = "mean", # "mean" | "max" | "sum"
save_per_term: bool = True,
):
"""
Returns:
{
"combined": <path_to_combined_overlay>,
"per_term": { term: path_to_overlay, ... }
}
"""
if isinstance(terms, str):
terms = [terms]
prompt_lower = prompt.lower()
# Keep only terms that actually appear in the prompt (case-insensitive)
terms = [t for t in terms if t.lower() in prompt_lower]
if not terms:
print(f"[TokenCA] No configured terms found in prompt: {prompt!r}")
return {} # or an empty dict / list, whatever you expect upstream
# terms = [t.strip() for t in terms if t and t.strip()]
# if not terms:
# raise ValueError("No terms provided for attention visualization.")
device = device or DEVICE_DEFAULT
model = load_model(cfg_path, ckpt_path).to(device).eval()
preprocess_image_fn = preprocess_image_fn_factory(device=device, longest=1024, pad_divisor=32)
img_pil = load_image_keep_longest(image_path, longest=1024)
os.makedirs(out_dir, exist_ok=True)
base_name = os.path.splitext(os.path.basename(image_path))[0]
combined_path = os.path.join(out_dir, f"{base_name}__attn_combined.png")
# ---- hook cross-attn
decoder_layers = model.transformer.decoder.layers
recorder = CrossAttnRecorder(decoder_layers, attn_attr_name="ca_text")
# preprocess → NestedTensor
img_tensor, mask = preprocess_image_fn(img_pil)
samples = NestedTensor(img_tensor, mask)
outputs = model(samples, captions=[prompt])
# decode boxes
pred_logits = outputs["pred_logits"]
pred_boxes = outputs["pred_boxes"]
logits = pred_logits[0].sigmoid()
scores, _ = logits.max(dim=1)
keep = torch.nonzero(scores > score_thresh).squeeze(1)
if keep.numel() == 0:
keep = torch.argsort(scores, descending=True)[:min(topk, scores.numel())]
else:
keep = keep[:topk]
W, H = img_pil.size
boxes_cxcywh = pred_boxes[0][keep]
cx, cy, w, h = boxes_cxcywh.unbind(-1)
x1 = (cx - 0.5 * w) * W
y1 = (cy - 0.5 * h) * H
x2 = (cx + 0.5 * w) * W
y2 = (cy + 0.5 * h) * H
boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=-1)
kept_scores = scores[keep]
keep_cpu = keep.cpu()
if len(recorder.attn_weights) == 0:
recorder.close()
raise RuntimeError("No attention weights captured. Check that 'ca_text' exists.")
attn_qt_layers = []
for w_att in recorder.attn_weights:
w_att = w_att.squeeze(0).mean(0) # [Q,T]
attn_qt_layers.append(w_att)
attn_qt = torch.stack(attn_qt_layers, 0).mean(0) # [Q,T]
recorder.close()
# tokenize prompt
tok = tokenize_with_offsets(prompt, device="cpu")
tokens, offsets = tok["tokens"], tok["offsets"]
T_text = attn_qt.shape[1]
per_term_attn_kept = {}
per_term_attn_full = {}
for t in terms:
model_idxs = model_span_indices_for_term(tokens, offsets, T_text, t)
if model_idxs.numel() == 0:
continue
attn_per_query = attn_qt[:, model_idxs].mean(1) # [Q]
attn_kept = attn_per_query[keep_cpu]
attn_kept = (attn_kept - attn_kept.min()) / (attn_kept.max() - attn_kept.min() + 1e-6)
per_term_attn_kept[t] = attn_kept
per_term_attn_full[t] = attn_per_query
if not per_term_attn_kept:
# raise ValueError(f"None of the terms were found in the first T tokens: {terms}")
print(f"[TokenCA] None of the terms were found in the first T tokens: {terms}")
# Return an empty dict (or whatever your function usually returns)
return {}
# aggregate terms
agg = None
for t, v in per_term_attn_full.items():
agg = v if agg is None else (
agg + v if term_agg == "sum"
else torch.maximum(agg, v) if term_agg == "max"
else (agg + v)
)
if term_agg == "mean":
agg = agg / float(len(per_term_attn_full))
agg_kept = agg[keep_cpu]
agg_kept = (agg_kept - agg_kept.min()) / (agg_kept.max() - agg_kept.min() + 1e-6)
heat = boxes_to_heatmap(
boxes_xyxy=boxes_xyxy,
weights=agg_kept,
hw=(H, W),
score_scale=kept_scores,
blur_ksize=61,
blur_sigma=0,
)
overlay = overlay_heatmap(img_pil, heat, alpha=0.45)
overlay.save(combined_path)
per_term_paths = {}
if save_per_term and len(per_term_attn_kept) > 1:
for t, v in per_term_attn_kept.items():
heat_t = boxes_to_heatmap(
boxes_xyxy=boxes_xyxy,
weights=v,
hw=(H, W),
score_scale=kept_scores,
blur_ksize=61,
blur_sigma=0,
)
ov_t = overlay_heatmap(img_pil, heat_t, alpha=0.45)
term_tag = re.sub(r"[^a-zA-Z0-9]+", "_", t.lower())[:32]
p = os.path.join(out_dir, f"{base_name}__{term_tag}.png")
ov_t.save(p)
per_term_paths[t] = p
return {
"combined": combined_path,
"per_term": per_term_paths,
}