Spaces:
Sleeping
Sleeping
| """ | |
| Streamlit app for human evaluation of model outputs. | |
| Allows users to select two models, compare their responses to the same inputs, | |
| and record preferences for subsequent analysis. | |
| """ | |
| import os | |
| import json | |
| import csv | |
| from datetime import datetime | |
| import streamlit as st | |
| import pandas as pd | |
| import tempfile | |
| st.set_page_config(page_title="Model Comparison Evaluation", layout="wide") | |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| DATA_DIR = os.path.join(SCRIPT_DIR, "for_experiments_prediction") | |
| def load_models(data_dir): | |
| """ | |
| Discover prediction JSON files named 'predicted_vs_gt.json' and load flattened records for each model. | |
| Returns a dict mapping model name to dict of {id: record}. | |
| """ | |
| model_paths = {} | |
| for root, _, files in os.walk(data_dir): | |
| for fname in files: | |
| if fname == 'predicted_vs_gt.json': | |
| path = os.path.join(root, fname) | |
| rel = os.path.relpath(root, data_dir) | |
| model_paths[rel] = path | |
| models = {} | |
| for model_name, path in sorted(model_paths.items()): | |
| with open(path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| records = {} | |
| for section in data.values(): | |
| if isinstance(section, dict): | |
| for sub in section.values(): | |
| for rec in sub: | |
| records[rec['id']] = rec | |
| elif isinstance(section, list): | |
| for rec in section: | |
| records[rec['id']] = rec | |
| models[model_name] = records | |
| return models | |
| def append_feedback(feedback_file, header, row): | |
| """ | |
| Append a single feedback row to TSV, writing header if file does not exist. | |
| """ | |
| write_header = not os.path.exists(feedback_file) | |
| with open(feedback_file, 'a', newline='', encoding='utf-8') as f: | |
| writer = csv.writer(f, delimiter='\t', quoting=csv.QUOTE_ALL) | |
| if write_header: | |
| writer.writerow(header) | |
| writer.writerow(row) | |
| def load_eval_tables(data_dir): | |
| """ | |
| Discover evaluation_table.parquet files under each model directory and load each into a pandas DataFrame. | |
| Returns a dict mapping model name to its evaluation DataFrame. | |
| """ | |
| tables = {} | |
| for root, _, files in os.walk(data_dir): | |
| if 'evaluation_table.parquet' in files: | |
| path = os.path.join(root, 'evaluation_table.parquet') | |
| rel = os.path.relpath(root, data_dir) | |
| tables[rel] = pd.read_parquet(path) | |
| return tables | |
| def main(): | |
| st.title("Model Comparison Evaluation") | |
| print(DATA_DIR) | |
| models = load_models(DATA_DIR) | |
| eval_tables = load_eval_tables(DATA_DIR) | |
| all_cols = set() | |
| for df in eval_tables.values(): | |
| all_cols.update(df.columns) | |
| key_columns = {'gt_sac_id', 'gt_title'} | |
| metric_columns = sorted(all_cols - key_columns) | |
| fixed_metrics = [ | |
| 'chemicals_accuracy', | |
| 'chemicals_f1_score', | |
| 'chemicals_precision', | |
| 'chemicals_recall', | |
| 'metal_accuracy', | |
| 'metal_f1_score', | |
| 'metal_precision', | |
| 'metal_recall', | |
| 'procedure_procedure_completeness_score', | |
| 'procedure_procedure_order_score', | |
| 'procedure_procedure_accuracy_score', | |
| 'support_accuracy', | |
| 'support_f1_score', | |
| 'support_precision', | |
| 'support_recall', | |
| ] | |
| other_metrics = sorted([c for c in metric_columns if c not in fixed_metrics and not c.startswith('gt_')]) | |
| model_names = list(models.keys()) | |
| st.sidebar.header("Configuration") | |
| def reset_index(): | |
| st.session_state.idx = 0 | |
| # Reset feedback file when models change | |
| feedback_path = os.path.join(tempfile.gettempdir(), 'feedback.tsv') | |
| if os.path.exists(feedback_path): | |
| os.remove(feedback_path) | |
| selected = st.sidebar.multiselect( | |
| "Select exactly two models to compare", | |
| options=model_names, | |
| key='models', | |
| help="Choose two model variants for side-by-side comparison", | |
| on_change=reset_index | |
| ) | |
| if len(selected) != 2: | |
| st.sidebar.info("Please select exactly two models.") | |
| st.stop() | |
| # Download button for feedback TSV | |
| feedback_path = os.path.join(tempfile.gettempdir(), 'feedback.tsv') | |
| if os.path.exists(feedback_path): | |
| with open(feedback_path, 'r', encoding='utf-8') as f: | |
| tsv_data = f.read() | |
| st.sidebar.download_button( | |
| label="Download Feedback TSV", | |
| data=tsv_data, | |
| file_name="feedback.tsv", | |
| mime="text/tab-separated-values" | |
| ) | |
| m1, m2 = selected | |
| recs1 = models[m1] | |
| recs2 = models[m2] | |
| common_ids = sorted(set(recs1.keys()) & set(recs2.keys())) | |
| if not common_ids: | |
| st.error("No common records between the selected models.") | |
| st.stop() | |
| if 'idx' not in st.session_state: | |
| st.session_state.idx = 0 | |
| if 'feedback_saved' not in st.session_state: | |
| st.session_state.feedback_saved = False | |
| # Initialize fresh feedback file for new session | |
| if 'session_initialized' not in st.session_state: | |
| feedback_path = os.path.join(tempfile.gettempdir(), 'feedback.tsv') | |
| if os.path.exists(feedback_path): | |
| os.remove(feedback_path) | |
| st.session_state.session_initialized = True | |
| total = len(common_ids) | |
| idx = st.session_state.idx | |
| if idx < 0: | |
| idx = 0 | |
| if idx >= total: | |
| st.write("### Evaluation complete! Thank you for your feedback.") | |
| st.stop() | |
| current_id = common_ids[idx] | |
| rec1 = recs1[current_id] | |
| rec2 = recs2[current_id] | |
| st.markdown(f"**Record {idx+1}/{total} — ID: {current_id}**") | |
| st.markdown("---") | |
| st.subheader("Input Prompt") | |
| st.code(rec1.get('input', ''), language='') | |
| st.subheader("Model Responses and Ground Truth") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.markdown(f"**{m1}**") | |
| st.text_area("", rec1.get('predicted', ''), height=600, key=f"resp1_{idx}") | |
| with col2: | |
| st.markdown(f"**{m2}**") | |
| st.text_area("", rec2.get('predicted', ''), height=600, key=f"resp2_{idx}") | |
| with col3: | |
| st.markdown("**Ground Truth**") | |
| st.text_area("", rec1.get('ground_truth', ''), height=600, key=f"gt_{idx}") | |
| fcol1, fcol2, fcol3 = st.columns(3) | |
| with fcol1: | |
| df1 = eval_tables.get(m1) | |
| if df1 is not None: | |
| if 'gt_sac_id' in df1.columns: | |
| key_val = rec1.get('gt_sac_id', rec1.get('sac_id')) | |
| key_col = 'gt_sac_id' | |
| elif 'gt_title' in df1.columns: | |
| key_val = rec1.get('gt_title', rec1.get('title')) | |
| key_col = 'gt_title' | |
| else: | |
| key_col = key_val = None | |
| if key_col and key_val is not None: | |
| row = df1[df1[key_col] == key_val] | |
| if not row.empty: | |
| fm_df = row[fixed_metrics].T | |
| fm_df.columns = ['value'] | |
| st.table(fm_df) | |
| else: | |
| st.info("No fixed metrics for this record.") | |
| else: | |
| st.info("No evaluation table available for this model.") | |
| with fcol2: | |
| df2 = eval_tables.get(m2) | |
| if df2 is not None: | |
| if 'gt_sac_id' in df2.columns: | |
| key_val = rec1.get('gt_sac_id', rec1.get('sac_id')) | |
| key_col = 'gt_sac_id' | |
| elif 'gt_title' in df2.columns: | |
| key_val = rec1.get('gt_title', rec1.get('title')) | |
| key_col = 'gt_title' | |
| else: | |
| key_col = key_val = None | |
| if key_col and key_val is not None: | |
| row = df2[df2[key_col] == key_val] | |
| if not row.empty: | |
| fm_df = row[fixed_metrics].T.astype(float).mean(axis=1) | |
| fm_df.columns = ['value'] | |
| st.table(fm_df) | |
| else: | |
| st.info("No fixed metrics for this record.") | |
| else: | |
| st.info("No evaluation table available for this model.") | |
| if other_metrics: | |
| selected_metric = st.selectbox( | |
| "Select additional metric to display", | |
| options=other_metrics, | |
| key=f"metric_sel_{idx}" | |
| ) | |
| else: | |
| selected_metric = None | |
| if selected_metric: | |
| mcol1, mcol2, mcol3 = st.columns(3) | |
| with mcol1: | |
| df1 = eval_tables.get(m1) | |
| if df1 is not None and selected_metric in df1.columns: | |
| if 'gt_sac_id' in df1.columns: | |
| key_val = rec1.get('gt_sac_id', rec1.get('sac_id')) | |
| key_col = 'gt_sac_id' | |
| elif 'gt_title' in df1.columns: | |
| key_val = rec1.get('gt_title', rec1.get('title')) | |
| key_col = 'gt_title' | |
| else: | |
| key_col = key_val = None | |
| if key_col and key_val is not None: | |
| row = df1[df1[key_col] == key_val] | |
| if not row.empty: | |
| value = row[selected_metric].iloc[0] | |
| try: | |
| # Try to parse as JSON first | |
| parsed_json = json.loads(str(value)) | |
| formatted_json = json.dumps(parsed_json, indent=2) | |
| st.markdown(f"**{selected_metric}:**") | |
| st.code(formatted_json, language='json') | |
| except json.JSONDecodeError: | |
| try: | |
| # If JSON fails, try to evaluate as Python literal (handles single quotes) | |
| import ast | |
| parsed_json = ast.literal_eval(str(value)) | |
| formatted_json = json.dumps(parsed_json, indent=2) | |
| st.markdown(f"**{selected_metric}:**") | |
| st.code(formatted_json, language='json') | |
| except (ValueError, SyntaxError): | |
| # If all parsing fails, show as raw text | |
| st.markdown(f"**{selected_metric}:** {value}") | |
| except (TypeError, ValueError): | |
| st.markdown(f"**{selected_metric}:** {value}") | |
| else: | |
| st.markdown(f"**{selected_metric}:** N/A") | |
| else: | |
| st.markdown(f"**{selected_metric}:** N/A") | |
| with mcol2: | |
| df2 = eval_tables.get(m2) | |
| if df2 is not None and selected_metric in df2.columns: | |
| if 'gt_sac_id' in df2.columns: | |
| key_val = rec1.get('gt_sac_id', rec1.get('sac_id')) | |
| key_col = 'gt_sac_id' | |
| elif 'gt_title' in df2.columns: | |
| key_val = rec1.get('gt_title', rec1.get('title')) | |
| key_col = 'gt_title' | |
| else: | |
| key_col = key_val = None | |
| if key_col and key_val is not None: | |
| row = df2[df2[key_col] == key_val] | |
| if not row.empty: | |
| value = row[selected_metric].iloc[0] | |
| try: | |
| # Try to parse as JSON first | |
| parsed_json = json.loads(str(value)) | |
| formatted_json = json.dumps(parsed_json, indent=2) | |
| st.markdown(f"**{selected_metric}:**") | |
| st.code(formatted_json, language='json') | |
| except json.JSONDecodeError: | |
| try: | |
| # If JSON fails, try to evaluate as Python literal (handles single quotes) | |
| import ast | |
| parsed_json = ast.literal_eval(str(value)) | |
| formatted_json = json.dumps(parsed_json, indent=2) | |
| st.markdown(f"**{selected_metric}:**") | |
| st.code(formatted_json, language='json') | |
| except (ValueError, SyntaxError): | |
| # If all parsing fails, show as raw text | |
| st.markdown(f"**{selected_metric}:** {value}") | |
| except (TypeError, ValueError): | |
| st.markdown(f"**{selected_metric}:** {value}") | |
| else: | |
| st.markdown(f"**{selected_metric}:** N/A") | |
| else: | |
| st.markdown(f"**{selected_metric}:** N/A") | |
| with mcol3: | |
| st.markdown("**Ground Truth Metrics**") | |
| df_for_gt = eval_tables.get(m1) | |
| if df_for_gt is None: | |
| df_for_gt = eval_tables.get(m2) | |
| if df_for_gt is not None: | |
| if 'gt_sac_id' in df_for_gt.columns: | |
| key_val = rec1.get('gt_sac_id', rec1.get('sac_id')) | |
| key_col = 'gt_sac_id' | |
| elif 'gt_title' in df_for_gt.columns: | |
| key_val = rec1.get('gt_title', rec1.get('title')) | |
| key_col = 'gt_title' | |
| else: | |
| key_col = key_val = None | |
| if key_col and key_val is not None: | |
| row = df_for_gt[df_for_gt[key_col] == key_val] | |
| if not row.empty: | |
| excluded_gt_fields = {'gt_procedure', 'gt_dspy_uuid', 'gt_dspy_split'} | |
| gt_columns = [col for col in df_for_gt.columns if col.startswith('gt_') and col not in key_columns and col not in excluded_gt_fields] | |
| if gt_columns: | |
| for gt_col in gt_columns: | |
| value = row[gt_col].iloc[0] | |
| try: | |
| # Try to parse as JSON first | |
| parsed_json = json.loads(str(value)) | |
| formatted_json = json.dumps(parsed_json, indent=2) | |
| st.markdown(f"**{gt_col}:**") | |
| st.code(formatted_json, language='json') | |
| except json.JSONDecodeError: | |
| try: | |
| # If JSON fails, try to evaluate as Python literal (handles single quotes) | |
| import ast | |
| parsed_json = ast.literal_eval(str(value)) | |
| formatted_json = json.dumps(parsed_json, indent=2) | |
| st.markdown(f"**{gt_col}:**") | |
| st.code(formatted_json, language='json') | |
| except (ValueError, SyntaxError): | |
| # If all parsing fails, show as raw text | |
| st.markdown(f"**{gt_col}:** {value}") | |
| except (TypeError, ValueError): | |
| st.markdown(f"**{gt_col}:** {value}") | |
| else: | |
| st.info("No additional ground truth metrics available.") | |
| else: | |
| st.info("No ground truth metrics for this record.") | |
| else: | |
| st.info("No evaluation table available for ground truth metrics.") | |
| st.subheader("Your Preference") | |
| pref = st.radio( | |
| "Which response do you prefer?", options=[m1, m2], key=f"pref_{idx}" | |
| ) | |
| st.subheader("Comments (Optional)") | |
| comments = st.text_area( | |
| "Add any comments or notes about your preference:", | |
| height=100, | |
| key=f"comments_{idx}", | |
| placeholder="Optional: Explain your reasoning or add any observations..." | |
| ) | |
| if st.session_state.feedback_saved: | |
| st.success("Feedback saved.") | |
| st.session_state.feedback_saved = False | |
| header = [ | |
| 'timestamp', 'record_id', 'model_1', 'model_2', 'preference', | |
| 'input', 'response_1', 'response_2', 'ground_truth', 'comments' | |
| ] | |
| row = [ | |
| datetime.now().isoformat(), current_id, m1, m2, pref, | |
| rec1.get('input', ''), rec1.get('predicted', ''), rec2.get('predicted', ''), | |
| rec1.get('ground_truth', ''), comments | |
| ] | |
| def submit_feedback(): | |
| # Get the current text box content at the time of submission | |
| current_comments = st.session_state.get(f"comments_{idx}", "") | |
| # Update the row with the current comments | |
| row[9] = current_comments # comments is at index 9 | |
| append_feedback(feedback_path, header, row) | |
| st.session_state.idx += 1 | |
| st.session_state.feedback_saved = True | |
| st.button("Submit and Next", on_click=submit_feedback) | |
| if __name__ == '__main__': | |
| main() |