# vg_token_attention.py # -*- coding: utf-8 -*- """ Token→region cross-attention visualization for GroundingDINO integrated as a helper. Usage from other modules: from vg_token_attention import run_token_ca_visualization paths = run_token_ca_visualization( cfg_path="VG/config/GroundingDINO_SwinT_OGC_2.py", ckpt_path="VG/weights/checkpoint0399_log4.pth", image_path=image_path, prompt=text_prompt, terms=chexbert_terms, # e.g. ["edema", "effusion"] out_dir="outputs/attn_overlays", device="cuda" or "cpu", ) """ import os import math import re import cv2 import torch import numpy as np import torch.nn.functional as F from torch import nn from PIL import Image import torchvision.transforms as T from VG.groundingdino.util.inference import load_model from VG.groundingdino.util.misc import NestedTensor from transformers import AutoTokenizer DEVICE_DEFAULT = "cuda" if torch.cuda.is_available() else "cpu" IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] # ----------------------------- # Preprocess: PIL -> (tensor, mask) # ----------------------------- def preprocess_image_fn_factory(device=DEVICE_DEFAULT, longest=1024, pad_divisor=32): to_tensor = T.ToTensor() normalize = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) def _resize_longest(pil_img: Image.Image, longest_side=1024): w, h = pil_img.size scale = float(longest_side) / max(w, h) new_w, new_h = int(round(w * scale)), int(round(h * scale)) return pil_img.resize((new_w, new_h), Image.BICUBIC) def preprocess_image_fn(pil_img: Image.Image): img_resized = _resize_longest(pil_img, longest_side=longest) x = normalize(to_tensor(img_resized)) # [3,H,W] _, H, W = x.shape # pad to /32 for backbone H_pad = math.ceil(H / pad_divisor) * pad_divisor W_pad = math.ceil(W / pad_divisor) * pad_divisor pad_h, pad_w = H_pad - H, W_pad - W x = F.pad(x, (0, pad_w, 0, pad_h), value=0.0) # [3,Hp,Wp] # mask: True on padded pixels mask = torch.zeros((H_pad, W_pad), dtype=torch.bool) if pad_h > 0: mask[H:, :] = True if pad_w > 0: mask[:, W:] = True return x.unsqueeze(0).to(device), mask.unsqueeze(0).to(device) return preprocess_image_fn # ----------------------------- # Tokenizer (BiomedVLP-CXR-BERT) # ----------------------------- BIOMEDVLP_TOKENIZER_PATH = "VG/weights/BiomedVLP-CXR-BERT/" _tokenizer = AutoTokenizer.from_pretrained(BIOMEDVLP_TOKENIZER_PATH) def tokenize_with_offsets(prompt: str, device=DEVICE_DEFAULT): enc = _tokenizer( prompt, return_tensors="pt", return_offsets_mapping=True, add_special_tokens=True, truncation=True, ) tokens = _tokenizer.convert_ids_to_tokens(enc["input_ids"][0]) offsets = enc["offset_mapping"][0].tolist() return { "input_ids": enc["input_ids"].to(device), "attention_mask": enc["attention_mask"].to(device), "tokens": tokens, "offsets": offsets, } def find_token_span_by_offsets(prompt: str, offsets, term: str): s = prompt.lower() t = term.lower() m = re.search(r'\b' + re.escape(t) + r'\b', s) or re.search(re.escape(t), s) if not m: return [] a, b = m.start(), m.end() idxs = [] for i, (u, v) in enumerate(offsets): if ( u is None or v is None or u < 0 or v < 0 or (u == 0 and v == 0) ): continue if not (v <= a or u >= b): # overlap with [a,b) idxs.append(i) return idxs def model_span_indices_for_term(tokens, offsets, attn_T, term: str): # 1) HF indices by offsets raw_hf_idxs = find_token_span_by_offsets( "".join(t if t != "[PAD]" else " " for t in tokens), offsets, term ) if not raw_hf_idxs: low = term.lower() raw_hf_idxs = [i for i, t in enumerate(tokens) if low in t.lower()] # 2) Map HF non-special → model positions 0..T-1 non_special_hf = [] for i, (tok_i, (u, v)) in enumerate(zip(tokens, offsets)): if tok_i in ("[CLS]", "[SEP]", "[PAD]"): continue if u is None or v is None or u < 0 or v < 0 or (u == 0 and v == 0): continue non_special_hf.append(i) non_special_hf = non_special_hf[:attn_T] hf2model = {hf_idx: j for j, hf_idx in enumerate(non_special_hf)} model_term_idxs = [hf2model[i] for i in raw_hf_idxs if i in hf2model] return torch.tensor(model_term_idxs, dtype=torch.long) # ----------------------------- # Cross-attention recorder # ----------------------------- class CrossAttnRecorder: def __init__(self, decoder_layers, attn_attr_name='ca_text'): self.attn_weights = [] # list of [B, heads, Q, T] self.handles = [] self._register(decoder_layers, attn_attr_name) def _hook(self, module, input, output): if isinstance(output, tuple) and len(output) >= 2: attn_w = output[1] elif hasattr(module, 'attn_output_weights'): attn_w = module.attn_output_weights else: attn_w = None if attn_w is not None: self.attn_weights.append(attn_w.detach().to('cpu', dtype=torch.float32)) def _wrap_forward(self, mha_module: nn.MultiheadAttention): orig_forward = mha_module.forward def wrapped_forward(*args, **kwargs): kwargs['need_weights'] = True kwargs['average_attn_weights'] = False return orig_forward(*args, **kwargs) return orig_forward, wrapped_forward def _register(self, decoder_layers, attn_attr_name): for layer in decoder_layers: attn_module = getattr(layer, attn_attr_name, None) if attn_module is None: continue if isinstance(attn_module, nn.MultiheadAttention): orig_fwd, wrapped = self._wrap_forward(attn_module) attn_module.forward = wrapped handle = attn_module.register_forward_hook(self._hook) self.handles.append((attn_module, handle, orig_fwd)) else: handle = attn_module.register_forward_hook(self._hook) self.handles.append((attn_module, handle, None)) def close(self): for attn_module, handle, orig_fwd in self.handles: handle.remove() if (orig_fwd is not None) and isinstance(attn_module, nn.MultiheadAttention): attn_module.forward = orig_fwd # ----------------------------- # Heatmap helpers # ----------------------------- def boxes_to_heatmap(boxes_xyxy, weights, hw, score_scale=None, blur_ksize=51, blur_sigma=0): H, W = hw heat = np.zeros((H, W), dtype=np.float32) w = weights.detach().cpu().numpy() if score_scale is not None: s = score_scale.detach().cpu().numpy() w = w * s for i, box in enumerate(boxes_xyxy): x1, y1, x2, y2 = map(int, box.tolist()) x1 = max(0, min(W - 1, x1)); x2 = max(0, min(W - 1, x2)) y1 = max(0, min(H - 1, y1)); y2 = max(0, min(H - 1, y2)) if x2 <= x1 or y2 <= y1: continue heat[y1:y2, x1:x2] += float(w[i]) if blur_ksize is not None and blur_ksize >= 3 and blur_ksize % 2 == 1: heat = cv2.GaussianBlur(heat, (blur_ksize, blur_ksize), blur_sigma) mx = heat.max() if mx > 1e-6: heat /= mx return heat def overlay_heatmap(img_pil: Image.Image, heatmap, alpha=0.45, cmap=cv2.COLORMAP_JET): img = np.array(img_pil.convert("RGB")) H, W = img.shape[:2] h = (np.clip(heatmap, 0, 1) * 255).astype(np.uint8) h_color = cv2.applyColorMap(h, cmap)[:, :, ::-1] blended = cv2.addWeighted(h_color, alpha, img, 1 - alpha, 0) return Image.fromarray(blended) def load_image_keep_longest(path, longest=1024): img = Image.open(path).convert("RGB") w, h = img.size s = float(longest) / max(w, h) new_w, new_h = int(round(w * s)), int(round(h * s)) return img.resize((new_w, new_h), Image.BICUBIC) # ----------------------------- # Main helper: one call from API # ----------------------------- @torch.no_grad() def run_token_ca_visualization( cfg_path: str, ckpt_path: str, image_path: str, prompt: str, terms, out_dir: str, device: str = DEVICE_DEFAULT, score_thresh: float = 0.25, topk: int = 100, term_agg: str = "mean", # "mean" | "max" | "sum" save_per_term: bool = True, ): """ Returns: { "combined": , "per_term": { term: path_to_overlay, ... } } """ if isinstance(terms, str): terms = [terms] prompt_lower = prompt.lower() # Keep only terms that actually appear in the prompt (case-insensitive) terms = [t for t in terms if t.lower() in prompt_lower] if not terms: print(f"[TokenCA] No configured terms found in prompt: {prompt!r}") return {} # or an empty dict / list, whatever you expect upstream # terms = [t.strip() for t in terms if t and t.strip()] # if not terms: # raise ValueError("No terms provided for attention visualization.") device = device or DEVICE_DEFAULT model = load_model(cfg_path, ckpt_path).to(device).eval() preprocess_image_fn = preprocess_image_fn_factory(device=device, longest=1024, pad_divisor=32) img_pil = load_image_keep_longest(image_path, longest=1024) os.makedirs(out_dir, exist_ok=True) base_name = os.path.splitext(os.path.basename(image_path))[0] combined_path = os.path.join(out_dir, f"{base_name}__attn_combined.png") # ---- hook cross-attn decoder_layers = model.transformer.decoder.layers recorder = CrossAttnRecorder(decoder_layers, attn_attr_name="ca_text") # preprocess → NestedTensor img_tensor, mask = preprocess_image_fn(img_pil) samples = NestedTensor(img_tensor, mask) outputs = model(samples, captions=[prompt]) # decode boxes pred_logits = outputs["pred_logits"] pred_boxes = outputs["pred_boxes"] logits = pred_logits[0].sigmoid() scores, _ = logits.max(dim=1) keep = torch.nonzero(scores > score_thresh).squeeze(1) if keep.numel() == 0: keep = torch.argsort(scores, descending=True)[:min(topk, scores.numel())] else: keep = keep[:topk] W, H = img_pil.size boxes_cxcywh = pred_boxes[0][keep] cx, cy, w, h = boxes_cxcywh.unbind(-1) x1 = (cx - 0.5 * w) * W y1 = (cy - 0.5 * h) * H x2 = (cx + 0.5 * w) * W y2 = (cy + 0.5 * h) * H boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=-1) kept_scores = scores[keep] keep_cpu = keep.cpu() if len(recorder.attn_weights) == 0: recorder.close() raise RuntimeError("No attention weights captured. Check that 'ca_text' exists.") attn_qt_layers = [] for w_att in recorder.attn_weights: w_att = w_att.squeeze(0).mean(0) # [Q,T] attn_qt_layers.append(w_att) attn_qt = torch.stack(attn_qt_layers, 0).mean(0) # [Q,T] recorder.close() # tokenize prompt tok = tokenize_with_offsets(prompt, device="cpu") tokens, offsets = tok["tokens"], tok["offsets"] T_text = attn_qt.shape[1] per_term_attn_kept = {} per_term_attn_full = {} for t in terms: model_idxs = model_span_indices_for_term(tokens, offsets, T_text, t) if model_idxs.numel() == 0: continue attn_per_query = attn_qt[:, model_idxs].mean(1) # [Q] attn_kept = attn_per_query[keep_cpu] attn_kept = (attn_kept - attn_kept.min()) / (attn_kept.max() - attn_kept.min() + 1e-6) per_term_attn_kept[t] = attn_kept per_term_attn_full[t] = attn_per_query if not per_term_attn_kept: # raise ValueError(f"None of the terms were found in the first T tokens: {terms}") print(f"[TokenCA] None of the terms were found in the first T tokens: {terms}") # Return an empty dict (or whatever your function usually returns) return {} # aggregate terms agg = None for t, v in per_term_attn_full.items(): agg = v if agg is None else ( agg + v if term_agg == "sum" else torch.maximum(agg, v) if term_agg == "max" else (agg + v) ) if term_agg == "mean": agg = agg / float(len(per_term_attn_full)) agg_kept = agg[keep_cpu] agg_kept = (agg_kept - agg_kept.min()) / (agg_kept.max() - agg_kept.min() + 1e-6) heat = boxes_to_heatmap( boxes_xyxy=boxes_xyxy, weights=agg_kept, hw=(H, W), score_scale=kept_scores, blur_ksize=61, blur_sigma=0, ) overlay = overlay_heatmap(img_pil, heat, alpha=0.45) overlay.save(combined_path) per_term_paths = {} if save_per_term and len(per_term_attn_kept) > 1: for t, v in per_term_attn_kept.items(): heat_t = boxes_to_heatmap( boxes_xyxy=boxes_xyxy, weights=v, hw=(H, W), score_scale=kept_scores, blur_ksize=61, blur_sigma=0, ) ov_t = overlay_heatmap(img_pil, heat_t, alpha=0.45) term_tag = re.sub(r"[^a-zA-Z0-9]+", "_", t.lower())[:32] p = os.path.join(out_dir, f"{base_name}__{term_tag}.png") ov_t.save(p) per_term_paths[t] = p return { "combined": combined_path, "per_term": per_term_paths, }