File size: 17,909 Bytes
1070692
 
 
f83fa62
1070692
 
 
 
 
 
 
f83fa62
 
ed922f9
 
f83fa62
1070692
 
 
 
f83fa62
127c41d
 
1070692
127c41d
1070692
205c7b6
f83fa62
 
 
 
205c7b6
99351a1
ed922f9
205c7b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f83fa62
 
 
 
4283cea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99351a1
f83fa62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed922f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f83fa62
ed922f9
 
 
 
 
 
 
 
 
 
 
 
205c7b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed922f9
f83fa62
1070692
 
 
b9e1a9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127c41d
134e659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed922f9
134e659
 
 
 
 
 
 
 
 
 
 
ed922f9
 
1070692
 
4283cea
 
 
 
 
 
 
 
 
1070692
f83fa62
 
 
 
 
 
205c7b6
 
 
 
 
f83fa62
134e659
f83fa62
 
 
205c7b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f83fa62
134e659
f83fa62
134e659
f83fa62
 
205c7b6
f83fa62
205c7b6
 
 
f83fa62
 
 
 
 
3bfb60a
 
 
1070692
f83fa62
205c7b6
f83fa62
 
205c7b6
f83fa62
 
ed922f9
f83fa62
205c7b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f83fa62
 
 
 
 
1070692
4283cea
 
 
 
 
1070692
4283cea
 
1070692
 
 
 
03f4e0c
 
1070692
 
 
 
127c41d
1070692
 
 
b9e1a9b
 
 
 
 
1070692
 
7e8b6aa
b9e1a9b
 
 
 
 
1070692
 
 
 
 
 
 
127c41d
b9e1a9b
1070692
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
from pathlib import Path
import json
import pandas as pd
import numpy as np

import gradio as gr
from datasets import load_dataset
from gradio_leaderboard import Leaderboard
from datetime import datetime
import os

from about import (
    PROBLEM_TYPES, TOKEN, CACHE_PATH, API, submissions_repo, results_repo,
    COLUMN_DISPLAY_NAMES, COUNT_BASED_METRICS, METRIC_GROUPS,
    METRIC_GROUP_COLORS, COLUMN_TO_GROUP
)

def get_leaderboard():
    ds = load_dataset(results_repo, split='train', download_mode="force_redownload")
    full_df = pd.DataFrame(ds)
    print(full_df.columns)
    if len(full_df) == 0:
        return pd.DataFrame({'date':[], 'model':[], 'score':[], 'verified':[]})

    return full_df

def format_dataframe(df, show_percentage=False, selected_groups=None, compact_view=True):
    """Format the dataframe with proper column names and optional percentages."""
    if len(df) == 0:
        return df

    # Build column list based on view mode
    selected_cols = ['model_name']

    if compact_view:
        # Use predefined compact columns
        from about import COMPACT_VIEW_COLUMNS
        selected_cols = [col for col in COMPACT_VIEW_COLUMNS if col in df.columns]
    else:
        # Build from selected groups
        if 'n_structures' in df.columns:
            selected_cols.append('n_structures')

        # If no groups selected, show all
        if not selected_groups:
            selected_groups = list(METRIC_GROUPS.keys())

        # Add columns from selected groups
        for group in selected_groups:
            if group in METRIC_GROUPS:
                for col in METRIC_GROUPS[group]:
                    if col in df.columns and col not in selected_cols:
                        selected_cols.append(col)

    # Create a copy with selected columns
    display_df = df[selected_cols].copy()

    # Add symbols to model names based on various properties
    if 'model_name' in display_df.columns:
        def add_model_symbols(row):
            name = row['model_name']
            symbols = []

            # Add relaxed symbol
            if 'relaxed' in df.columns and row.get('relaxed', False):
                symbols.append('⚑')

            # Add reference dataset symbols
            # β˜… for Alexandria and OQMD (in-distribution, part of reference dataset)
            if name in ['Alexandria', 'OQMD']:
                symbols.append('β˜…')
            # β—† for AFLOW (out-of-distribution relative to reference dataset)
            elif name == 'AFLOW':
                symbols.append('β—†')

            return f"{name} {' '.join(symbols)}" if symbols else name

        display_df['model_name'] = df.apply(add_model_symbols, axis=1)

    # Convert count-based metrics to percentages if requested
    if show_percentage and 'n_structures' in df.columns:
        n_structures = df['n_structures']
        for col in COUNT_BASED_METRICS:
            if col in display_df.columns:
                # Calculate percentage and format as string with %
                display_df[col] = (df[col] / n_structures * 100).round(1).astype(str) + '%'

    # Round numeric columns for cleaner display
    for col in display_df.columns:
        if display_df[col].dtype in ['float64', 'float32']:
            display_df[col] = display_df[col].round(4)

    # Rename columns for display
    display_df = display_df.rename(columns=COLUMN_DISPLAY_NAMES)

    # Apply color coding based on metric groups
    styled_df = apply_color_styling(display_df, selected_cols)

    return styled_df

