Spaces:
Running
Running
| from pathlib import Path | |
| from typing import Literal | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.colors as pcolors | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import streamlit as st | |
| from plotly.subplots import make_subplots | |
| from mlip_arena.models import REGISTRY | |
| st.title("Stability") | |
| DATA_DIR = Path(__file__).parents[2] / "benchmarks" / "stability" | |
| st.markdown("### Methods") | |
| container = st.container(border=True) | |
| # Filter models that have valid parquet results | |
| valid_models = [ | |
| model | |
| for model, metadata in REGISTRY.items() | |
| if ( | |
| DATA_DIR / REGISTRY[str(model)]["family"].lower() / f"{model}-heating.parquet" | |
| ).exists() | |
| ] | |
| models = container.multiselect( | |
| "MLIPs", | |
| valid_models, | |
| [ | |
| "MACE-MP(M)", | |
| "CHGNet", | |
| "SevenNet", | |
| "ORBv2", | |
| "eqV2(OMat)", | |
| "M3GNet", | |
| "MatterSim", | |
| "MACE-MPA", | |
| ], | |
| ) | |
| st.markdown("### Settings") | |
| vis = st.container(border=True) | |
| # Build available color palettes from Plotly | |
| color_palettes = { | |
| attr: getattr(pcolors.qualitative, attr) | |
| for attr in dir(pcolors.qualitative) | |
| if isinstance(getattr(pcolors.qualitative, attr), list) | |
| } | |
| color_palettes.pop("__all__", None) | |
| palette_name = vis.selectbox( | |
| "Color sequence", options=list(color_palettes.keys()), index=22 | |
| ) | |
| color_sequence = color_palettes[palette_name] | |
| if not models: | |
| st.stop() | |
| def get_data(model_list, run_type: Literal["heating", "compression"]) -> pd.DataFrame: | |
| """Load parquet files for selected models.""" | |
| dfs = [] | |
| for m in model_list: | |
| fpath = ( | |
| DATA_DIR / REGISTRY[str(m)]["family"].lower() / f"{m}-{run_type}.parquet" | |
| ) | |
| if not fpath.exists(): | |
| continue | |
| df_local = pd.read_parquet(fpath) | |
| df_local["method"] = str(m) | |
| dfs.append(df_local) | |
| return pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame() | |
| df_nvt = get_data(models, run_type="heating") | |
| df_npt = get_data(models, run_type="compression") | |
| # Map model → color | |
| method_color_mapping = { | |
| method: color_sequence[i % len(color_sequence)] | |
| for i, method in enumerate(df_nvt["method"].unique()) | |
| } | |
| def prepare_scatter_df(df_in: pd.DataFrame, max_points: int = 20000) -> pd.DataFrame: | |
| """Prepare scatter dataframe with marker sizes scaled by total steps.""" | |
| dfp = df_in.dropna(subset=["natoms", "steps_per_second"]).copy() | |
| if dfp.empty: | |
| return dfp | |
| # Downsample if too many points | |
| if len(dfp) > max_points: | |
| dfp = dfp.sample(max_points, random_state=1) | |
| if "total_steps" in dfp.columns: | |
| ts_local = dfp["total_steps"].fillna(dfp["total_steps"].median()).astype(float) | |
| ts_range = ts_local.max() - ts_local.min() | |
| scaled = (ts_local - ts_local.min()) / (ts_range if ts_range != 0 else 1.0) | |
| dfp["_marker_size"] = (scaled * 40) + 5 | |
| else: | |
| dfp["_marker_size"] = 8 | |
| return dfp | |
| def compute_power_law_fits(df_in: pd.DataFrame) -> dict: | |
| """Fit power-law scaling: steps/s ~ a * N^(-n).""" | |
| fits = {} | |
| for name, grp in df_in.groupby("method"): | |
| grp_clean = grp.dropna(subset=["natoms", "steps_per_second"]) | |
| grp_clean = grp_clean[ | |
| (grp_clean["natoms"] > 0) & (grp_clean["steps_per_second"] > 0) | |
| ] | |
| if len(grp_clean) < 3: | |
| continue | |
| try: | |
| logsx = np.log(grp_clean["natoms"].astype(float)) | |
| logsy = np.log(grp_clean["steps_per_second"].astype(float)) | |
| slope, intercept = np.polyfit(logsx, logsy, 1) | |
| fits[name] = (float(np.exp(intercept)), float(-slope)) # (a, n) | |
| except Exception: | |
| continue | |
| return fits | |
| def build_speed_figure( | |
| df_in: pd.DataFrame, color_map: dict, show_scatter: bool | |
| ) -> go.Figure: | |
| """Build scatter plot of inference speed vs number of atoms with power-law fits.""" | |
| fig = go.Figure() | |
| # Optionally add scatter points | |
| if show_scatter: | |
| dfp = prepare_scatter_df(df_in) | |
| scatter_fig = px.scatter( | |
| dfp, | |
| x="natoms", | |
| y="steps_per_second", | |
| color="method", | |
| size="_marker_size", | |
| hover_data=[c for c in ["material_id", "formula"] if c in dfp.columns], | |
| color_discrete_map=color_map, | |
| log_x=True, | |
| log_y=True, | |
| render_mode="webgl", | |
| labels={ | |
| "steps_per_second": "Steps per second", | |
| "natoms": "Number of atoms", | |
| }, | |
| ) | |
| for trace in scatter_fig.data: | |
| fig.add_trace(trace) | |
| # Overlay fits | |
| fits = compute_power_law_fits(df_in) | |
| for method, (a, n) in fits.items(): | |
| grp = df_in[df_in["method"] == method] | |
| if grp["natoms"].dropna().empty: | |
| continue | |
| xs = np.logspace( | |
| np.log10(grp["natoms"].min()), np.log10(grp["natoms"].max()), 200 | |
| ) | |
| ys = a * xs ** (-n) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=xs, | |
| y=ys, | |
| mode="lines", | |
| line=dict(color=color_map.get(method, "black"), width=2), | |
| showlegend=not show_scatter, | |
| name=f"{method}", | |
| # zorder=0, | |
| # text=hover_text, | |
| # hoverinfo='text', # use the custom text | |
| ) | |
| ) | |
| fig.update_layout( | |
| height=520, | |
| title="Inference speed (steps/s)", | |
| xaxis=dict(type="log", title="Number of atoms"), | |
| yaxis=dict(type="log", title="Steps per second"), | |
| ) | |
| return fig | |
| def build_nvt_figure( | |
| df_in: pd.DataFrame, color_map: dict, show_scatter: bool | |
| ) -> go.Figure: | |
| """Build subplot: NVT valid runs (cumulative) + speed scaling plot.""" | |
| fig = make_subplots( | |
| rows=1, | |
| cols=2, | |
| column_widths=[0.4, 0.6], | |
| subplot_titles=("Valid runs", "Inference speed: steps/s vs N"), | |
| ) | |
| # Right panel: speed scaling | |
| speed_fig = build_speed_figure(df_in, color_map, show_scatter) | |
| for trace in speed_fig.data: | |
| fig.add_trace(trace, row=1, col=2) | |
| # Left panel: cumulative valid runs | |
| for method, df_model in df_in.groupby("method"): | |
| df_model_grp = df_model.drop_duplicates(["formula"]) | |
| hist, bin_edges = np.histogram( | |
| df_model_grp["normalized_final_step"], bins=np.linspace(0, 1, 50) | |
| ) | |
| cumulative_population = np.cumsum(hist) | |
| bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 | |
| fig.add_trace( | |
| go.Scatter( | |
| x=bin_centers[:-1], | |
| y=(cumulative_population[-1] - cumulative_population[:-1]) / 120 * 100, | |
| mode="lines", | |
| line=dict(color=color_map.get(method)), | |
| name=str(method), | |
| showlegend=False, | |
| ), | |
| row=1, | |
| col=1, | |
| ) | |
| fig.update_xaxes(title_text="Normalized time", row=1, col=1, range=[0, 1]) | |
| fig.update_yaxes(title_text="Valid runs (%)", row=1, col=1) | |
| fig.update_xaxes(type="log", row=1, col=2, title_text="Number of atoms") | |
| fig.update_yaxes(type="log", row=1, col=2, title_text="Steps per second") | |
| fig.update_layout(height=520, width=1000) | |
| return fig | |
| def build_npt_figure( | |
| df_in: pd.DataFrame, color_map: dict, show_scatter: bool | |
| ) -> go.Figure: | |
| """Build subplot: NPT valid runs (cumulative) + speed scaling plot.""" | |
| fig = make_subplots( | |
| rows=1, | |
| cols=2, | |
| column_widths=[0.4, 0.6], | |
| subplot_titles=("Valid runs", "Inference speed: steps/s vs N"), | |
| ) | |
| # Right panel: speed scaling | |
| speed_fig = build_speed_figure(df_in, color_map, show_scatter) | |
| for trace in speed_fig.data: | |
| fig.add_trace(trace, row=1, col=2) | |
| # Left panel: cumulative valid runs | |
| for method, df_model in df_in.groupby("method"): | |
| df_model_grp = df_model.drop_duplicates(["formula"]) | |
| hist, bin_edges = np.histogram( | |
| df_model_grp["normalized_final_step"], bins=np.linspace(0, 1, 50) | |
| ) | |
| cumulative_population = np.cumsum(hist) | |
| bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 | |
| fig.add_trace( | |
| go.Scatter( | |
| x=bin_centers[:-1], | |
| y=(cumulative_population[-1] - cumulative_population[:-1]) / 80 * 100, | |
| mode="lines", | |
| line=dict(color=color_map.get(method)), | |
| name=str(method), | |
| showlegend=False, | |
| ), | |
| row=1, | |
| col=1, | |
| ) | |
| fig.update_xaxes(title_text="Normalized time", row=1, col=1, range=[0, 1]) | |
| fig.update_yaxes(title_text="Valid runs (%)", row=1, col=1) | |
| fig.update_xaxes(type="log", row=1, col=2, title_text="Number of atoms") | |
| fig.update_yaxes(type="log", row=1, col=2, title_text="Steps per second") | |
| fig.update_layout(height=520, width=1000) | |
| return fig | |
| if df_nvt.empty and df_npt.empty: | |
| st.info("No data available to display for selected models.") | |
| else: | |
| st.markdown(""" | |
| ## Heating | |
| Isochoric-isothermal (NVT) MD simulations on RM24 structures, with temperature ramp from 300K to 3000K over 10 ps. | |
| """) | |
| show_scatter_nvt = st.toggle( | |
| "Show scatter points", key="show_scatter_nvt", value=True | |
| ) | |
| # Toggle for scatter points | |
| # show_scatter = vis.checkbox("Show scatter points", value=True) | |
| st.plotly_chart( | |
| build_nvt_figure(df_nvt, method_color_mapping, show_scatter_nvt), | |
| use_container_width=True, | |
| ) | |
| st.markdown(""" | |
| ## Compression | |
| Isothermal-isobaric (NPT) MD simulations on RM24 structures, with pressure ramp from 0 GPa to 500 GPa and temperature ramp from 300K to 3000K over 10 ps. | |
| """) | |
| show_scatter_npt = st.toggle( | |
| "Show scatter points", key="show_scatter_npt", value=True | |
| ) | |
| # Toggle for scatter points | |
| # show_scatter = vis.checkbox("Show scatter points", value=True) | |
| st.plotly_chart( | |
| build_npt_figure(df_npt, method_color_mapping, show_scatter_npt), | |
| use_container_width=True, | |
| ) | |