sayehghp commited on
Commit
27e3844
·
1 Parent(s): 3600cfd

Visualization

Browse files
Files changed (3) hide show
  1. app.py +3 -2
  2. requirements.txt +1 -0
  3. vicca_api.py +8 -1
app.py CHANGED
@@ -19,14 +19,14 @@ def vicca_interface(image, text_prompt, box_threshold=0.2, text_threshold=0.2, n
19
  )
20
 
21
  best_gen = result.get("best_generated_image_path")
22
-
23
  attn = result.get("attention_overlays") or {}
24
  combined = attn.get("combined")
25
  per_term_dict = attn.get("per_term") or {}
26
 
27
  gallery_items = [(p, term) for term, p in per_term_dict.items()]
28
 
29
- return best_gen, combined, gallery_items, result
30
 
31
  demo = gr.Interface(
32
  fn=vicca_interface,
@@ -39,6 +39,7 @@ demo = gr.Interface(
39
  ],
40
  outputs=[
41
  gr.Image(label="Best generated CXR"),
 
42
  gr.Image(label="Combined attention heatmap"),
43
  # gr.Gallery(label="Per-term overlays").style(grid=[3], height=400),
44
  gr.Gallery(
 
19
  )
20
 
21
  best_gen = result.get("best_generated_image_path")
22
+ VG_path = result.get("VG_annotated_image_path")
23
  attn = result.get("attention_overlays") or {}
24
  combined = attn.get("combined")
25
  per_term_dict = attn.get("per_term") or {}
26
 
27
  gallery_items = [(p, term) for term, p in per_term_dict.items()]
28
 
29
+ return best_gen, VG_path, combined, gallery_items, result
30
 
31
  demo = gr.Interface(
32
  fn=vicca_interface,
 
39
  ],
40
  outputs=[
41
  gr.Image(label="Best generated CXR"),
42
+ gr.Image(label="VG annotated image"),
43
  gr.Image(label="Combined attention heatmap"),
44
  # gr.Gallery(label="Per-term overlays").style(grid=[3], height=400),
45
  gr.Gallery(
requirements.txt CHANGED
@@ -49,6 +49,7 @@ scikit-learn
49
  scikit-image
50
  tqdm
51
  statsmodels
 
52
 
53
  # Formatting / style
54
  yapf
 
49
  scikit-image
50
  tqdm
51
  statsmodels
52
+ supervision
53
 
54
  # Formatting / style
55
  yapf
vicca_api.py CHANGED
@@ -4,10 +4,11 @@ import time
4
  import cv2
5
  import pandas as pd
6
  import torch
 
7
 
8
  from weights_utils import ensure_all_vicca_weights, get_weight
9
  from vg_token_attention import run_token_ca_visualization
10
-
11
 
12
  # Make sure all heavy weights are present once per container
13
  ensure_all_vicca_weights()
@@ -69,6 +70,11 @@ def run_vicca(
69
  box_threshold,
70
  text_threshold,
71
  )
 
 
 
 
 
72
 
73
  # 5) SSIM per bbox
74
  image_org_cv = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
@@ -126,4 +132,5 @@ def run_vicca(
126
  "shift_y": sy,
127
  "best_generated_image_path": max_sim_gen_path,
128
  "attention_overlays": attn_paths,
 
129
  }
 
4
  import cv2
5
  import pandas as pd
6
  import torch
7
+ import supervision as sv
8
 
9
  from weights_utils import ensure_all_vicca_weights, get_weight
10
  from vg_token_attention import run_token_ca_visualization
11
+ from VG.groundingdino.util.inference import annotate
12
 
13
  # Make sure all heavy weights are present once per container
14
  ensure_all_vicca_weights()
 
70
  box_threshold,
71
  text_threshold,
72
  )
73
+ annotate_dict = dict(color=sv.ColorPalette.DEFAULT, thickness=2, text_thickness=1)
74
+
75
+ annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases, bbox_annot=annotate_dict)
76
+ VG_path = os.path.join(output_path, "VG_annotations.jpg")
77
+ cv2.imwrite(VG_path, annotated_frame)
78
 
79
  # 5) SSIM per bbox
80
  image_org_cv = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
 
132
  "shift_y": sy,
133
  "best_generated_image_path": max_sim_gen_path,
134
  "attention_overlays": attn_paths,
135
+ "VG_annotated_image_path": VG_path,
136
  }