"""Rigorous, adversarial benchmark for gary-neuron. Reports TRUE exact-match on a large held-out test set the model never trained on, stress-tests the hardest long-carry ripples, checks robustness to the random async update order, and shows the NCA 'train short / run longer' property. Usage: python benchmark.py main # held-out + adversarial + by-length python benchmark.py sweep # inference-steps and update-prob sweeps """ import os, sys, json, numpy as np from data import make_batch, exact_match, digits_rev, to_int from garyneuron import forward_np D = os.path.dirname(os.path.abspath(__file__)) CKPT = os.environ.get("CKPT", f"{D}/final.npz") z = np.load(CKPT, allow_pickle=True) W = {k[2:]: z[k] for k in z.files if k.startswith("P/")} cfg = json.loads(str(z["cfg"])) S = cfg["S"] mode = sys.argv[1] if len(sys.argv) > 1 else "main" print(f"# gary-neuron benchmark | step {int(z['step'])} | trained steps={cfg['steps']} p={cfg['p_update']} | {mode}") def C(**kw): c = dict(cfg); c.update(kw); return c def grids(pairs): A = np.array([digits_rev(a, S) for a, b in pairs]) B = np.array([digits_rev(b, S) for a, b in pairs]) Y = np.array([digits_rev(a + b, S) for a, b in pairs]) return A, B, Y if mode == "main": # ---- 1. large held-out test, robustness to async update order ---- A, B, Y = make_batch(10000, S, np.random.default_rng(20260611), 7) ems = [exact_match(forward_np(W, A, B, cfg, np.random.default_rng(s)), Y) for s in range(8)] print(f"\n[held-out 10k, ≤7-digit] exact-match across 8 random async orders:") print(f" mean {np.mean(ems)*100:.3f}% min {np.min(ems)*100:.3f}% max {np.max(ems)*100:.3f}% std {np.std(ems)*100:.4f}") # ---- 2. adversarial maximal-carry ripples (run a bit longer: steps=28) ---- hard = [] for L in range(1, 9): hard.append((10**L - 1, 1)) # 99..9 + 1 -> full-length carry hard.append((10**L - 1, 10**L - 1)) # 99..9 + 99..9 hard += [(9999999, 1), (9999999, 9999999), (5555555, 4444445), (1234567, 8765433), (9090909, 909091), (7777777, 2222223)] HA, HB, HY = grids(hard) pred = forward_np(W, HA, HB, C(steps=28), np.random.default_rng(0)) ok = (pred == HY).all(1) print(f"\n[adversarial max-carry, {len(hard)} cases @ steps=28] {int(ok.sum())}/{len(hard)} correct") for (a, b), o, p in zip(hard, ok, pred): flag = "ok " if o else "MISS" if not o: print(f" {flag} {a} + {b} = {a+b} got {to_int(p[None])[0]}") # ---- 3. accuracy by operand length ---- print("\n[exact-match by max operand length @ steps=24]") for L in range(1, 8): rng = np.random.default_rng(500 + L) lo = 10**(L-1) if L > 1 else 0 a = rng.integers(lo, 10**L, 4000); b = rng.integers(lo, 10**L, 4000) A, B, Y = grids(list(zip(a.tolist(), b.tolist()))) em = exact_match(forward_np(W, A, B, C(steps=24), np.random.default_rng(0)), Y) print(f" len {L}: {em*100:6.3f}%") if mode == "sweep": A, B, Y = make_batch(10000, S, np.random.default_rng(424242), 7) print("\n[inference async-steps sweep] (trained at 20)") for st in [8, 10, 12, 16, 20, 24, 28, 32]: ems = [exact_match(forward_np(W, A, B, C(steps=st), np.random.default_rng(s)), Y) for s in range(3)] print(f" steps={st:2d}: exact {np.mean(ems)*100:6.3f}% (±{np.std(ems)*100:.3f})") print("\n[update-probability sweep @ steps=28] p=1.0 is fully synchronous") for p in [0.25, 0.5, 0.75, 1.0]: ems = [exact_match(forward_np(W, A, B, C(steps=28, p_update=p), np.random.default_rng(s)), Y) for s in range(3)] print(f" p_update={p}: exact {np.mean(ems)*100:6.3f}% (±{np.std(ems)*100:.3f})")