Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from typing import Any | |
| import matplotlib.pyplot as plt | |
| SERIES_ORDER = [ | |
| ("center_sft", "center SFT", "#2563eb"), | |
| ("center_grpo", "center GRPO", "#7c3aed"), | |
| ("warehouse_sft", "warehouse SFT", "#0f766e"), | |
| ("warehouse_grpo", "warehouse GRPO", "#047857"), | |
| ("center", "center", "#2563eb"), | |
| ("warehouse", "warehouse", "#0f766e"), | |
| ] | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Create static SupplyMind training/eval plots.") | |
| parser.add_argument("--input", type=Path, default=Path("results/training_dashboard.json")) | |
| parser.add_argument("--output-dir", type=Path, default=Path("results/plots")) | |
| return parser.parse_args() | |
| def load_json(path: Path) -> dict[str, Any]: | |
| return json.loads(path.read_text(encoding="utf-8")) | |
| def value(row: dict[str, Any], *keys: str) -> float | None: | |
| for key in keys: | |
| raw = row.get(key) | |
| if raw is None: | |
| continue | |
| try: | |
| return float(raw) | |
| except (TypeError, ValueError): | |
| continue | |
| return None | |
| def active_series(data: dict[str, Any]) -> list[tuple[str, str, str]]: | |
| series = data.get("training_series", {}) | |
| active = [item for item in SERIES_ORDER if item[0] in series] | |
| if "center_sft" in series or "center_grpo" in series or "warehouse_sft" in series or "warehouse_grpo" in series: | |
| return [item for item in active if item[0].endswith(("_sft", "_grpo"))] | |
| return active | |
| def line_plot(data: dict[str, Any], y_keys: tuple[str, ...], title: str, ylabel: str, output: Path) -> None: | |
| series = data.get("training_series", {}) | |
| plt.figure(figsize=(10, 5.2)) | |
| plotted = False | |
| for key, label, color in active_series(data): | |
| rows = series.get(key, {}).get("steps", []) | |
| xs: list[float] = [] | |
| ys: list[float] = [] | |
| for idx, row in enumerate(rows, start=1): | |
| y = value(row, *y_keys) | |
| if y is None: | |
| continue | |
| xs.append(value(row, "step", "global_step") or idx) | |
| ys.append(y) | |
| if xs: | |
| plt.plot(xs, ys, label=label, color=color, linewidth=2) | |
| plotted = True | |
| plt.title(title) | |
| plt.xlabel("training step") | |
| plt.ylabel(ylabel) | |
| plt.grid(alpha=0.25) | |
| if plotted: | |
| plt.legend() | |
| else: | |
| plt.text(0.5, 0.5, "No series available", ha="center", va="center", transform=plt.gca().transAxes) | |
| plt.tight_layout() | |
| plt.savefig(output, dpi=160) | |
| plt.close() | |
| def invalid_plot(data: dict[str, Any], output: Path) -> None: | |
| series = data.get("training_series", {}) | |
| labels: list[str] = [] | |
| payloads: list[float] = [] | |
| actions: list[float] = [] | |
| for key, label, _color in active_series(data): | |
| batches = series.get(key, {}).get("reward_batches", []) | |
| if not batches: | |
| continue | |
| labels.append(label) | |
| payloads.append(sum(value(row, "invalid_payloads") or 0 for row in batches)) | |
| actions.append(sum(value(row, "invalid_actions") or 0 for row in batches)) | |
| plt.figure(figsize=(10, 5.2)) | |
| if labels: | |
| xs = range(len(labels)) | |
| plt.bar([x - 0.18 for x in xs], payloads, width=0.36, label="invalid payloads", color="#c2410c") | |
| plt.bar([x + 0.18 for x in xs], actions, width=0.36, label="invalid env actions", color="#b7791f") | |
| plt.xticks(list(xs), labels, rotation=20, ha="right") | |
| plt.legend() | |
| else: | |
| plt.text(0.5, 0.5, "No invalid-action diagnostics available", ha="center", va="center", transform=plt.gca().transAxes) | |
| plt.title("Invalid Payloads / Actions") | |
| plt.ylabel("count across logged reward batches") | |
| plt.grid(axis="y", alpha=0.25) | |
| plt.tight_layout() | |
| plt.savefig(output, dpi=160) | |
| plt.close() | |
| def heldout_plot(data: dict[str, Any], output: Path) -> None: | |
| comparisons = data.get("comparisons", []) | |
| roles = ["center", "warehouse"] | |
| variants = [("base", "#64748b"), ("sft", "#2563eb"), ("grpo", "#7c3aed")] | |
| groups: list[tuple[str, str, str]] = [] | |
| for role in roles: | |
| groups.append((role, "global", "mean_global_score")) | |
| groups.append((role, "role", "mean_center_role_score" if role == "center" else "mean_warehouse_role_score")) | |
| plt.figure(figsize=(11, 5.5)) | |
| group_xs = list(range(len(groups))) | |
| width = 0.22 | |
| any_rows = False | |
| for offset, (variant, color) in zip([-width, 0, width], variants, strict=True): | |
| ys = [] | |
| for role, _metric, key in groups: | |
| row = next((item for item in comparisons if item.get("role") == role and item.get("label") == variant), None) | |
| ys.append(value(row or {}, key) if row else None) | |
| xs = [x + offset for x, y in zip(group_xs, ys, strict=True) if y is not None] | |
| vals = [y for y in ys if y is not None] | |
| if vals: | |
| any_rows = True | |
| plt.bar(xs, vals, width=width, label=variant.upper(), color=color) | |
| if any_rows: | |
| plt.xticks(group_xs, [f"{role}\n{metric}" for role, metric, _key in groups]) | |
| plt.ylim(0, 1) | |
| plt.legend() | |
| else: | |
| plt.text(0.5, 0.5, "No held-out comparisons available", ha="center", va="center", transform=plt.gca().transAxes) | |
| plt.title("Held-out Scores: Base vs SFT vs GRPO") | |
| plt.ylabel("normalized score") | |
| plt.grid(axis="y", alpha=0.25) | |
| plt.tight_layout() | |
| plt.savefig(output, dpi=160) | |
| plt.close() | |
| def main() -> None: | |
| args = parse_args() | |
| data = load_json(args.input) | |
| args.output_dir.mkdir(parents=True, exist_ok=True) | |
| line_plot(data, ("loss",), "Loss Over Step", "loss", args.output_dir / "loss.png") | |
| line_plot(data, ("reward", "rewards/reward_completions/mean"), "Reward Over Step", "reward", args.output_dir / "reward.png") | |
| line_plot(data, ("completions/clipped_ratio", "clipped_ratio"), "Clipped Ratio Over Step", "clipped ratio", args.output_dir / "clipped_ratio.png") | |
| line_plot(data, ("completions/mean_length", "completion_length", "mean_completion_length"), "Completion Length Over Step", "tokens", args.output_dir / "completion_length.png") | |
| invalid_plot(data, args.output_dir / "invalids.png") | |
| heldout_plot(data, args.output_dir / "heldout_comparison.png") | |
| print(f"Wrote plots to {args.output_dir}") | |
| if __name__ == "__main__": | |
| main() | |