supplymind / scripts /plot_training_results.py
Rishav
Tighten role training scaffold
a2144da
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
SERIES_ORDER = [
("center_sft", "center SFT", "#2563eb"),
("center_grpo", "center GRPO", "#7c3aed"),
("warehouse_sft", "warehouse SFT", "#0f766e"),
("warehouse_grpo", "warehouse GRPO", "#047857"),
("center", "center", "#2563eb"),
("warehouse", "warehouse", "#0f766e"),
]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Create static SupplyMind training/eval plots.")
parser.add_argument("--input", type=Path, default=Path("results/training_dashboard.json"))
parser.add_argument("--output-dir", type=Path, default=Path("results/plots"))
return parser.parse_args()
def load_json(path: Path) -> dict[str, Any]:
return json.loads(path.read_text(encoding="utf-8"))
def value(row: dict[str, Any], *keys: str) -> float | None:
for key in keys:
raw = row.get(key)
if raw is None:
continue
try:
return float(raw)
except (TypeError, ValueError):
continue
return None
def active_series(data: dict[str, Any]) -> list[tuple[str, str, str]]:
series = data.get("training_series", {})
active = [item for item in SERIES_ORDER if item[0] in series]
if "center_sft" in series or "center_grpo" in series or "warehouse_sft" in series or "warehouse_grpo" in series:
return [item for item in active if item[0].endswith(("_sft", "_grpo"))]
return active
def line_plot(data: dict[str, Any], y_keys: tuple[str, ...], title: str, ylabel: str, output: Path) -> None:
series = data.get("training_series", {})
plt.figure(figsize=(10, 5.2))
plotted = False
for key, label, color in active_series(data):
rows = series.get(key, {}).get("steps", [])
xs: list[float] = []
ys: list[float] = []
for idx, row in enumerate(rows, start=1):
y = value(row, *y_keys)
if y is None:
continue
xs.append(value(row, "step", "global_step") or idx)
ys.append(y)
if xs:
plt.plot(xs, ys, label=label, color=color, linewidth=2)
plotted = True
plt.title(title)
plt.xlabel("training step")
plt.ylabel(ylabel)
plt.grid(alpha=0.25)
if plotted:
plt.legend()
else:
plt.text(0.5, 0.5, "No series available", ha="center", va="center", transform=plt.gca().transAxes)
plt.tight_layout()
plt.savefig(output, dpi=160)
plt.close()
def invalid_plot(data: dict[str, Any], output: Path) -> None:
series = data.get("training_series", {})
labels: list[str] = []
payloads: list[float] = []
actions: list[float] = []
for key, label, _color in active_series(data):
batches = series.get(key, {}).get("reward_batches", [])
if not batches:
continue
labels.append(label)
payloads.append(sum(value(row, "invalid_payloads") or 0 for row in batches))
actions.append(sum(value(row, "invalid_actions") or 0 for row in batches))
plt.figure(figsize=(10, 5.2))
if labels:
xs = range(len(labels))
plt.bar([x - 0.18 for x in xs], payloads, width=0.36, label="invalid payloads", color="#c2410c")
plt.bar([x + 0.18 for x in xs], actions, width=0.36, label="invalid env actions", color="#b7791f")
plt.xticks(list(xs), labels, rotation=20, ha="right")
plt.legend()
else:
plt.text(0.5, 0.5, "No invalid-action diagnostics available", ha="center", va="center", transform=plt.gca().transAxes)
plt.title("Invalid Payloads / Actions")
plt.ylabel("count across logged reward batches")
plt.grid(axis="y", alpha=0.25)
plt.tight_layout()
plt.savefig(output, dpi=160)
plt.close()
def heldout_plot(data: dict[str, Any], output: Path) -> None:
comparisons = data.get("comparisons", [])
roles = ["center", "warehouse"]
variants = [("base", "#64748b"), ("sft", "#2563eb"), ("grpo", "#7c3aed")]
groups: list[tuple[str, str, str]] = []
for role in roles:
groups.append((role, "global", "mean_global_score"))
groups.append((role, "role", "mean_center_role_score" if role == "center" else "mean_warehouse_role_score"))
plt.figure(figsize=(11, 5.5))
group_xs = list(range(len(groups)))
width = 0.22
any_rows = False
for offset, (variant, color) in zip([-width, 0, width], variants, strict=True):
ys = []
for role, _metric, key in groups:
row = next((item for item in comparisons if item.get("role") == role and item.get("label") == variant), None)
ys.append(value(row or {}, key) if row else None)
xs = [x + offset for x, y in zip(group_xs, ys, strict=True) if y is not None]
vals = [y for y in ys if y is not None]
if vals:
any_rows = True
plt.bar(xs, vals, width=width, label=variant.upper(), color=color)
if any_rows:
plt.xticks(group_xs, [f"{role}\n{metric}" for role, metric, _key in groups])
plt.ylim(0, 1)
plt.legend()
else:
plt.text(0.5, 0.5, "No held-out comparisons available", ha="center", va="center", transform=plt.gca().transAxes)
plt.title("Held-out Scores: Base vs SFT vs GRPO")
plt.ylabel("normalized score")
plt.grid(axis="y", alpha=0.25)
plt.tight_layout()
plt.savefig(output, dpi=160)
plt.close()
def main() -> None:
args = parse_args()
data = load_json(args.input)
args.output_dir.mkdir(parents=True, exist_ok=True)
line_plot(data, ("loss",), "Loss Over Step", "loss", args.output_dir / "loss.png")
line_plot(data, ("reward", "rewards/reward_completions/mean"), "Reward Over Step", "reward", args.output_dir / "reward.png")
line_plot(data, ("completions/clipped_ratio", "clipped_ratio"), "Clipped Ratio Over Step", "clipped ratio", args.output_dir / "clipped_ratio.png")
line_plot(data, ("completions/mean_length", "completion_length", "mean_completion_length"), "Completion Length Over Step", "tokens", args.output_dir / "completion_length.png")
invalid_plot(data, args.output_dir / "invalids.png")
heldout_plot(data, args.output_dir / "heldout_comparison.png")
print(f"Wrote plots to {args.output_dir}")
if __name__ == "__main__":
main()