def apply_color_styling(display_df, original_cols):
    """Apply background colors to dataframe based on metric groups using pandas Styler."""

    def style_by_group(x):
        # Create a DataFrame with the same shape filled with empty strings
        styles = pd.DataFrame('', index=x.index, columns=x.columns)

        # Map display column names back to original column names
        for i, display_col in enumerate(x.columns):
            if i < len(original_cols):
                original_col = original_cols[i]

                # Check if this column belongs to a metric group
                if original_col in COLUMN_TO_GROUP:
                    group = COLUMN_TO_GROUP[original_col]
                    color = METRIC_GROUP_COLORS.get(group, '')
                    if color:
                        styles[display_col] = f'background-color: {color}'

        return styles

    # Apply the styling function
    return display_df.style.apply(style_by_group, axis=None)

def update_leaderboard(show_percentage, selected_groups, compact_view, cached_df, sort_by, sort_direction):
    """Update the leaderboard based on user selections.

    Uses cached dataframe to avoid re-downloading data on every change.
    """
    # Use cached dataframe instead of re-downloading
    df_to_format = cached_df.copy()

    # Convert display name back to raw column name for sorting
    if sort_by and sort_by != "None":
        # Create reverse mapping from display names to raw column names
        display_to_raw = {v: k for k, v in COLUMN_DISPLAY_NAMES.items()}
        raw_column_name = display_to_raw.get(sort_by, sort_by)

        if raw_column_name in df_to_format.columns:
            ascending = (sort_direction == "Ascending")
            df_to_format = df_to_format.sort_values(by=raw_column_name, ascending=ascending)

    formatted_df = format_dataframe(df_to_format, show_percentage, selected_groups, compact_view)
    return formatted_df

def show_output_box(message):
    return gr.update(value=message, visible=True)

def submit_cif_files(model_name, problem_type, cif_files, relaxed, profile: gr.OAuthProfile | None):
    """Submit structures to the leaderboard."""
    from huggingface_hub import upload_file

    # Validate inputs
    if not model_name or not model_name.strip():
        return "Error: Please provide a model name.", None

    if not problem_type:
        return "Error: Please select a problem type.", None

    if not cif_files:
        return "Error: Please upload a file.", None

    if not profile:
        return "Error: Please log in to submit.", None

    try:
        username = profile.username
        timestamp = datetime.now().isoformat()

        # Create submission metadata
        submission_data = {
            "username": username,
            "model_name": model_name.strip(),
            "problem_type": problem_type,
            "relaxed": relaxed,
            "timestamp": timestamp,
            "file_name": Path(cif_files).name
        }

        # Create a unique submission ID
        submission_id = f"{username}_{model_name.strip().replace(' ', '_')}_{timestamp.replace(':', '-')}"

        # Upload the submission file
        file_path = Path(cif_files)
        uploaded_file_path = f"submissions/{submission_id}/{file_path.name}"

        upload_file(
            path_or_fileobj=str(file_path),
            path_in_repo=uploaded_file_path,
            repo_id=submissions_repo,
            token=TOKEN,
            repo_type="dataset"
        )

        # Upload metadata as JSON
        metadata_path = f"submissions/{submission_id}/metadata.json"
        import tempfile
        with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
            json.dump(submission_data, f, indent=2)
            temp_metadata_path = f.name

        upload_file(
            path_or_fileobj=temp_metadata_path,
            path_in_repo=metadata_path,
            repo_id=submissions_repo,
            token=TOKEN,
            repo_type="dataset"
        )

        # Clean up temp file
        os.unlink(temp_metadata_path)

        return f"Success! Submitted {model_name} for {problem_type} evaluation. Submission ID: {submission_id}", submission_id

    except Exception as e:
        return f"Error during submission: {str(e)}", None 

