LeMat-GenBench / app.py
cgeorgiaw's picture
cgeorgiaw HF Staff
add labels about relaxation
4283cea
from pathlib import Path
import json
import pandas as pd
import numpy as np
import gradio as gr
from datasets import load_dataset
from gradio_leaderboard import Leaderboard
from datetime import datetime
import os
from about import (
PROBLEM_TYPES, TOKEN, CACHE_PATH, API, submissions_repo, results_repo,
COLUMN_DISPLAY_NAMES, COUNT_BASED_METRICS, METRIC_GROUPS,
METRIC_GROUP_COLORS, COLUMN_TO_GROUP
)
def get_leaderboard():
ds = load_dataset(results_repo, split='train', download_mode="force_redownload")
full_df = pd.DataFrame(ds)
print(full_df.columns)
if len(full_df) == 0:
return pd.DataFrame({'date':[], 'model':[], 'score':[], 'verified':[]})
return full_df
def format_dataframe(df, show_percentage=False, selected_groups=None, compact_view=True):
"""Format the dataframe with proper column names and optional percentages."""
if len(df) == 0:
return df
# Build column list based on view mode
selected_cols = ['model_name']
if compact_view:
# Use predefined compact columns
from about import COMPACT_VIEW_COLUMNS
selected_cols = [col for col in COMPACT_VIEW_COLUMNS if col in df.columns]
else:
# Build from selected groups
if 'n_structures' in df.columns:
selected_cols.append('n_structures')
# If no groups selected, show all
if not selected_groups:
selected_groups = list(METRIC_GROUPS.keys())
# Add columns from selected groups
for group in selected_groups:
if group in METRIC_GROUPS:
for col in METRIC_GROUPS[group]:
if col in df.columns and col not in selected_cols:
selected_cols.append(col)
# Create a copy with selected columns
display_df = df[selected_cols].copy()
# Add symbols to model names based on various properties
if 'model_name' in display_df.columns:
def add_model_symbols(row):
name = row['model_name']
symbols = []
# Add relaxed symbol
if 'relaxed' in df.columns and row.get('relaxed', False):
symbols.append('⚑')
# Add reference dataset symbols
# β˜… for Alexandria and OQMD (in-distribution, part of reference dataset)
if name in ['Alexandria', 'OQMD']:
symbols.append('β˜…')
# β—† for AFLOW (out-of-distribution relative to reference dataset)
elif name == 'AFLOW':
symbols.append('β—†')
return f"{name} {' '.join(symbols)}" if symbols else name
display_df['model_name'] = df.apply(add_model_symbols, axis=1)
# Convert count-based metrics to percentages if requested
if show_percentage and 'n_structures' in df.columns:
n_structures = df['n_structures']
for col in COUNT_BASED_METRICS:
if col in display_df.columns:
# Calculate percentage and format as string with %
display_df[col] = (df[col] / n_structures * 100).round(1).astype(str) + '%'
# Round numeric columns for cleaner display
for col in display_df.columns:
if display_df[col].dtype in ['float64', 'float32']:
display_df[col] = display_df[col].round(4)
# Rename columns for display
display_df = display_df.rename(columns=COLUMN_DISPLAY_NAMES)
# Apply color coding based on metric groups
styled_df = apply_color_styling(display_df, selected_cols)
return styled_df
def apply_color_styling(display_df, original_cols):
"""Apply background colors to dataframe based on metric groups using pandas Styler."""
def style_by_group(x):
# Create a DataFrame with the same shape filled with empty strings
styles = pd.DataFrame('', index=x.index, columns=x.columns)
# Map display column names back to original column names
for i, display_col in enumerate(x.columns):
if i < len(original_cols):
original_col = original_cols[i]
# Check if this column belongs to a metric group
if original_col in COLUMN_TO_GROUP:
group = COLUMN_TO_GROUP[original_col]
color = METRIC_GROUP_COLORS.get(group, '')
if color:
styles[display_col] = f'background-color: {color}'
return styles
# Apply the styling function
return display_df.style.apply(style_by_group, axis=None)
def update_leaderboard(show_percentage, selected_groups, compact_view, cached_df, sort_by, sort_direction):
"""Update the leaderboard based on user selections.
Uses cached dataframe to avoid re-downloading data on every change.
"""
# Use cached dataframe instead of re-downloading
df_to_format = cached_df.copy()
# Convert display name back to raw column name for sorting
if sort_by and sort_by != "None":
# Create reverse mapping from display names to raw column names
display_to_raw = {v: k for k, v in COLUMN_DISPLAY_NAMES.items()}
raw_column_name = display_to_raw.get(sort_by, sort_by)
if raw_column_name in df_to_format.columns:
ascending = (sort_direction == "Ascending")
df_to_format = df_to_format.sort_values(by=raw_column_name, ascending=ascending)
formatted_df = format_dataframe(df_to_format, show_percentage, selected_groups, compact_view)
return formatted_df
def show_output_box(message):
return gr.update(value=message, visible=True)
def submit_cif_files(model_name, problem_type, cif_files, relaxed, profile: gr.OAuthProfile | None):
"""Submit structures to the leaderboard."""
from huggingface_hub import upload_file
# Validate inputs
if not model_name or not model_name.strip():
return "Error: Please provide a model name.", None
if not problem_type:
return "Error: Please select a problem type.", None
if not cif_files:
return "Error: Please upload a file.", None
if not profile:
return "Error: Please log in to submit.", None
try:
username = profile.username
timestamp = datetime.now().isoformat()
# Create submission metadata
submission_data = {
"username": username,
"model_name": model_name.strip(),
"problem_type": problem_type,
"relaxed": relaxed,
"timestamp": timestamp,
"file_name": Path(cif_files).name
}
# Create a unique submission ID
submission_id = f"{username}_{model_name.strip().replace(' ', '_')}_{timestamp.replace(':', '-')}"
# Upload the submission file
file_path = Path(cif_files)
uploaded_file_path = f"submissions/{submission_id}/{file_path.name}"
upload_file(
path_or_fileobj=str(file_path),
path_in_repo=uploaded_file_path,
repo_id=submissions_repo,
token=TOKEN,
repo_type="dataset"
)
# Upload metadata as JSON
metadata_path = f"submissions/{submission_id}/metadata.json"
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(submission_data, f, indent=2)
temp_metadata_path = f.name
upload_file(
path_or_fileobj=temp_metadata_path,
path_in_repo=metadata_path,
repo_id=submissions_repo,
token=TOKEN,
repo_type="dataset"
)
# Clean up temp file
os.unlink(temp_metadata_path)
return f"Success! Submitted {model_name} for {problem_type} evaluation. Submission ID: {submission_id}", submission_id
except Exception as e:
return f"Error during submission: {str(e)}", None
def generate_metric_legend_html():
"""Generate HTML table with color-coded metric group legend."""
metric_details = {
'Validity ↑': ('Valid, Charge Neutral, Distance Valid, Plausibility Valid', '↑ Higher is better'),
'Uniqueness & Novelty ↑': ('Unique, Novel', '↑ Higher is better'),
'Energy Metrics ↓': ('E Above Hull, Formation Energy, Relaxation RMSD (with std)', '↓ Lower is better'),
'Stability ↑': ('Stable, Unique in Stable, SUN', '↑ Higher is better'),
'Metastability ↑': ('Metastable, Unique in Metastable, MSUN', '↑ Higher is better'),
'Distribution ↓': ('JS Distance, MMD, FID', '↓ Lower is better'),
'Diversity ↑': ('Element, Space Group, Atomic Site, Crystal Size', '↑ Higher is better'),
'HHI ↓': ('HHI Production, HHI Reserve', '↓ Lower is better'),
}
html = '<table style="width: 100%; border-collapse: collapse;">'
html += '<thead><tr>'
html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Color</th>'
html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Group</th>'
html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Metrics</th>'
html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Direction</th>'
html += '</tr></thead><tbody>'
for group, color in METRIC_GROUP_COLORS.items():
metrics, direction = metric_details.get(group, ('', ''))
group_name = group.replace('↑', '').replace('↓', '').strip()
html += '<tr>'
html += f'<td style="border: 1px solid #ddd; padding: 8px;"><div style="width: 30px; height: 20px; background-color: {color}; border: 1px solid #999;"></div></td>'
html += f'<td style="border: 1px solid #ddd; padding: 8px;"><strong>{group_name}</strong></td>'
html += f'<td style="border: 1px solid #ddd; padding: 8px;">{metrics}</td>'
html += f'<td style="border: 1px solid #ddd; padding: 8px;">{direction}</td>'
html += '</tr>'
html += '</tbody></table>'
return html
def gradio_interface() -> gr.Blocks:
with gr.Blocks() as demo:
gr.Markdown("""
# πŸ”¬ LeMat-GenBench: A Unified Benchmark for Generative Models of Crystalline Materials
Generative machine learning models hold great promise for accelerating materials discovery, particularly through the inverse design of inorganic crystals, enabling an unprecedented exploration of chemical space. Yet, the lack of standardized evaluation frameworks makes it difficult to evaluate, compare and further develop these ML models meaningfully.
**LeMat-GenBench** introduces a unified benchmark for generative models of crystalline materials, with standardized evaluation metrics** for meaningful model comparison, diverse tasks, and this leaderboard to encourage and track community progress.
πŸ“„ **Paper**: [arXiv preprint](https://arxiv.org/abs/XXXX.XXXXX) | πŸ’» **Code**: [GitHub](https://github.com/LeMaterial/lemat-genbench) | πŸ“§ **Contact**: siddharth.betala-ext [at] entalpic.ai, alexandre.duval [at] entalpic.ai
""")
with gr.Tabs(elem_classes="tab-buttons"):
with gr.TabItem("πŸš€ Leaderboard", elem_id="boundary-benchmark-tab-table"):
gr.Markdown("# LeMat-GenBench")
# Display options
with gr.Row():
with gr.Column(scale=1):
compact_view = gr.Checkbox(
value=True,
label="Compact View",
info="Show only key metrics"
)
show_percentage = gr.Checkbox(
value=True,
label="Show as Percentages",
info="Display count-based metrics as percentages of total structures"
)
with gr.Column(scale=1):
# Create choices with display names, but values are the raw column names
sort_choices = ["None"] + [COLUMN_DISPLAY_NAMES.get(col, col) for col in COLUMN_DISPLAY_NAMES.keys()]
sort_by = gr.Dropdown(
choices=sort_choices,
value="None",
label="Sort By",
info="Select column to sort by"
)
sort_direction = gr.Radio(
choices=["Ascending", "Descending"],
value="Descending",
label="Sort Direction"
)
with gr.Column(scale=2):
selected_groups = gr.CheckboxGroup(
choices=list(METRIC_GROUPS.keys()),
value=list(METRIC_GROUPS.keys()),
label="Metric Families (only active when Compact View is off)",
info="Select which metric groups to display"
)
# Metric legend with color coding
with gr.Accordion("Metric Groups Legend", open=False):
gr.HTML(generate_metric_legend_html())
try:
# Initial dataframe - load once and cache
initial_df = get_leaderboard()
cached_df_state = gr.State(initial_df)
formatted_df = format_dataframe(initial_df, show_percentage=True, selected_groups=list(METRIC_GROUPS.keys()), compact_view=True)
leaderboard_table = gr.Dataframe(
label="GenBench Leaderboard",
value=formatted_df,
interactive=False,
wrap=True,
column_widths=["180px"] + ["160px"] * (len(formatted_df.columns) - 1) if len(formatted_df.columns) > 0 else None,
show_fullscreen_button=True
)
# Update dataframe when options change (using cached data)
show_percentage.change(
fn=update_leaderboard,
inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction],
outputs=leaderboard_table
)
selected_groups.change(
fn=update_leaderboard,
inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction],
outputs=leaderboard_table
)
compact_view.change(
fn=update_leaderboard,
inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction],
outputs=leaderboard_table
)
sort_by.change(
fn=update_leaderboard,
inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction],
outputs=leaderboard_table
)
sort_direction.change(
fn=update_leaderboard,
inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction],
outputs=leaderboard_table
)
except Exception as e:
gr.Markdown(f"Leaderboard is empty or error loading: {str(e)}")
gr.Markdown("""
**Symbol Legend:**
- ⚑ Structures were already relaxed
- β˜… Contributes to LeMat-Bulk reference dataset (in-distribution)
- β—† Out-of-distribution relative to LeMat-Bulk reference dataset
Verified submissions mean the results came from a model submission rather than a CIF submission.
""")
with gr.TabItem("βœ‰οΈ Submit", elem_id="boundary-benchmark-tab-table"):
gr.Markdown(
"""
# Materials Submission
Upload a CSV, pkl, or a ZIP of CIFs with your structures.
"""
)
filename = gr.State(value=None)
gr.LoginButton()
with gr.Row():
with gr.Column():
model_name_input = gr.Textbox(
label="Model Name",
placeholder="Enter your model name",
info="Provide a name for your model/method"
)
problem_type = gr.Dropdown(PROBLEM_TYPES, label="Problem Type")
with gr.Column():
cif_file = gr.File(label="Upload a CSV, a pkl, or a ZIP of CIF files.")
relaxed = gr.Checkbox(
value=False,
label="Structures are already relaxed",
info="Check this box if your submitted structures have already been relaxed"
)
submit_btn = gr.Button("Submission")
message = gr.Textbox(label="Status", lines=1, visible=False)
# help message
gr.Markdown("If you have issues with submission or using the leaderboard, please start a discussion in the Community tab of this Space.")
submit_btn.click(
submit_cif_files,
inputs=[model_name_input, problem_type, cif_file, relaxed],
outputs=[message, filename],
).then(
fn=show_output_box,
inputs=[message],
outputs=[message],
)
return demo
if __name__ == "__main__":
gradio_interface().launch()