Displaying attention image
Browse files- app.py +72 -27
- inference.py +56 -1
- vg_token_attention.py +396 -0
- vicca_api.py +28 -0
app.py
CHANGED
|
@@ -5,45 +5,90 @@ import tempfile
|
|
| 5 |
import gradio as gr
|
| 6 |
from vicca_api import run_vicca
|
| 7 |
|
| 8 |
-
def vicca_interface(image, text_prompt):
|
| 9 |
-
""
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
"""
|
| 13 |
-
# Gradio passes a PIL image or a file path depending on type
|
| 14 |
-
# We'll request type='filepath' so this is already a path
|
| 15 |
-
image_path = image
|
| 16 |
|
| 17 |
result = run_vicca(
|
| 18 |
-
image_path=
|
| 19 |
text_prompt=text_prompt,
|
|
|
|
|
|
|
|
|
|
| 20 |
)
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
demo = gr.Interface(
|
| 27 |
fn=vicca_interface,
|
| 28 |
inputs=[
|
| 29 |
-
gr.Image(type="
|
| 30 |
-
gr.Textbox(label="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
],
|
| 32 |
-
|
| 33 |
-
title="VICCA – Visual Interpretation & Comprehension",
|
| 34 |
-
description=(
|
| 35 |
-
"Upload a chest X-ray and provide a text report / pathology description. "
|
| 36 |
-
"The VICCA pipeline will run CXR generation, visual grounding, "
|
| 37 |
-
"and ROI-level similarity scoring."
|
| 38 |
-
),
|
| 39 |
)
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
# if __name__ == "__main__":
|
| 42 |
# demo.launch()
|
| 43 |
-
if __name__ == "__main__":
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
|
|
|
|
| 5 |
import gradio as gr
|
| 6 |
from vicca_api import run_vicca
|
| 7 |
|
| 8 |
+
def vicca_interface(image, text_prompt, box_threshold=0.2, text_threshold=0.2, num_samples=4):
|
| 9 |
+
os.makedirs("uploads", exist_ok=True)
|
| 10 |
+
input_path = os.path.join("uploads", "input.png")
|
| 11 |
+
image.save(input_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
result = run_vicca(
|
| 14 |
+
image_path=input_path,
|
| 15 |
text_prompt=text_prompt,
|
| 16 |
+
box_threshold=box_threshold,
|
| 17 |
+
text_threshold=text_threshold,
|
| 18 |
+
num_samples=num_samples,
|
| 19 |
)
|
| 20 |
|
| 21 |
+
best_gen = result.get("best_generated_image_path")
|
| 22 |
+
|
| 23 |
+
attn = result.get("attention_overlays") or {}
|
| 24 |
+
combined = attn.get("combined")
|
| 25 |
+
per_term_dict = attn.get("per_term") or {}
|
| 26 |
+
|
| 27 |
+
gallery_items = [(p, term) for term, p in per_term_dict.items()]
|
| 28 |
+
|
| 29 |
+
return best_gen, combined, gallery_items, result
|
| 30 |
|
| 31 |
demo = gr.Interface(
|
| 32 |
fn=vicca_interface,
|
| 33 |
inputs=[
|
| 34 |
+
gr.Image(type="pil", label="Input CXR"),
|
| 35 |
+
gr.Textbox(lines=3, label="Text prompt"),
|
| 36 |
+
gr.Slider(0.0, 1.0, value=0.2, label="Box threshold"),
|
| 37 |
+
gr.Slider(0.0, 1.0, value=0.2, label="Text threshold"),
|
| 38 |
+
gr.Slider(1, 8, step=1, value=4, label="Number of samples"),
|
| 39 |
+
],
|
| 40 |
+
outputs=[
|
| 41 |
+
gr.Image(label="Best generated CXR"),
|
| 42 |
+
gr.Image(label="Combined attention heatmap"),
|
| 43 |
+
gr.Gallery(label="Per-term overlays").style(grid=[3], height=400),
|
| 44 |
+
gr.JSON(label="Raw VICCA output"),
|
| 45 |
],
|
| 46 |
+
title="VICCA",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
)
|
| 48 |
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, debug=False)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# def vicca_interface(image, text_prompt):
|
| 54 |
+
# """
|
| 55 |
+
# image: file from Gradio, we'll use its temp path
|
| 56 |
+
# text_prompt: report / description
|
| 57 |
+
# """
|
| 58 |
+
# # Gradio passes a PIL image or a file path depending on type
|
| 59 |
+
# # We'll request type='filepath' so this is already a path
|
| 60 |
+
# image_path = image
|
| 61 |
+
|
| 62 |
+
# result = run_vicca(
|
| 63 |
+
# image_path=image_path,
|
| 64 |
+
# text_prompt=text_prompt,
|
| 65 |
+
# )
|
| 66 |
+
|
| 67 |
+
# # You could also return the best generated image as an image output
|
| 68 |
+
# # For now, we expose the dict as JSON
|
| 69 |
+
# return result
|
| 70 |
+
|
| 71 |
+
# demo = gr.Interface(
|
| 72 |
+
# fn=vicca_interface,
|
| 73 |
+
# inputs=[
|
| 74 |
+
# gr.Image(type="filepath", label="Chest X-ray"),
|
| 75 |
+
# gr.Textbox(label="Report / pathology description", lines=3),
|
| 76 |
+
# ],
|
| 77 |
+
# outputs=gr.JSON(label="VICCA output"),
|
| 78 |
+
# title="VICCA – Visual Interpretation & Comprehension",
|
| 79 |
+
# description=(
|
| 80 |
+
# "Upload a chest X-ray and provide a text report / pathology description. "
|
| 81 |
+
# "The VICCA pipeline will run CXR generation, visual grounding, "
|
| 82 |
+
# "and ROI-level similarity scoring."
|
| 83 |
+
# ),
|
| 84 |
+
# )
|
| 85 |
+
|
| 86 |
# if __name__ == "__main__":
|
| 87 |
# demo.launch()
|
| 88 |
+
# if __name__ == "__main__":
|
| 89 |
+
# demo.launch(
|
| 90 |
+
# server_name="0.0.0.0",
|
| 91 |
+
# server_port=7860,
|
| 92 |
+
# debug=False
|
| 93 |
+
# )
|
| 94 |
|
inference.py
CHANGED
|
@@ -9,11 +9,13 @@ python inference.py \
|
|
| 9 |
|
| 10 |
"""
|
| 11 |
import pandas as pd
|
|
|
|
| 12 |
import time
|
| 13 |
import cv2
|
| 14 |
import sys
|
| 15 |
import argparse
|
| 16 |
from ast import literal_eval
|
|
|
|
| 17 |
|
| 18 |
# sys.path.append('/home/gholipos-admin/Desktop/Thesis/Training_Code/VICCA')
|
| 19 |
from pathlib import Path
|
|
@@ -101,7 +103,7 @@ from DETR import svc
|
|
| 101 |
from DETR.arguments import get_args_parser as get_detr_args_parser
|
| 102 |
from VG import localization
|
| 103 |
from ssim import ssim
|
| 104 |
-
|
| 105 |
|
| 106 |
def get_args_parser():
|
| 107 |
parser = argparse.ArgumentParser('Set the Input', add_help=True)
|
|
@@ -120,6 +122,59 @@ def get_args_parser():
|
|
| 120 |
help="Path to save generated files.")
|
| 121 |
return parser
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
def extract_tensor(value):
|
| 124 |
cleaned_value = value.replace('tensor(', '').replace(')', '')
|
| 125 |
return literal_eval(cleaned_value)
|
|
|
|
| 9 |
|
| 10 |
"""
|
| 11 |
import pandas as pd
|
| 12 |
+
import numpy as np
|
| 13 |
import time
|
| 14 |
import cv2
|
| 15 |
import sys
|
| 16 |
import argparse
|
| 17 |
from ast import literal_eval
|
| 18 |
+
from nltk import tokenize
|
| 19 |
|
| 20 |
# sys.path.append('/home/gholipos-admin/Desktop/Thesis/Training_Code/VICCA')
|
| 21 |
from pathlib import Path
|
|
|
|
| 103 |
from DETR.arguments import get_args_parser as get_detr_args_parser
|
| 104 |
from VG import localization
|
| 105 |
from ssim import ssim
|
| 106 |
+
from CheXbert.src.label import label
|
| 107 |
|
| 108 |
def get_args_parser():
|
| 109 |
parser = argparse.ArgumentParser('Set the Input', add_help=True)
|
|
|
|
| 122 |
help="Path to save generated files.")
|
| 123 |
return parser
|
| 124 |
|
| 125 |
+
path_list = ['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
|
| 126 |
+
'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis',
|
| 127 |
+
'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture',
|
| 128 |
+
'Support Devices', 'No Finding']
|
| 129 |
+
|
| 130 |
+
# Cache CheXbert weights once at import time
|
| 131 |
+
CHEXBERT_WEIGHTS = get_weight("CheXbert/checkpoint/chexbert.pth")
|
| 132 |
+
|
| 133 |
+
# def chexbert_pathology(text):
|
| 134 |
+
# sentences = list(set(tokenize.sent_tokenize(text)))
|
| 135 |
+
# path_dict = []
|
| 136 |
+
# for sentence in sentences:
|
| 137 |
+
# sentence = sentence.replace('\n',' ')
|
| 138 |
+
# sentence = sentence.replace('\s+',' ')
|
| 139 |
+
# chexbert_weight_path = get_weight("CheXbert/checkpoint/chexbert.pth")
|
| 140 |
+
# # pathology = np.array(label("CheXbert/checkpoint/chexbert.pth", sentence)).T[0]
|
| 141 |
+
# pathology = np.array(label(chexbert_weight_path, sentence)).T[0]
|
| 142 |
+
# if pathology[-1]==1 or len(list(set(pathology)))==1 or not any(e==1 for e in pathology):
|
| 143 |
+
# pass
|
| 144 |
+
# else:
|
| 145 |
+
# indice = [i for i, e in enumerate(pathology) if e==1]
|
| 146 |
+
# for ind in indice:
|
| 147 |
+
# path_dict.append(path_list[ind])
|
| 148 |
+
# return path_dict
|
| 149 |
+
def chexbert_pathology(text: str):
|
| 150 |
+
"""
|
| 151 |
+
Run CheXbert on the text and return a list of *positive* pathology labels,
|
| 152 |
+
deduplicated.
|
| 153 |
+
"""
|
| 154 |
+
# If NLTK punkt ever becomes a problem on Spaces, replace this with a simple split.
|
| 155 |
+
# sentences = list(set(tokenize.sent_tokenize(text)))
|
| 156 |
+
sentences = [s.strip() for s in text.split(".") if s.strip()]
|
| 157 |
+
|
| 158 |
+
path_terms = set()
|
| 159 |
+
|
| 160 |
+
for sentence in sentences:
|
| 161 |
+
sentence = sentence.replace("\n", " ")
|
| 162 |
+
sentence = sentence.replace("\s+", " ")
|
| 163 |
+
|
| 164 |
+
# Run CheXbert
|
| 165 |
+
pathology = np.array(label(CHEXBERT_WEIGHTS, sentence)).T[0]
|
| 166 |
+
|
| 167 |
+
# Skip if: "No Finding" active, or all labels same, or no positives
|
| 168 |
+
if pathology[-1] == 1 or len(set(pathology)) == 1 or not any(e == 1 for e in pathology):
|
| 169 |
+
continue
|
| 170 |
+
|
| 171 |
+
# Collect positive indices
|
| 172 |
+
indices = [i for i, e in enumerate(pathology) if e == 1]
|
| 173 |
+
for ind in indices:
|
| 174 |
+
path_terms.add(path_list[ind])
|
| 175 |
+
|
| 176 |
+
return sorted(path_terms)
|
| 177 |
+
|
| 178 |
def extract_tensor(value):
|
| 179 |
cleaned_value = value.replace('tensor(', '').replace(')', '')
|
| 180 |
return literal_eval(cleaned_value)
|
vg_token_attention.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# vg_token_attention.py
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Token→region cross-attention visualization for GroundingDINO integrated as a helper.
|
| 5 |
+
|
| 6 |
+
Usage from other modules:
|
| 7 |
+
from vg_token_attention import run_token_ca_visualization
|
| 8 |
+
|
| 9 |
+
paths = run_token_ca_visualization(
|
| 10 |
+
cfg_path="VG/config/GroundingDINO_SwinT_OGC_2.py",
|
| 11 |
+
ckpt_path="VG/weights/checkpoint0399_log4.pth",
|
| 12 |
+
image_path=image_path,
|
| 13 |
+
prompt=text_prompt,
|
| 14 |
+
terms=chexbert_terms, # e.g. ["edema", "effusion"]
|
| 15 |
+
out_dir="outputs/attn_overlays",
|
| 16 |
+
device="cuda" or "cpu",
|
| 17 |
+
)
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import math
|
| 22 |
+
import re
|
| 23 |
+
import cv2
|
| 24 |
+
import torch
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
from torch import nn
|
| 28 |
+
from PIL import Image
|
| 29 |
+
import torchvision.transforms as T
|
| 30 |
+
|
| 31 |
+
from VG.groundingdino.util.inference import load_model
|
| 32 |
+
from VG.groundingdino.util.misc import NestedTensor
|
| 33 |
+
|
| 34 |
+
from transformers import AutoTokenizer
|
| 35 |
+
|
| 36 |
+
DEVICE_DEFAULT = "cuda" if torch.cuda.is_available() else "cpu"
|
| 37 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 38 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# -----------------------------
|
| 42 |
+
# Preprocess: PIL -> (tensor, mask)
|
| 43 |
+
# -----------------------------
|
| 44 |
+
def preprocess_image_fn_factory(device=DEVICE_DEFAULT, longest=1024, pad_divisor=32):
|
| 45 |
+
to_tensor = T.ToTensor()
|
| 46 |
+
normalize = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
|
| 47 |
+
|
| 48 |
+
def _resize_longest(pil_img: Image.Image, longest_side=1024):
|
| 49 |
+
w, h = pil_img.size
|
| 50 |
+
scale = float(longest_side) / max(w, h)
|
| 51 |
+
new_w, new_h = int(round(w * scale)), int(round(h * scale))
|
| 52 |
+
return pil_img.resize((new_w, new_h), Image.BICUBIC)
|
| 53 |
+
|
| 54 |
+
def preprocess_image_fn(pil_img: Image.Image):
|
| 55 |
+
img_resized = _resize_longest(pil_img, longest_side=longest)
|
| 56 |
+
x = normalize(to_tensor(img_resized)) # [3,H,W]
|
| 57 |
+
_, H, W = x.shape
|
| 58 |
+
|
| 59 |
+
# pad to /32 for backbone
|
| 60 |
+
H_pad = math.ceil(H / pad_divisor) * pad_divisor
|
| 61 |
+
W_pad = math.ceil(W / pad_divisor) * pad_divisor
|
| 62 |
+
pad_h, pad_w = H_pad - H, W_pad - W
|
| 63 |
+
x = F.pad(x, (0, pad_w, 0, pad_h), value=0.0) # [3,Hp,Wp]
|
| 64 |
+
|
| 65 |
+
# mask: True on padded pixels
|
| 66 |
+
mask = torch.zeros((H_pad, W_pad), dtype=torch.bool)
|
| 67 |
+
if pad_h > 0:
|
| 68 |
+
mask[H:, :] = True
|
| 69 |
+
if pad_w > 0:
|
| 70 |
+
mask[:, W:] = True
|
| 71 |
+
|
| 72 |
+
return x.unsqueeze(0).to(device), mask.unsqueeze(0).to(device)
|
| 73 |
+
|
| 74 |
+
return preprocess_image_fn
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# -----------------------------
|
| 78 |
+
# Tokenizer (BiomedVLP-CXR-BERT)
|
| 79 |
+
# -----------------------------
|
| 80 |
+
BIOMEDVLP_TOKENIZER_PATH = "VG/weights/BiomedVLP-CXR-BERT/"
|
| 81 |
+
|
| 82 |
+
_tokenizer = AutoTokenizer.from_pretrained(BIOMEDVLP_TOKENIZER_PATH)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def tokenize_with_offsets(prompt: str, device=DEVICE_DEFAULT):
|
| 86 |
+
enc = _tokenizer(
|
| 87 |
+
prompt,
|
| 88 |
+
return_tensors="pt",
|
| 89 |
+
return_offsets_mapping=True,
|
| 90 |
+
add_special_tokens=True,
|
| 91 |
+
truncation=True,
|
| 92 |
+
)
|
| 93 |
+
tokens = _tokenizer.convert_ids_to_tokens(enc["input_ids"][0])
|
| 94 |
+
offsets = enc["offset_mapping"][0].tolist()
|
| 95 |
+
return {
|
| 96 |
+
"input_ids": enc["input_ids"].to(device),
|
| 97 |
+
"attention_mask": enc["attention_mask"].to(device),
|
| 98 |
+
"tokens": tokens,
|
| 99 |
+
"offsets": offsets,
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def find_token_span_by_offsets(prompt: str, offsets, term: str):
|
| 104 |
+
s = prompt.lower()
|
| 105 |
+
t = term.lower()
|
| 106 |
+
m = re.search(r'\b' + re.escape(t) + r'\b', s) or re.search(re.escape(t), s)
|
| 107 |
+
if not m:
|
| 108 |
+
return []
|
| 109 |
+
a, b = m.start(), m.end()
|
| 110 |
+
idxs = []
|
| 111 |
+
for i, (u, v) in enumerate(offsets):
|
| 112 |
+
if (
|
| 113 |
+
u is None or v is None or
|
| 114 |
+
u < 0 or v < 0 or
|
| 115 |
+
(u == 0 and v == 0)
|
| 116 |
+
):
|
| 117 |
+
continue
|
| 118 |
+
if not (v <= a or u >= b): # overlap with [a,b)
|
| 119 |
+
idxs.append(i)
|
| 120 |
+
return idxs
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def model_span_indices_for_term(tokens, offsets, attn_T, term: str):
|
| 124 |
+
# 1) HF indices by offsets
|
| 125 |
+
raw_hf_idxs = find_token_span_by_offsets(
|
| 126 |
+
"".join(t if t != "[PAD]" else " " for t in tokens),
|
| 127 |
+
offsets,
|
| 128 |
+
term
|
| 129 |
+
)
|
| 130 |
+
if not raw_hf_idxs:
|
| 131 |
+
low = term.lower()
|
| 132 |
+
raw_hf_idxs = [i for i, t in enumerate(tokens) if low in t.lower()]
|
| 133 |
+
|
| 134 |
+
# 2) Map HF non-special → model positions 0..T-1
|
| 135 |
+
non_special_hf = []
|
| 136 |
+
for i, (tok_i, (u, v)) in enumerate(zip(tokens, offsets)):
|
| 137 |
+
if tok_i in ("[CLS]", "[SEP]", "[PAD]"):
|
| 138 |
+
continue
|
| 139 |
+
if u is None or v is None or u < 0 or v < 0 or (u == 0 and v == 0):
|
| 140 |
+
continue
|
| 141 |
+
non_special_hf.append(i)
|
| 142 |
+
|
| 143 |
+
non_special_hf = non_special_hf[:attn_T]
|
| 144 |
+
hf2model = {hf_idx: j for j, hf_idx in enumerate(non_special_hf)}
|
| 145 |
+
model_term_idxs = [hf2model[i] for i in raw_hf_idxs if i in hf2model]
|
| 146 |
+
|
| 147 |
+
return torch.tensor(model_term_idxs, dtype=torch.long)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# -----------------------------
|
| 151 |
+
# Cross-attention recorder
|
| 152 |
+
# -----------------------------
|
| 153 |
+
class CrossAttnRecorder:
|
| 154 |
+
def __init__(self, decoder_layers, attn_attr_name='ca_text'):
|
| 155 |
+
self.attn_weights = [] # list of [B, heads, Q, T]
|
| 156 |
+
self.handles = []
|
| 157 |
+
self._register(decoder_layers, attn_attr_name)
|
| 158 |
+
|
| 159 |
+
def _hook(self, module, input, output):
|
| 160 |
+
if isinstance(output, tuple) and len(output) >= 2:
|
| 161 |
+
attn_w = output[1]
|
| 162 |
+
elif hasattr(module, 'attn_output_weights'):
|
| 163 |
+
attn_w = module.attn_output_weights
|
| 164 |
+
else:
|
| 165 |
+
attn_w = None
|
| 166 |
+
if attn_w is not None:
|
| 167 |
+
self.attn_weights.append(attn_w.detach().to('cpu', dtype=torch.float32))
|
| 168 |
+
|
| 169 |
+
def _wrap_forward(self, mha_module: nn.MultiheadAttention):
|
| 170 |
+
orig_forward = mha_module.forward
|
| 171 |
+
|
| 172 |
+
def wrapped_forward(*args, **kwargs):
|
| 173 |
+
kwargs['need_weights'] = True
|
| 174 |
+
kwargs['average_attn_weights'] = False
|
| 175 |
+
return orig_forward(*args, **kwargs)
|
| 176 |
+
|
| 177 |
+
return orig_forward, wrapped_forward
|
| 178 |
+
|
| 179 |
+
def _register(self, decoder_layers, attn_attr_name):
|
| 180 |
+
for layer in decoder_layers:
|
| 181 |
+
attn_module = getattr(layer, attn_attr_name, None)
|
| 182 |
+
if attn_module is None:
|
| 183 |
+
continue
|
| 184 |
+
if isinstance(attn_module, nn.MultiheadAttention):
|
| 185 |
+
orig_fwd, wrapped = self._wrap_forward(attn_module)
|
| 186 |
+
attn_module.forward = wrapped
|
| 187 |
+
handle = attn_module.register_forward_hook(self._hook)
|
| 188 |
+
self.handles.append((attn_module, handle, orig_fwd))
|
| 189 |
+
else:
|
| 190 |
+
handle = attn_module.register_forward_hook(self._hook)
|
| 191 |
+
self.handles.append((attn_module, handle, None))
|
| 192 |
+
|
| 193 |
+
def close(self):
|
| 194 |
+
for attn_module, handle, orig_fwd in self.handles:
|
| 195 |
+
handle.remove()
|
| 196 |
+
if (orig_fwd is not None) and isinstance(attn_module, nn.MultiheadAttention):
|
| 197 |
+
attn_module.forward = orig_fwd
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
# -----------------------------
|
| 201 |
+
# Heatmap helpers
|
| 202 |
+
# -----------------------------
|
| 203 |
+
def boxes_to_heatmap(boxes_xyxy, weights, hw, score_scale=None, blur_ksize=51, blur_sigma=0):
|
| 204 |
+
H, W = hw
|
| 205 |
+
heat = np.zeros((H, W), dtype=np.float32)
|
| 206 |
+
|
| 207 |
+
w = weights.detach().cpu().numpy()
|
| 208 |
+
if score_scale is not None:
|
| 209 |
+
s = score_scale.detach().cpu().numpy()
|
| 210 |
+
w = w * s
|
| 211 |
+
|
| 212 |
+
for i, box in enumerate(boxes_xyxy):
|
| 213 |
+
x1, y1, x2, y2 = map(int, box.tolist())
|
| 214 |
+
x1 = max(0, min(W - 1, x1)); x2 = max(0, min(W - 1, x2))
|
| 215 |
+
y1 = max(0, min(H - 1, y1)); y2 = max(0, min(H - 1, y2))
|
| 216 |
+
if x2 <= x1 or y2 <= y1:
|
| 217 |
+
continue
|
| 218 |
+
heat[y1:y2, x1:x2] += float(w[i])
|
| 219 |
+
|
| 220 |
+
if blur_ksize is not None and blur_ksize >= 3 and blur_ksize % 2 == 1:
|
| 221 |
+
heat = cv2.GaussianBlur(heat, (blur_ksize, blur_ksize), blur_sigma)
|
| 222 |
+
|
| 223 |
+
mx = heat.max()
|
| 224 |
+
if mx > 1e-6:
|
| 225 |
+
heat /= mx
|
| 226 |
+
return heat
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def overlay_heatmap(img_pil: Image.Image, heatmap, alpha=0.45, cmap=cv2.COLORMAP_JET):
|
| 230 |
+
img = np.array(img_pil.convert("RGB"))
|
| 231 |
+
H, W = img.shape[:2]
|
| 232 |
+
h = (np.clip(heatmap, 0, 1) * 255).astype(np.uint8)
|
| 233 |
+
h_color = cv2.applyColorMap(h, cmap)[:, :, ::-1]
|
| 234 |
+
blended = cv2.addWeighted(h_color, alpha, img, 1 - alpha, 0)
|
| 235 |
+
return Image.fromarray(blended)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def load_image_keep_longest(path, longest=1024):
|
| 239 |
+
img = Image.open(path).convert("RGB")
|
| 240 |
+
w, h = img.size
|
| 241 |
+
s = float(longest) / max(w, h)
|
| 242 |
+
new_w, new_h = int(round(w * s)), int(round(h * s))
|
| 243 |
+
return img.resize((new_w, new_h), Image.BICUBIC)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# -----------------------------
|
| 247 |
+
# Main helper: one call from API
|
| 248 |
+
# -----------------------------
|
| 249 |
+
@torch.no_grad()
|
| 250 |
+
def run_token_ca_visualization(
|
| 251 |
+
cfg_path: str,
|
| 252 |
+
ckpt_path: str,
|
| 253 |
+
image_path: str,
|
| 254 |
+
prompt: str,
|
| 255 |
+
terms,
|
| 256 |
+
out_dir: str,
|
| 257 |
+
device: str = DEVICE_DEFAULT,
|
| 258 |
+
score_thresh: float = 0.25,
|
| 259 |
+
topk: int = 100,
|
| 260 |
+
term_agg: str = "mean", # "mean" | "max" | "sum"
|
| 261 |
+
save_per_term: bool = True,
|
| 262 |
+
):
|
| 263 |
+
"""
|
| 264 |
+
Returns:
|
| 265 |
+
{
|
| 266 |
+
"combined": <path_to_combined_overlay>,
|
| 267 |
+
"per_term": { term: path_to_overlay, ... }
|
| 268 |
+
}
|
| 269 |
+
"""
|
| 270 |
+
if isinstance(terms, str):
|
| 271 |
+
terms = [terms]
|
| 272 |
+
terms = [t.strip() for t in terms if t and t.strip()]
|
| 273 |
+
if not terms:
|
| 274 |
+
raise ValueError("No terms provided for attention visualization.")
|
| 275 |
+
|
| 276 |
+
device = device or DEVICE_DEFAULT
|
| 277 |
+
model = load_model(cfg_path, ckpt_path).to(device).eval()
|
| 278 |
+
preprocess_image_fn = preprocess_image_fn_factory(device=device, longest=1024, pad_divisor=32)
|
| 279 |
+
|
| 280 |
+
img_pil = load_image_keep_longest(image_path, longest=1024)
|
| 281 |
+
|
| 282 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 283 |
+
base_name = os.path.splitext(os.path.basename(image_path))[0]
|
| 284 |
+
combined_path = os.path.join(out_dir, f"{base_name}__attn_combined.png")
|
| 285 |
+
|
| 286 |
+
# ---- hook cross-attn
|
| 287 |
+
decoder_layers = model.transformer.decoder.layers
|
| 288 |
+
recorder = CrossAttnRecorder(decoder_layers, attn_attr_name="ca_text")
|
| 289 |
+
|
| 290 |
+
# preprocess → NestedTensor
|
| 291 |
+
img_tensor, mask = preprocess_image_fn(img_pil)
|
| 292 |
+
samples = NestedTensor(img_tensor, mask)
|
| 293 |
+
|
| 294 |
+
outputs = model(samples, captions=[prompt])
|
| 295 |
+
|
| 296 |
+
# decode boxes
|
| 297 |
+
pred_logits = outputs["pred_logits"]
|
| 298 |
+
pred_boxes = outputs["pred_boxes"]
|
| 299 |
+
logits = pred_logits[0].sigmoid()
|
| 300 |
+
scores, _ = logits.max(dim=1)
|
| 301 |
+
|
| 302 |
+
keep = torch.nonzero(scores > score_thresh).squeeze(1)
|
| 303 |
+
if keep.numel() == 0:
|
| 304 |
+
keep = torch.argsort(scores, descending=True)[:min(topk, scores.numel())]
|
| 305 |
+
else:
|
| 306 |
+
keep = keep[:topk]
|
| 307 |
+
|
| 308 |
+
W, H = img_pil.size
|
| 309 |
+
boxes_cxcywh = pred_boxes[0][keep]
|
| 310 |
+
cx, cy, w, h = boxes_cxcywh.unbind(-1)
|
| 311 |
+
x1 = (cx - 0.5 * w) * W
|
| 312 |
+
y1 = (cy - 0.5 * h) * H
|
| 313 |
+
x2 = (cx + 0.5 * w) * W
|
| 314 |
+
y2 = (cy + 0.5 * h) * H
|
| 315 |
+
boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=-1)
|
| 316 |
+
kept_scores = scores[keep]
|
| 317 |
+
|
| 318 |
+
keep_cpu = keep.cpu()
|
| 319 |
+
|
| 320 |
+
if len(recorder.attn_weights) == 0:
|
| 321 |
+
recorder.close()
|
| 322 |
+
raise RuntimeError("No attention weights captured. Check that 'ca_text' exists.")
|
| 323 |
+
attn_qt_layers = []
|
| 324 |
+
for w_att in recorder.attn_weights:
|
| 325 |
+
w_att = w_att.squeeze(0).mean(0) # [Q,T]
|
| 326 |
+
attn_qt_layers.append(w_att)
|
| 327 |
+
attn_qt = torch.stack(attn_qt_layers, 0).mean(0) # [Q,T]
|
| 328 |
+
recorder.close()
|
| 329 |
+
|
| 330 |
+
# tokenize prompt
|
| 331 |
+
tok = tokenize_with_offsets(prompt, device="cpu")
|
| 332 |
+
tokens, offsets = tok["tokens"], tok["offsets"]
|
| 333 |
+
T_text = attn_qt.shape[1]
|
| 334 |
+
|
| 335 |
+
per_term_attn_kept = {}
|
| 336 |
+
per_term_attn_full = {}
|
| 337 |
+
|
| 338 |
+
for t in terms:
|
| 339 |
+
model_idxs = model_span_indices_for_term(tokens, offsets, T_text, t)
|
| 340 |
+
if model_idxs.numel() == 0:
|
| 341 |
+
continue
|
| 342 |
+
attn_per_query = attn_qt[:, model_idxs].mean(1) # [Q]
|
| 343 |
+
attn_kept = attn_per_query[keep_cpu]
|
| 344 |
+
attn_kept = (attn_kept - attn_kept.min()) / (attn_kept.max() - attn_kept.min() + 1e-6)
|
| 345 |
+
per_term_attn_kept[t] = attn_kept
|
| 346 |
+
per_term_attn_full[t] = attn_per_query
|
| 347 |
+
|
| 348 |
+
if not per_term_attn_kept:
|
| 349 |
+
raise ValueError(f"None of the terms were found in the first T tokens: {terms}")
|
| 350 |
+
|
| 351 |
+
# aggregate terms
|
| 352 |
+
agg = None
|
| 353 |
+
for t, v in per_term_attn_full.items():
|
| 354 |
+
agg = v if agg is None else (
|
| 355 |
+
agg + v if term_agg == "sum"
|
| 356 |
+
else torch.maximum(agg, v) if term_agg == "max"
|
| 357 |
+
else (agg + v)
|
| 358 |
+
)
|
| 359 |
+
if term_agg == "mean":
|
| 360 |
+
agg = agg / float(len(per_term_attn_full))
|
| 361 |
+
|
| 362 |
+
agg_kept = agg[keep_cpu]
|
| 363 |
+
agg_kept = (agg_kept - agg_kept.min()) / (agg_kept.max() - agg_kept.min() + 1e-6)
|
| 364 |
+
|
| 365 |
+
heat = boxes_to_heatmap(
|
| 366 |
+
boxes_xyxy=boxes_xyxy,
|
| 367 |
+
weights=agg_kept,
|
| 368 |
+
hw=(H, W),
|
| 369 |
+
score_scale=kept_scores,
|
| 370 |
+
blur_ksize=61,
|
| 371 |
+
blur_sigma=0,
|
| 372 |
+
)
|
| 373 |
+
overlay = overlay_heatmap(img_pil, heat, alpha=0.45)
|
| 374 |
+
overlay.save(combined_path)
|
| 375 |
+
|
| 376 |
+
per_term_paths = {}
|
| 377 |
+
if save_per_term and len(per_term_attn_kept) > 1:
|
| 378 |
+
for t, v in per_term_attn_kept.items():
|
| 379 |
+
heat_t = boxes_to_heatmap(
|
| 380 |
+
boxes_xyxy=boxes_xyxy,
|
| 381 |
+
weights=v,
|
| 382 |
+
hw=(H, W),
|
| 383 |
+
score_scale=kept_scores,
|
| 384 |
+
blur_ksize=61,
|
| 385 |
+
blur_sigma=0,
|
| 386 |
+
)
|
| 387 |
+
ov_t = overlay_heatmap(img_pil, heat_t, alpha=0.45)
|
| 388 |
+
term_tag = re.sub(r"[^a-zA-Z0-9]+", "_", t.lower())[:32]
|
| 389 |
+
p = os.path.join(out_dir, f"{base_name}__{term_tag}.png")
|
| 390 |
+
ov_t.save(p)
|
| 391 |
+
per_term_paths[t] = p
|
| 392 |
+
|
| 393 |
+
return {
|
| 394 |
+
"combined": combined_path,
|
| 395 |
+
"per_term": per_term_paths,
|
| 396 |
+
}
|
vicca_api.py
CHANGED
|
@@ -3,8 +3,11 @@ import os
|
|
| 3 |
import time
|
| 4 |
import cv2
|
| 5 |
import pandas as pd
|
|
|
|
| 6 |
|
| 7 |
from weights_utils import ensure_all_vicca_weights, get_weight
|
|
|
|
|
|
|
| 8 |
|
| 9 |
# Make sure all heavy weights are present once per container
|
| 10 |
ensure_all_vicca_weights()
|
|
@@ -14,6 +17,7 @@ from inference import (
|
|
| 14 |
cal_shift,
|
| 15 |
get_local_bbox,
|
| 16 |
extract_tensor,
|
|
|
|
| 17 |
)
|
| 18 |
|
| 19 |
def run_vicca(
|
|
@@ -23,6 +27,7 @@ def run_vicca(
|
|
| 23 |
text_threshold: float = 0.2,
|
| 24 |
num_samples: int = 4,
|
| 25 |
output_path: str = "CXRGen/test/samples/output/",
|
|
|
|
| 26 |
):
|
| 27 |
"""
|
| 28 |
Top-level VICCA API used by app.py / Gradio.
|
|
@@ -87,6 +92,28 @@ def run_vicca(
|
|
| 87 |
score = ssim(roi_org, roi_gen)
|
| 88 |
ssim_scores.append(score)
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
return {
|
| 91 |
"boxes": boxes,
|
| 92 |
"logits": logits,
|
|
@@ -95,4 +122,5 @@ def run_vicca(
|
|
| 95 |
"shift_x": sx,
|
| 96 |
"shift_y": sy,
|
| 97 |
"best_generated_image_path": max_sim_gen_path,
|
|
|
|
| 98 |
}
|
|
|
|
| 3 |
import time
|
| 4 |
import cv2
|
| 5 |
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
|
| 8 |
from weights_utils import ensure_all_vicca_weights, get_weight
|
| 9 |
+
from vg_token_attention import run_token_ca_visualization
|
| 10 |
+
|
| 11 |
|
| 12 |
# Make sure all heavy weights are present once per container
|
| 13 |
ensure_all_vicca_weights()
|
|
|
|
| 17 |
cal_shift,
|
| 18 |
get_local_bbox,
|
| 19 |
extract_tensor,
|
| 20 |
+
chexbert_pathology,
|
| 21 |
)
|
| 22 |
|
| 23 |
def run_vicca(
|
|
|
|
| 27 |
text_threshold: float = 0.2,
|
| 28 |
num_samples: int = 4,
|
| 29 |
output_path: str = "CXRGen/test/samples/output/",
|
| 30 |
+
attn_terms=None,
|
| 31 |
):
|
| 32 |
"""
|
| 33 |
Top-level VICCA API used by app.py / Gradio.
|
|
|
|
| 92 |
score = ssim(roi_org, roi_gen)
|
| 93 |
ssim_scores.append(score)
|
| 94 |
|
| 95 |
+
# Optional: attention visualization for terms (e.g. from CheXbert)
|
| 96 |
+
attn_paths = None
|
| 97 |
+
attn_terms = chexbert_pathology(text_prompt)
|
| 98 |
+
if attn_terms:
|
| 99 |
+
cfg_path = "VG/config/GroundingDINO_SwinT_OGC_2.py"
|
| 100 |
+
vg_ckpt_path = get_weight("VG/weights/checkpoint0399_log4.pth")
|
| 101 |
+
attn_out_dir = os.path.join(output_path, "attn_overlays")
|
| 102 |
+
|
| 103 |
+
attn_paths = run_token_ca_visualization(
|
| 104 |
+
cfg_path=cfg_path,
|
| 105 |
+
ckpt_path=vg_ckpt_path,
|
| 106 |
+
image_path=image_path, # or max_sim_gen_path if you prefer generated CXR
|
| 107 |
+
prompt=text_prompt,
|
| 108 |
+
terms=attn_terms,
|
| 109 |
+
out_dir=attn_out_dir,
|
| 110 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 111 |
+
score_thresh=0.25,
|
| 112 |
+
topk=100,
|
| 113 |
+
term_agg="mean",
|
| 114 |
+
save_per_term=True,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
return {
|
| 118 |
"boxes": boxes,
|
| 119 |
"logits": logits,
|
|
|
|
| 122 |
"shift_x": sx,
|
| 123 |
"shift_y": sy,
|
| 124 |
"best_generated_image_path": max_sim_gen_path,
|
| 125 |
+
"attention_overlays": attn_paths,
|
| 126 |
}
|