doncamilom commited on
Commit
142faac
·
1 Parent(s): 84d3480

update with our app

Browse files
Files changed (3) hide show
  1. Dockerfile +1 -1
  2. src/app.py +406 -0
  3. src/streamlit_app.py +0 -40
Dockerfile CHANGED
@@ -18,4 +18,4 @@ EXPOSE 8501
18
 
19
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
20
 
21
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
18
 
19
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
20
 
21
+ ENTRYPOINT ["streamlit", "run", "src/app.py", "--server.port=8501", "--server.address=0.0.0.0"]
src/app.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streamlit app for human evaluation of model outputs.
3
+
4
+ Allows users to select two models, compare their responses to the same inputs,
5
+ and record preferences for subsequent analysis.
6
+ """
7
+ import os
8
+ import json
9
+ import csv
10
+ from datetime import datetime
11
+
12
+ import streamlit as st
13
+ import pandas as pd
14
+
15
+ st.set_page_config(page_title="Model Comparison Evaluation", layout="wide")
16
+
17
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
18
+ DATA_DIR = os.path.join(SCRIPT_DIR, "for_experiments_prediction")
19
+
20
+
21
+ @st.cache_data
22
+ def load_models(data_dir):
23
+ """
24
+ Discover prediction JSON files named 'predicted_vs_gt.json' and load flattened records for each model.
25
+ Returns a dict mapping model name to dict of {id: record}.
26
+ """
27
+ model_paths = {}
28
+ for root, _, files in os.walk(data_dir):
29
+ for fname in files:
30
+ if fname == 'predicted_vs_gt.json':
31
+ path = os.path.join(root, fname)
32
+ rel = os.path.relpath(root, data_dir)
33
+ model_paths[rel] = path
34
+ models = {}
35
+ for model_name, path in sorted(model_paths.items()):
36
+ with open(path, 'r', encoding='utf-8') as f:
37
+ data = json.load(f)
38
+ records = {}
39
+ for section in data.values():
40
+ if isinstance(section, dict):
41
+ for sub in section.values():
42
+ for rec in sub:
43
+ records[rec['id']] = rec
44
+ elif isinstance(section, list):
45
+ for rec in section:
46
+ records[rec['id']] = rec
47
+ models[model_name] = records
48
+ return models
49
+
50
+
51
+ def append_feedback(feedback_file, header, row):
52
+ """
53
+ Append a single feedback row to TSV, writing header if file does not exist.
54
+ """
55
+ write_header = not os.path.exists(feedback_file)
56
+ with open(feedback_file, 'a', newline='', encoding='utf-8') as f:
57
+ writer = csv.writer(f, delimiter='\t', quoting=csv.QUOTE_ALL)
58
+ if write_header:
59
+ writer.writerow(header)
60
+ writer.writerow(row)
61
+
62
+ @st.cache_data
63
+ def load_eval_tables(data_dir):
64
+ """
65
+ Discover evaluation_table.parquet files under each model directory and load each into a pandas DataFrame.
66
+ Returns a dict mapping model name to its evaluation DataFrame.
67
+ """
68
+ tables = {}
69
+ for root, _, files in os.walk(data_dir):
70
+ if 'evaluation_table.parquet' in files:
71
+ path = os.path.join(root, 'evaluation_table.parquet')
72
+ rel = os.path.relpath(root, data_dir)
73
+ tables[rel] = pd.read_parquet(path)
74
+ return tables
75
+
76
+
77
+ def main():
78
+ st.title("Model Comparison Evaluation")
79
+
80
+ print(DATA_DIR)
81
+ models = load_models(DATA_DIR)
82
+ eval_tables = load_eval_tables(DATA_DIR)
83
+ all_cols = set()
84
+ for df in eval_tables.values():
85
+ all_cols.update(df.columns)
86
+ key_columns = {'gt_sac_id', 'gt_title'}
87
+ metric_columns = sorted(all_cols - key_columns)
88
+ fixed_metrics = [
89
+ 'chemicals_accuracy',
90
+ 'chemicals_f1_score',
91
+ 'chemicals_precision',
92
+ 'chemicals_recall',
93
+ 'metal_accuracy',
94
+ 'metal_f1_score',
95
+ 'metal_precision',
96
+ 'metal_recall',
97
+ 'procedure_procedure_completeness_score',
98
+ 'procedure_procedure_order_score',
99
+ 'procedure_procedure_accuracy_score',
100
+ 'support_accuracy',
101
+ 'support_f1_score',
102
+ 'support_precision',
103
+ 'support_recall',
104
+ ]
105
+ other_metrics = sorted([c for c in metric_columns if c not in fixed_metrics and not c.startswith('gt_')])
106
+ model_names = list(models.keys())
107
+
108
+ st.sidebar.header("Configuration")
109
+ def reset_index():
110
+ st.session_state.idx = 0
111
+ # Reset feedback file when models change
112
+ feedback_path = os.path.join(SCRIPT_DIR, 'feedback.tsv')
113
+ if os.path.exists(feedback_path):
114
+ os.remove(feedback_path)
115
+
116
+ selected = st.sidebar.multiselect(
117
+ "Select exactly two models to compare",
118
+ options=model_names,
119
+ key='models',
120
+ help="Choose two model variants for side-by-side comparison",
121
+ on_change=reset_index
122
+ )
123
+ if len(selected) != 2:
124
+ st.sidebar.info("Please select exactly two models.")
125
+ st.stop()
126
+
127
+ # Download button for feedback TSV
128
+ feedback_path = os.path.join(SCRIPT_DIR, 'feedback.tsv')
129
+ if os.path.exists(feedback_path):
130
+ with open(feedback_path, 'r', encoding='utf-8') as f:
131
+ tsv_data = f.read()
132
+ st.sidebar.download_button(
133
+ label="Download Feedback TSV",
134
+ data=tsv_data,
135
+ file_name="feedback.tsv",
136
+ mime="text/tab-separated-values"
137
+ )
138
+
139
+ m1, m2 = selected
140
+ recs1 = models[m1]
141
+ recs2 = models[m2]
142
+ common_ids = sorted(set(recs1.keys()) & set(recs2.keys()))
143
+ if not common_ids:
144
+ st.error("No common records between the selected models.")
145
+ st.stop()
146
+
147
+ if 'idx' not in st.session_state:
148
+ st.session_state.idx = 0
149
+ if 'feedback_saved' not in st.session_state:
150
+ st.session_state.feedback_saved = False
151
+
152
+ # Initialize fresh feedback file for new session
153
+ if 'session_initialized' not in st.session_state:
154
+ feedback_path = os.path.join(SCRIPT_DIR, 'feedback.tsv')
155
+ if os.path.exists(feedback_path):
156
+ os.remove(feedback_path)
157
+ st.session_state.session_initialized = True
158
+
159
+ total = len(common_ids)
160
+ idx = st.session_state.idx
161
+ if idx < 0:
162
+ idx = 0
163
+ if idx >= total:
164
+ st.write("### Evaluation complete! Thank you for your feedback.")
165
+ st.stop()
166
+
167
+ current_id = common_ids[idx]
168
+ rec1 = recs1[current_id]
169
+ rec2 = recs2[current_id]
170
+
171
+ st.markdown(f"**Record {idx+1}/{total} — ID: {current_id}**")
172
+ st.markdown("---")
173
+ st.subheader("Input Prompt")
174
+ st.code(rec1.get('input', ''), language='')
175
+
176
+ st.subheader("Model Responses and Ground Truth")
177
+ col1, col2, col3 = st.columns(3)
178
+ with col1:
179
+ st.markdown(f"**{m1}**")
180
+ st.text_area("", rec1.get('predicted', ''), height=600, key=f"resp1_{idx}")
181
+ with col2:
182
+ st.markdown(f"**{m2}**")
183
+ st.text_area("", rec2.get('predicted', ''), height=600, key=f"resp2_{idx}")
184
+ with col3:
185
+ st.markdown("**Ground Truth**")
186
+ st.text_area("", rec1.get('ground_truth', ''), height=600, key=f"gt_{idx}")
187
+
188
+ fcol1, fcol2, fcol3 = st.columns(3)
189
+ with fcol1:
190
+ df1 = eval_tables.get(m1)
191
+ if df1 is not None:
192
+ if 'gt_sac_id' in df1.columns:
193
+ key_val = rec1.get('gt_sac_id', rec1.get('sac_id'))
194
+ key_col = 'gt_sac_id'
195
+ elif 'gt_title' in df1.columns:
196
+ key_val = rec1.get('gt_title', rec1.get('title'))
197
+ key_col = 'gt_title'
198
+ else:
199
+ key_col = key_val = None
200
+ if key_col and key_val is not None:
201
+ row = df1[df1[key_col] == key_val]
202
+ if not row.empty:
203
+ fm_df = row[fixed_metrics].T
204
+ fm_df.columns = ['value']
205
+ st.table(fm_df)
206
+ else:
207
+ st.info("No fixed metrics for this record.")
208
+ else:
209
+ st.info("No evaluation table available for this model.")
210
+
211
+ with fcol2:
212
+ df2 = eval_tables.get(m2)
213
+ if df2 is not None:
214
+ if 'gt_sac_id' in df2.columns:
215
+ key_val = rec1.get('gt_sac_id', rec1.get('sac_id'))
216
+ key_col = 'gt_sac_id'
217
+ elif 'gt_title' in df2.columns:
218
+ key_val = rec1.get('gt_title', rec1.get('title'))
219
+ key_col = 'gt_title'
220
+ else:
221
+ key_col = key_val = None
222
+ if key_col and key_val is not None:
223
+ row = df2[df2[key_col] == key_val]
224
+ if not row.empty:
225
+ fm_df = row[fixed_metrics].T.astype(float).mean(axis=1)
226
+ fm_df.columns = ['value']
227
+ st.table(fm_df)
228
+ else:
229
+ st.info("No fixed metrics for this record.")
230
+ else:
231
+ st.info("No evaluation table available for this model.")
232
+
233
+ if other_metrics:
234
+ selected_metric = st.selectbox(
235
+ "Select additional metric to display",
236
+ options=other_metrics,
237
+ key=f"metric_sel_{idx}"
238
+ )
239
+ else:
240
+ selected_metric = None
241
+
242
+ if selected_metric:
243
+ mcol1, mcol2, mcol3 = st.columns(3)
244
+ with mcol1:
245
+ df1 = eval_tables.get(m1)
246
+ if df1 is not None and selected_metric in df1.columns:
247
+ if 'gt_sac_id' in df1.columns:
248
+ key_val = rec1.get('gt_sac_id', rec1.get('sac_id'))
249
+ key_col = 'gt_sac_id'
250
+ elif 'gt_title' in df1.columns:
251
+ key_val = rec1.get('gt_title', rec1.get('title'))
252
+ key_col = 'gt_title'
253
+ else:
254
+ key_col = key_val = None
255
+ if key_col and key_val is not None:
256
+ row = df1[df1[key_col] == key_val]
257
+ if not row.empty:
258
+ value = row[selected_metric].iloc[0]
259
+ try:
260
+ # Try to parse as JSON first
261
+ parsed_json = json.loads(str(value))
262
+ formatted_json = json.dumps(parsed_json, indent=2)
263
+ st.markdown(f"**{selected_metric}:**")
264
+ st.code(formatted_json, language='json')
265
+ except json.JSONDecodeError:
266
+ try:
267
+ # If JSON fails, try to evaluate as Python literal (handles single quotes)
268
+ import ast
269
+ parsed_json = ast.literal_eval(str(value))
270
+ formatted_json = json.dumps(parsed_json, indent=2)
271
+ st.markdown(f"**{selected_metric}:**")
272
+ st.code(formatted_json, language='json')
273
+ except (ValueError, SyntaxError):
274
+ # If all parsing fails, show as raw text
275
+ st.markdown(f"**{selected_metric}:** {value}")
276
+ except (TypeError, ValueError):
277
+ st.markdown(f"**{selected_metric}:** {value}")
278
+ else:
279
+ st.markdown(f"**{selected_metric}:** N/A")
280
+ else:
281
+ st.markdown(f"**{selected_metric}:** N/A")
282
+
283
+ with mcol2:
284
+ df2 = eval_tables.get(m2)
285
+ if df2 is not None and selected_metric in df2.columns:
286
+ if 'gt_sac_id' in df2.columns:
287
+ key_val = rec1.get('gt_sac_id', rec1.get('sac_id'))
288
+ key_col = 'gt_sac_id'
289
+ elif 'gt_title' in df2.columns:
290
+ key_val = rec1.get('gt_title', rec1.get('title'))
291
+ key_col = 'gt_title'
292
+ else:
293
+ key_col = key_val = None
294
+ if key_col and key_val is not None:
295
+ row = df2[df2[key_col] == key_val]
296
+ if not row.empty:
297
+ value = row[selected_metric].iloc[0]
298
+ try:
299
+ # Try to parse as JSON first
300
+ parsed_json = json.loads(str(value))
301
+ formatted_json = json.dumps(parsed_json, indent=2)
302
+ st.markdown(f"**{selected_metric}:**")
303
+ st.code(formatted_json, language='json')
304
+ except json.JSONDecodeError:
305
+ try:
306
+ # If JSON fails, try to evaluate as Python literal (handles single quotes)
307
+ import ast
308
+ parsed_json = ast.literal_eval(str(value))
309
+ formatted_json = json.dumps(parsed_json, indent=2)
310
+ st.markdown(f"**{selected_metric}:**")
311
+ st.code(formatted_json, language='json')
312
+ except (ValueError, SyntaxError):
313
+ # If all parsing fails, show as raw text
314
+ st.markdown(f"**{selected_metric}:** {value}")
315
+ except (TypeError, ValueError):
316
+ st.markdown(f"**{selected_metric}:** {value}")
317
+ else:
318
+ st.markdown(f"**{selected_metric}:** N/A")
319
+ else:
320
+ st.markdown(f"**{selected_metric}:** N/A")
321
+
322
+ with mcol3:
323
+ st.markdown("**Ground Truth Metrics**")
324
+ df_for_gt = eval_tables.get(m1)
325
+ if df_for_gt is None:
326
+ df_for_gt = eval_tables.get(m2)
327
+ if df_for_gt is not None:
328
+ if 'gt_sac_id' in df_for_gt.columns:
329
+ key_val = rec1.get('gt_sac_id', rec1.get('sac_id'))
330
+ key_col = 'gt_sac_id'
331
+ elif 'gt_title' in df_for_gt.columns:
332
+ key_val = rec1.get('gt_title', rec1.get('title'))
333
+ key_col = 'gt_title'
334
+ else:
335
+ key_col = key_val = None
336
+ if key_col and key_val is not None:
337
+ row = df_for_gt[df_for_gt[key_col] == key_val]
338
+ if not row.empty:
339
+ excluded_gt_fields = {'gt_procedure', 'gt_dspy_uuid', 'gt_dspy_split'}
340
+ gt_columns = [col for col in df_for_gt.columns if col.startswith('gt_') and col not in key_columns and col not in excluded_gt_fields]
341
+ if gt_columns:
342
+ for gt_col in gt_columns:
343
+ value = row[gt_col].iloc[0]
344
+ try:
345
+ # Try to parse as JSON first
346
+ parsed_json = json.loads(str(value))
347
+ formatted_json = json.dumps(parsed_json, indent=2)
348
+ st.markdown(f"**{gt_col}:**")
349
+ st.code(formatted_json, language='json')
350
+ except json.JSONDecodeError:
351
+ try:
352
+ # If JSON fails, try to evaluate as Python literal (handles single quotes)
353
+ import ast
354
+ parsed_json = ast.literal_eval(str(value))
355
+ formatted_json = json.dumps(parsed_json, indent=2)
356
+ st.markdown(f"**{gt_col}:**")
357
+ st.code(formatted_json, language='json')
358
+ except (ValueError, SyntaxError):
359
+ # If all parsing fails, show as raw text
360
+ st.markdown(f"**{gt_col}:** {value}")
361
+ except (TypeError, ValueError):
362
+ st.markdown(f"**{gt_col}:** {value}")
363
+ else:
364
+ st.info("No additional ground truth metrics available.")
365
+ else:
366
+ st.info("No ground truth metrics for this record.")
367
+ else:
368
+ st.info("No evaluation table available for ground truth metrics.")
369
+
370
+ st.subheader("Your Preference")
371
+ pref = st.radio(
372
+ "Which response do you prefer?", options=[m1, m2], key=f"pref_{idx}"
373
+ )
374
+
375
+ st.subheader("Comments (Optional)")
376
+ comments = st.text_area(
377
+ "Add any comments or notes about your preference:",
378
+ height=100,
379
+ key=f"comments_{idx}",
380
+ placeholder="Optional: Explain your reasoning or add any observations..."
381
+ )
382
+
383
+ if st.session_state.feedback_saved:
384
+ st.success("Feedback saved.")
385
+ st.session_state.feedback_saved = False
386
+
387
+ feedback_path = os.path.join(SCRIPT_DIR, 'feedback.tsv')
388
+ header = [
389
+ 'timestamp', 'record_id', 'model_1', 'model_2', 'preference',
390
+ 'input', 'response_1', 'response_2', 'ground_truth', 'comments'
391
+ ]
392
+ row = [
393
+ datetime.now().isoformat(), current_id, m1, m2, pref,
394
+ rec1.get('input', ''), rec1.get('predicted', ''), rec2.get('predicted', ''),
395
+ rec1.get('ground_truth', ''), comments
396
+ ]
397
+ def submit_feedback():
398
+ append_feedback(feedback_path, header, row)
399
+ st.session_state.idx += 1
400
+ st.session_state.feedback_saved = True
401
+
402
+ st.button("Submit and Next", on_click=submit_feedback)
403
+
404
+
405
+ if __name__ == '__main__':
406
+ main()
src/streamlit_app.py DELETED
@@ -1,40 +0,0 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))