""" Input: image and text Middle output: bbox (VG), Gen Image and similarity score (CXRGen), Shift_x&y (DETR) Output: Localization Score, Reliability Score python inference.py \ --image_path VG/38708899-5132e206-88cb58cf-d55a7065-6cbc983d.jpg \ --text_prompt "Cardiomegaly with mild pulmonary vascular congestion." """ import sys, os # --------------------------------------------------------------------- # Make CheXbert's `src` folder importable (so `import utils` works) # --------------------------------------------------------------------- # BASE_DIR = os.path.dirname(__file__) # CHEXBERT_SRC = os.path.join(BASE_DIR, "CheXbert", "src") # if CHEXBERT_SRC not in sys.path: # sys.path.insert(0, CHEXBERT_SRC) # from label import label # now imports /app/CheXbert/src/label.py import pandas as pd import numpy as np import time import cv2 import argparse from ast import literal_eval # from nltk import tokenize # sys.path.append('/home/gholipos-admin/Desktop/Thesis/Training_Code/VICCA') from pathlib import Path import shutil from huggingface_hub import hf_hub_download from weights_utils import get_weight # def ensure_vicca_weights(): # """ # Download all VICCA weights from the vicca-weights repo into the paths # expected by the original code, with caching and safe subfolder handling. # """ # repo_id = "sayehghp/vicca-weights" # base = Path(__file__).parent # weight_files = [ # # CheXbert # "CheXbert/checkpoint/chexbert.pth", # # Uniformer # "CXRGen/annotator/ckpts/upernet_global_small.pth", # # Diffusion # "CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth", # # Encoders # "CXRGen/ldm/modules/encoders/BiomedVLP-CXR-BERT/pytorch_model.bin", # "VG/weights/BiomedVLP-CXR-BERT/pytorch_model.bin", # # Lung UNet # "CXRGen/LungDetection/models/unet-2v.pt", # "CXRGen/LungDetection/models/unet-6v.pt", # # DETR # "DETR/output/checkpoint.pth", # # VG weights # "VG/weights/checkpoint0399.pth", # "VG/weights/checkpoint0399_log4.pth", # "VG/weights/checkpoint_best_regular.pth", # ] # for rel_path in weight_files: # local_path = base / rel_path # local_path.parent.mkdir(parents=True, exist_ok=True) # if local_path.exists(): # continue # skip if already mirrored into repo tree # # Split repo path # if "/" in rel_path: # subfolder, filename = rel_path.rsplit("/", 1) # else: # subfolder, filename = None, rel_path # cached_path = hf_hub_download( # repo_id=repo_id, # filename=filename, # subfolder=subfolder if subfolder else None # ) # # Copy from HF cache → repo tree # shutil.copy2(cached_path, local_path) # # Run once at import time so all weights are present before anything loads them # ensure_vicca_weights() # ---- SHIM FOR basicsr / torchvision ---- import types from torchvision.transforms import functional as F # Create a fake module torchvision.transforms.functional_tensor # and expose rgb_to_grayscale from torchvision.transforms.functional mod = types.ModuleType("torchvision.transforms.functional_tensor") mod.rgb_to_grayscale = F.rgb_to_grayscale sys.modules["torchvision.transforms.functional_tensor"] = mod # ---- END SHIM ---- from CXRGen import sample_generation from DETR import svc from DETR.arguments import get_args_parser as get_detr_args_parser from VG import localization from ssim import ssim import torch from CheXbert.src.label import label def get_args_parser(): parser = argparse.ArgumentParser('Set the Input', add_help=True) parser.add_argument('--weight_path_gencxr', type=str, default="CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth", help="Path to the CXR generation trained model") parser.add_argument('--weight_path_vg', type=str, default="VG/weights/checkpoint0399_log4.pth", help="Path to the Visual Grounding trained model") parser.add_argument('--image_path', type=str, required=True, help="Path to the input image file.") parser.add_argument('--text_prompt', type=str, required=True, help="Text prompt describing pathology.") parser.add_argument('--box_threshold', default=0.2, type=float, help="Box threshold for VG") parser.add_argument('--text_threshold', default=0.2, type=float, help="Text threshold for VG") parser.add_argument('--num_samples', type=int, default=4, help="Number of generated image samples.") parser.add_argument('--output_path', type=str, default="CXRGen/test/samples/output/", help="Path to save generated files.") return parser import re def simple_sentence_split(text: str): """ Very lightweight sentence splitter good enough for radiology reports. Splits on '.', ';', and newlines, then strips whitespace. """ parts = re.split(r"[.\n;]+", text) return [p.strip() for p in parts if p.strip()] path_list = ['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices', 'No Finding'] # Cache CheXbert weights once at import time CHEXBERT_WEIGHTS = get_weight("CheXbert/checkpoint/chexbert.pth") # def chexbert_pathology(text): # sentences = list(set(tokenize.sent_tokenize(text))) # path_dict = [] # for sentence in sentences: # sentence = sentence.replace('\n',' ') # sentence = sentence.replace('\s+',' ') # chexbert_weight_path = get_weight("CheXbert/checkpoint/chexbert.pth") # # pathology = np.array(label("CheXbert/checkpoint/chexbert.pth", sentence)).T[0] # pathology = np.array(label(chexbert_weight_path, sentence)).T[0] # if pathology[-1]==1 or len(list(set(pathology)))==1 or not any(e==1 for e in pathology): # pass # else: # indice = [i for i, e in enumerate(pathology) if e==1] # for ind in indice: # path_dict.append(path_list[ind]) # return path_dict def chexbert_pathology(text: str): """ Run CheXbert on the text and return a list of *positive* pathology labels, deduplicated. """ # If NLTK punkt ever becomes a problem on Spaces, replace this with a simple split. # sentences = list(set(tokenize.sent_tokenize(text))) # sentences = [s.strip() for s in text.split(".") if s.strip()] sentences = list(set(simple_sentence_split(text))) path_terms = set() for sentence in sentences: sentence = sentence.replace("\n", " ") sentence = sentence.replace("\s+", " ") # Run CheXbert pathology = np.array(label(CHEXBERT_WEIGHTS, sentence)).T[0] # Skip if: "No Finding" active, or all labels same, or no positives if pathology[-1] == 1 or len(set(pathology)) == 1 or not any(e == 1 for e in pathology): continue # Collect positive indices indices = [i for i, e in enumerate(pathology) if e == 1] for ind in indices: path_terms.add(path_list[ind]) return sorted(path_terms) def extract_tensor(value): cleaned_value = value.replace('tensor(', '').replace(')', '') return literal_eval(cleaned_value) def gen_cxr(weight_path, image_path, text_prompt, num_samples, output_path, device: str = "cpu"): parser = sample_generation.get_args_parser() args = parser.parse_args([]) # args.weight_path = weight_path args.image_path = image_path args.text_prompt = text_prompt args.num_samples = num_samples args.output_path = output_path args.weight_path = get_weight(weight_path) if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") args.device = device sample_generation.main(args) def cal_shift(img_org_path, img_gen_path): parser = get_detr_args_parser() args = parser.parse_args([]) args.read_checkpoint = get_weight("DETR/output/checkpoint.pth") args.img_org = img_org_path args.img_gen = img_gen_path shift_x, shift_y = svc.main(args) return shift_x, shift_y def get_local_bbox(weight_path, image_path, text_prompt, box_threshold, text_threshold): parser = localization.get_args_parser() args = parser.parse_args([]) # vg_ckpt_main = get_weight("VG/weights/checkpoint0399.pth") # vg_ckpt_best = get_weight("VG/weights/checkpoint_best_regular.pth") # vg_ckpt_log4 = get_weight("VG/weights/checkpoint0399_log4.pth") # args.weight_path = weight_path args.weight_path = get_weight(weight_path) args.image_path = image_path args.text_prompt = text_prompt args.box_threshold = box_threshold args.text_threshold = text_threshold bbox, logits, phrases = localization.main(args) return bbox, logits, phrases if __name__ == "__main__": args = get_args_parser().parse_args() gen_cxr(args.weight_path_gencxr, args.image_path, args.text_prompt, args.num_samples, args.output_path) time.sleep(4) # ensure outputs are written df = pd.read_csv(args.output_path + "info_path_similarity.csv") sim_ratios = [extract_tensor(val) for val in df["similarity_rate"]] max_sim_index = sim_ratios.index(max(sim_ratios)) max_sim_gen_path = df["gen_sample_path"][max_sim_index] sx, sy = cal_shift(args.image_path, max_sim_gen_path) boxes, logits, phrases = get_local_bbox( args.weight_path_vg, args.image_path, args.text_prompt, args.box_threshold, args.text_threshold ) print("Boxes:", boxes) print("Phrases:", phrases) image_org_cv = cv2.imread(args.image_path, cv2.IMREAD_GRAYSCALE) image_gen_cv = cv2.imread(max_sim_gen_path, cv2.IMREAD_GRAYSCALE) ssim_scores = [] for bbox in boxes: x1, y1, x2, y2 = bbox bbox1 = [x1, y1, x2 - x1, y2 - y1] bbox2 = [x1 + sx, y1 + sy, x2 - x1, y2 - y1] bx1, by1, bw1, bh1 = [int(val) for val in bbox1] bx2, by2, bw2, bh2 = [int(val) for val in bbox2] roi_org = image_org_cv[by1:by1 + bh1, bx1:bx1 + bw1] roi_gen = image_gen_cv[by2:by2 + bh2, bx2:bx2 + bw2] if roi_org.shape == roi_gen.shape and roi_org.size > 0: score = ssim(roi_org, roi_gen) ssim_scores.append(score) if ssim_scores: print("SSIM scores per box:", ssim_scores) print("Localization Detection Scores per bbox:", boxes, logits) # print("Average SSIM (Localization Score):", sum(ssim_scores) / len(ssim_scores)) else: print("No valid SSIM scores (e.g., mismatched shapes or empty ROIs).")