def generate_metric_legend_html():
    """Generate HTML table with color-coded metric group legend."""
    metric_details = {
        'Validity ↑': ('Valid, Charge Neutral, Distance Valid, Plausibility Valid', '↑ Higher is better'),
        'Uniqueness & Novelty ↑': ('Unique, Novel', '↑ Higher is better'),
        'Energy Metrics ↓': ('E Above Hull, Formation Energy, Relaxation RMSD (with std)', '↓ Lower is better'),
        'Stability ↑': ('Stable, Unique in Stable, SUN', '↑ Higher is better'),
        'Metastability ↑': ('Metastable, Unique in Metastable, MSUN', '↑ Higher is better'),
        'Distribution ↓': ('JS Distance, MMD, FID', '↓ Lower is better'),
        'Diversity ↑': ('Element, Space Group, Atomic Site, Crystal Size', '↑ Higher is better'),
        'HHI ↓': ('HHI Production, HHI Reserve', '↓ Lower is better'),
    }

    html = '<table style="width: 100%; border-collapse: collapse;">'
    html += '<thead><tr>'
    html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Color</th>'
    html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Group</th>'
    html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Metrics</th>'
    html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Direction</th>'
    html += '</tr></thead><tbody>'

    for group, color in METRIC_GROUP_COLORS.items():
        metrics, direction = metric_details.get(group, ('', ''))
        group_name = group.replace('↑', '').replace('↓', '').strip()

        html += '<tr>'
        html += f'<td style="border: 1px solid #ddd; padding: 8px;"><div style="width: 30px; height: 20px; background-color: {color}; border: 1px solid #999;"></div></td>'
        html += f'<td style="border: 1px solid #ddd; padding: 8px;"><strong>{group_name}</strong></td>'
        html += f'<td style="border: 1px solid #ddd; padding: 8px;">{metrics}</td>'
        html += f'<td style="border: 1px solid #ddd; padding: 8px;">{direction}</td>'
        html += '</tr>'

    html += '</tbody></table>'
    return html

