AustingDong
commited on
Commit
·
6d117d1
1
Parent(s):
7e57874
customed a loss (useless)
Browse files- demo/visualization.py +11 -0
- evaluate/evaluate.py +7 -9
- evaluate/questions.py +12 -12
demo/visualization.py
CHANGED
|
@@ -406,6 +406,16 @@ class VisualizationChartGemma(Visualization):
|
|
| 406 |
super().__init__(model, register=True)
|
| 407 |
self._modify_layers()
|
| 408 |
self._register_hooks_activations()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
def forward_backward(self, inputs, focus, start_idx, target_token_idx, visual_method="softmax"):
|
| 411 |
outputs_raw = self.model(**inputs, output_hidden_states=True)
|
|
@@ -421,6 +431,7 @@ class VisualizationChartGemma(Visualization):
|
|
| 421 |
print("logits shape:", outputs_raw.logits.shape)
|
| 422 |
if target_token_idx == -1:
|
| 423 |
loss = outputs_raw.logits.max(dim=-1).values.sum()
|
|
|
|
| 424 |
else:
|
| 425 |
loss = outputs_raw.logits.max(dim=-1).values[0, start_idx + target_token_idx]
|
| 426 |
loss.backward()
|
|
|
|
| 406 |
super().__init__(model, register=True)
|
| 407 |
self._modify_layers()
|
| 408 |
self._register_hooks_activations()
|
| 409 |
+
|
| 410 |
+
# def custom_loss(self, start_idx, input_ids, logits):
|
| 411 |
+
# Q = logits.shape[1]
|
| 412 |
+
# loss = 0
|
| 413 |
+
# q = 0
|
| 414 |
+
# while start_idx + q < Q - 1:
|
| 415 |
+
# loss += F.cross_entropy(logits[0, start_idx + q], input_ids[0, start_idx + q + 1])
|
| 416 |
+
# q += 1
|
| 417 |
+
# return loss
|
| 418 |
+
|
| 419 |
|
| 420 |
def forward_backward(self, inputs, focus, start_idx, target_token_idx, visual_method="softmax"):
|
| 421 |
outputs_raw = self.model(**inputs, output_hidden_states=True)
|
|
|
|
| 431 |
print("logits shape:", outputs_raw.logits.shape)
|
| 432 |
if target_token_idx == -1:
|
| 433 |
loss = outputs_raw.logits.max(dim=-1).values.sum()
|
| 434 |
+
# loss = self.custom_loss(start_idx, inputs['input_ids'], outputs_raw.logits)
|
| 435 |
else:
|
| 436 |
loss = outputs_raw.logits.max(dim=-1).values[0, start_idx + target_token_idx]
|
| 437 |
loss.backward()
|
evaluate/evaluate.py
CHANGED
|
@@ -7,9 +7,9 @@ from openai import OpenAI
|
|
| 7 |
from demo.model_utils import *
|
| 8 |
from evaluate.questions import questions
|
| 9 |
|
| 10 |
-
def set_seed(model_seed =
|
| 11 |
torch.manual_seed(model_seed)
|
| 12 |
-
np.random.seed(model_seed)
|
| 13 |
torch.cuda.manual_seed(model_seed) if torch.cuda.is_available() else None
|
| 14 |
|
| 15 |
def clean():
|
|
@@ -52,7 +52,7 @@ def evaluate(model_type, num_eval = 10):
|
|
| 52 |
client = OpenAI(api_key=os.environ["GEMINI_HCI_API_KEY"],
|
| 53 |
base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
|
| 54 |
|
| 55 |
-
for question in questions:
|
| 56 |
chart_type = question[0]
|
| 57 |
q = question[1]
|
| 58 |
img_path = question[2]
|
|
@@ -104,8 +104,8 @@ def evaluate(model_type, num_eval = 10):
|
|
| 104 |
|
| 105 |
else:
|
| 106 |
prepare_inputs = model_utils.prepare_inputs(q, image)
|
| 107 |
-
temperature = 0.
|
| 108 |
-
top_p = 0.
|
| 109 |
|
| 110 |
if model_type.split('-')[0] == "Janus":
|
| 111 |
inputs_embeds = model_utils.generate_inputs_embeddings(prepare_inputs)
|
|
@@ -120,7 +120,7 @@ def evaluate(model_type, num_eval = 10):
|
|
| 120 |
FILES_ROOT = f"{RESULTS_ROOT}/{model_type}/{eval_idx}"
|
| 121 |
os.makedirs(FILES_ROOT, exist_ok=True)
|
| 122 |
|
| 123 |
-
with open(f"{FILES_ROOT}/{chart_type}.txt", "w") as f:
|
| 124 |
f.write(answer)
|
| 125 |
f.close()
|
| 126 |
|
|
@@ -129,8 +129,6 @@ def evaluate(model_type, num_eval = 10):
|
|
| 129 |
if __name__ == '__main__':
|
| 130 |
|
| 131 |
# models = ["ChartGemma", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B", "GPT-4o", "Gemini-2.0-flash"]
|
| 132 |
-
|
| 133 |
-
# models = ["Janus-Pro-7B", "LLaVA-1.5-7B"]
|
| 134 |
-
models = ["GPT-4o", "Gemini-2.0-flash"]
|
| 135 |
for model_type in models:
|
| 136 |
evaluate(model_type=model_type, num_eval=10)
|
|
|
|
| 7 |
from demo.model_utils import *
|
| 8 |
from evaluate.questions import questions
|
| 9 |
|
| 10 |
+
def set_seed(model_seed = 70):
|
| 11 |
torch.manual_seed(model_seed)
|
| 12 |
+
# np.random.seed(model_seed)
|
| 13 |
torch.cuda.manual_seed(model_seed) if torch.cuda.is_available() else None
|
| 14 |
|
| 15 |
def clean():
|
|
|
|
| 52 |
client = OpenAI(api_key=os.environ["GEMINI_HCI_API_KEY"],
|
| 53 |
base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
|
| 54 |
|
| 55 |
+
for question_idx, question in enumerate(questions):
|
| 56 |
chart_type = question[0]
|
| 57 |
q = question[1]
|
| 58 |
img_path = question[2]
|
|
|
|
| 104 |
|
| 105 |
else:
|
| 106 |
prepare_inputs = model_utils.prepare_inputs(q, image)
|
| 107 |
+
temperature = 0.1
|
| 108 |
+
top_p = 0.95
|
| 109 |
|
| 110 |
if model_type.split('-')[0] == "Janus":
|
| 111 |
inputs_embeds = model_utils.generate_inputs_embeddings(prepare_inputs)
|
|
|
|
| 120 |
FILES_ROOT = f"{RESULTS_ROOT}/{model_type}/{eval_idx}"
|
| 121 |
os.makedirs(FILES_ROOT, exist_ok=True)
|
| 122 |
|
| 123 |
+
with open(f"{FILES_ROOT}/Q{question_idx + 1}-{chart_type}.txt", "w") as f:
|
| 124 |
f.write(answer)
|
| 125 |
f.close()
|
| 126 |
|
|
|
|
| 129 |
if __name__ == '__main__':
|
| 130 |
|
| 131 |
# models = ["ChartGemma", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B", "GPT-4o", "Gemini-2.0-flash"]
|
| 132 |
+
models = ["Janus-Pro-7B"]
|
|
|
|
|
|
|
| 133 |
for model_type in models:
|
| 134 |
evaluate(model_type=model_type, num_eval=10)
|
evaluate/questions.py
CHANGED
|
@@ -2,72 +2,72 @@ questions=[
|
|
| 2 |
[
|
| 3 |
"LineChart",
|
| 4 |
"What was the price of a barrel of oil in February 2020?",
|
| 5 |
-
"images/LineChart.png"
|
| 6 |
],
|
| 7 |
|
| 8 |
[
|
| 9 |
"BarChart",
|
| 10 |
"What is the average internet speed in Japan?",
|
| 11 |
-
"images/BarChart.png"
|
| 12 |
],
|
| 13 |
|
| 14 |
[
|
| 15 |
"StackedBar",
|
| 16 |
"What is the cost of peanuts in Seoul?",
|
| 17 |
-
"images/StackedBar.png"
|
| 18 |
],
|
| 19 |
|
| 20 |
[
|
| 21 |
"100%StackedBar",
|
| 22 |
"Which country has the lowest proportion of Gold medals?",
|
| 23 |
-
"images/Stacked100.png"
|
| 24 |
],
|
| 25 |
|
| 26 |
[
|
| 27 |
"PieChart",
|
| 28 |
"What is the approximate global smartphone market share of Samsung?",
|
| 29 |
-
"images/PieChart.png"
|
| 30 |
],
|
| 31 |
|
| 32 |
[
|
| 33 |
"Histogram",
|
| 34 |
"What distance have customers traveled in the taxi the most?",
|
| 35 |
-
"images/Histogram.png"
|
| 36 |
],
|
| 37 |
|
| 38 |
[
|
| 39 |
"Scatterplot",
|
| 40 |
"True/False: There is a negative linear relationship between the height and the weight of the 85 males.",
|
| 41 |
-
"images/Scatterplot.png"
|
| 42 |
],
|
| 43 |
|
| 44 |
[
|
| 45 |
"AreaChart",
|
| 46 |
"What was the average price of pount of coffee beans in October 2019?",
|
| 47 |
-
"images/AreaChart.png"
|
| 48 |
],
|
| 49 |
|
| 50 |
[
|
| 51 |
"StackedArea",
|
| 52 |
"What was the ratio of girls named 'Isla' to girls named 'Amelia' in 2012 in the UK?",
|
| 53 |
-
"images/StackedArea.png"
|
| 54 |
],
|
| 55 |
|
| 56 |
[
|
| 57 |
"BubbleChart",
|
| 58 |
"Which city's metro system has the largest number of stations?",
|
| 59 |
-
"images/BubbleChart.png"
|
| 60 |
],
|
| 61 |
|
| 62 |
[
|
| 63 |
"Choropleth",
|
| 64 |
"True/False: In 2020, the unemployment rate for Washington (WA) was higher than that of Wisconsin (WI).",
|
| 65 |
-
"images/Choropleth_New.png"
|
| 66 |
],
|
| 67 |
|
| 68 |
[
|
| 69 |
"TreeMap",
|
| 70 |
"True/False: eBay is nested in the Software category.",
|
| 71 |
-
"images/TreeMap.png"
|
| 72 |
]
|
| 73 |
]
|
|
|
|
| 2 |
[
|
| 3 |
"LineChart",
|
| 4 |
"What was the price of a barrel of oil in February 2020?",
|
| 5 |
+
"images/mini-VLAT/LineChart.png"
|
| 6 |
],
|
| 7 |
|
| 8 |
[
|
| 9 |
"BarChart",
|
| 10 |
"What is the average internet speed in Japan?",
|
| 11 |
+
"images/mini-VLAT/BarChart.png"
|
| 12 |
],
|
| 13 |
|
| 14 |
[
|
| 15 |
"StackedBar",
|
| 16 |
"What is the cost of peanuts in Seoul?",
|
| 17 |
+
"images/mini-VLAT/StackedBar.png"
|
| 18 |
],
|
| 19 |
|
| 20 |
[
|
| 21 |
"100%StackedBar",
|
| 22 |
"Which country has the lowest proportion of Gold medals?",
|
| 23 |
+
"images/mini-VLAT/Stacked100.png"
|
| 24 |
],
|
| 25 |
|
| 26 |
[
|
| 27 |
"PieChart",
|
| 28 |
"What is the approximate global smartphone market share of Samsung?",
|
| 29 |
+
"images/mini-VLAT/PieChart.png"
|
| 30 |
],
|
| 31 |
|
| 32 |
[
|
| 33 |
"Histogram",
|
| 34 |
"What distance have customers traveled in the taxi the most?",
|
| 35 |
+
"images/mini-VLAT/Histogram.png"
|
| 36 |
],
|
| 37 |
|
| 38 |
[
|
| 39 |
"Scatterplot",
|
| 40 |
"True/False: There is a negative linear relationship between the height and the weight of the 85 males.",
|
| 41 |
+
"images/mini-VLAT/Scatterplot.png"
|
| 42 |
],
|
| 43 |
|
| 44 |
[
|
| 45 |
"AreaChart",
|
| 46 |
"What was the average price of pount of coffee beans in October 2019?",
|
| 47 |
+
"images/mini-VLAT/AreaChart.png"
|
| 48 |
],
|
| 49 |
|
| 50 |
[
|
| 51 |
"StackedArea",
|
| 52 |
"What was the ratio of girls named 'Isla' to girls named 'Amelia' in 2012 in the UK?",
|
| 53 |
+
"images/mini-VLAT/StackedArea.png"
|
| 54 |
],
|
| 55 |
|
| 56 |
[
|
| 57 |
"BubbleChart",
|
| 58 |
"Which city's metro system has the largest number of stations?",
|
| 59 |
+
"images/mini-VLAT/BubbleChart.png"
|
| 60 |
],
|
| 61 |
|
| 62 |
[
|
| 63 |
"Choropleth",
|
| 64 |
"True/False: In 2020, the unemployment rate for Washington (WA) was higher than that of Wisconsin (WI).",
|
| 65 |
+
"images/mini-VLAT/Choropleth_New.png"
|
| 66 |
],
|
| 67 |
|
| 68 |
[
|
| 69 |
"TreeMap",
|
| 70 |
"True/False: eBay is nested in the Software category.",
|
| 71 |
+
"images/mini-VLAT/TreeMap.png"
|
| 72 |
]
|
| 73 |
]
|