def gradio_interface() -> gr.Blocks:
    with gr.Blocks() as demo:
        gr.Markdown("""
# πŸ”¬ LeMat-GenBench: A Unified Benchmark for Generative Models of Crystalline Materials

Generative machine learning models hold great promise for accelerating materials discovery, particularly through the inverse design of inorganic crystals, enabling an unprecedented exploration of chemical space. Yet, the lack of standardized evaluation frameworks makes it difficult to evaluate, compare and further develop these ML models meaningfully.

**LeMat-GenBench** introduces a unified benchmark for generative models of crystalline materials, with standardized evaluation metrics** for meaningful model comparison, diverse tasks, and this leaderboard to encourage and track community progress.

πŸ“„ **Paper**: [arXiv preprint](https://arxiv.org/abs/XXXX.XXXXX) | πŸ’» **Code**: [GitHub](https://github.com/LeMaterial/lemat-genbench) | πŸ“§ **Contact**: siddharth.betala-ext [at] entalpic.ai, alexandre.duval [at] entalpic.ai
""")
        with gr.Tabs(elem_classes="tab-buttons"):
            with gr.TabItem("πŸš€ Leaderboard", elem_id="boundary-benchmark-tab-table"):
                gr.Markdown("# LeMat-GenBench")

                # Display options
                with gr.Row():
                    with gr.Column(scale=1):
                        compact_view = gr.Checkbox(
                            value=True,
                            label="Compact View",
                            info="Show only key metrics"
                        )
                        show_percentage = gr.Checkbox(
                            value=True,
                            label="Show as Percentages",
                            info="Display count-based metrics as percentages of total structures"
                        )
                    with gr.Column(scale=1):
                        # Create choices with display names, but values are the raw column names
                        sort_choices = ["None"] + [COLUMN_DISPLAY_NAMES.get(col, col) for col in COLUMN_DISPLAY_NAMES.keys()]
                        sort_by = gr.Dropdown(
                            choices=sort_choices,
                            value="None",
                            label="Sort By",
                            info="Select column to sort by"
                        )
                        sort_direction = gr.Radio(
                            choices=["Ascending", "Descending"],
                            value="Descending",
                            label="Sort Direction"
                        )
                    with gr.Column(scale=2):
                        selected_groups = gr.CheckboxGroup(
                            choices=list(METRIC_GROUPS.keys()),
                            value=list(METRIC_GROUPS.keys()),
                            label="Metric Families (only active when Compact View is off)",
                            info="Select which metric groups to display"
                        )

                # Metric legend with color coding
                with gr.Accordion("Metric Groups Legend", open=False):
                    gr.HTML(generate_metric_legend_html())

                try:
                    # Initial dataframe - load once and cache
                    initial_df = get_leaderboard()
                    cached_df_state = gr.State(initial_df)

                    formatted_df = format_dataframe(initial_df, show_percentage=True, selected_groups=list(METRIC_GROUPS.keys()), compact_view=True)

                    leaderboard_table = gr.Dataframe(
                        label="GenBench Leaderboard",
                        value=formatted_df,
                        interactive=False,
                        wrap=True,
                        column_widths=["180px"] + ["160px"] * (len(formatted_df.columns) - 1) if len(formatted_df.columns) > 0 else None,
                        show_fullscreen_button=True
                    )

                    # Update dataframe when options change (using cached data)
                    show_percentage.change(
                        fn=update_leaderboard,
                        inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction],
                        outputs=leaderboard_table
                    )
                    selected_groups.change(
                        fn=update_leaderboard,
                        inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction],
                        outputs=leaderboard_table
                    )
                    compact_view.change(
                        fn=update_leaderboard,
                        inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction],
                        outputs=leaderboard_table
                    )
                    sort_by.change(
                        fn=update_leaderboard,
                        inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction],
                        outputs=leaderboard_table
                    )
                    sort_direction.change(
                        fn=update_leaderboard,
                        inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction],
                        outputs=leaderboard_table
                    )

                except Exception as e:
                    gr.Markdown(f"Leaderboard is empty or error loading: {str(e)}")

                gr.Markdown("""
**Symbol Legend:**
- ⚑ Structures were already relaxed
- β˜… Contributes to LeMat-Bulk reference dataset (in-distribution)
- β—† Out-of-distribution relative to LeMat-Bulk reference dataset

Verified submissions mean the results came from a model submission rather than a CIF submission.
""")        

            with gr.TabItem("βœ‰οΈ Submit", elem_id="boundary-benchmark-tab-table"):
                gr.Markdown(
                    """
                # Materials Submission
                Upload a CSV, pkl, or a ZIP of CIFs with your structures.
                """
                )
                filename = gr.State(value=None) 

                gr.LoginButton()

                with gr.Row():
                    with gr.Column():
                        model_name_input = gr.Textbox(
                            label="Model Name",
                            placeholder="Enter your model name",
                            info="Provide a name for your model/method"
                        )
                        problem_type = gr.Dropdown(PROBLEM_TYPES, label="Problem Type")
                    with gr.Column():
                        cif_file = gr.File(label="Upload a CSV, a pkl, or a ZIP of CIF files.")
                        relaxed = gr.Checkbox(
                            value=False,
                            label="Structures are already relaxed",
                            info="Check this box if your submitted structures have already been relaxed"
                        )

                submit_btn = gr.Button("Submission")
                message = gr.Textbox(label="Status", lines=1, visible=False)
                # help message
                gr.Markdown("If you have issues with submission or using the leaderboard, please start a discussion in the Community tab of this Space.")
                
                submit_btn.click(
                    submit_cif_files,
                    inputs=[model_name_input, problem_type, cif_file, relaxed],
                    outputs=[message, filename],
                ).then(
                    fn=show_output_box,
                    inputs=[message],
                    outputs=[message],
                )
 
    return demo


if __name__ == "__main__":
    gradio_interface().launch()