gary-neuron / training /benchmark.py
gary23w's picture
gary-neuron: async NCA + top-2 MoE, 26k params, 99.97%/100% exact-match on 7-digit addition
57f9808 verified
Raw
History Blame Contribute Delete
3.76 kB
"""Rigorous, adversarial benchmark for gary-neuron. Reports TRUE exact-match on a
large held-out test set the model never trained on, stress-tests the hardest
long-carry ripples, checks robustness to the random async update order, and
shows the NCA 'train short / run longer' property. Usage:
python benchmark.py main # held-out + adversarial + by-length
python benchmark.py sweep # inference-steps and update-prob sweeps
"""
import os, sys, json, numpy as np
from data import make_batch, exact_match, digits_rev, to_int
from garyneuron import forward_np
D = os.path.dirname(os.path.abspath(__file__))
CKPT = os.environ.get("CKPT", f"{D}/final.npz")
z = np.load(CKPT, allow_pickle=True)
W = {k[2:]: z[k] for k in z.files if k.startswith("P/")}
cfg = json.loads(str(z["cfg"]))
S = cfg["S"]
mode = sys.argv[1] if len(sys.argv) > 1 else "main"
print(f"# gary-neuron benchmark | step {int(z['step'])} | trained steps={cfg['steps']} p={cfg['p_update']} | {mode}")
def C(**kw):
c = dict(cfg); c.update(kw); return c
def grids(pairs):
A = np.array([digits_rev(a, S) for a, b in pairs])
B = np.array([digits_rev(b, S) for a, b in pairs])
Y = np.array([digits_rev(a + b, S) for a, b in pairs])
return A, B, Y
if mode == "main":
# ---- 1. large held-out test, robustness to async update order ----
A, B, Y = make_batch(10000, S, np.random.default_rng(20260611), 7)
ems = [exact_match(forward_np(W, A, B, cfg, np.random.default_rng(s)), Y) for s in range(8)]
print(f"\n[held-out 10k, ≤7-digit] exact-match across 8 random async orders:")
print(f" mean {np.mean(ems)*100:.3f}% min {np.min(ems)*100:.3f}% max {np.max(ems)*100:.3f}% std {np.std(ems)*100:.4f}")
# ---- 2. adversarial maximal-carry ripples (run a bit longer: steps=28) ----
hard = []
for L in range(1, 9):
hard.append((10**L - 1, 1)) # 99..9 + 1 -> full-length carry
hard.append((10**L - 1, 10**L - 1)) # 99..9 + 99..9
hard += [(9999999, 1), (9999999, 9999999), (5555555, 4444445),
(1234567, 8765433), (9090909, 909091), (7777777, 2222223)]
HA, HB, HY = grids(hard)
pred = forward_np(W, HA, HB, C(steps=28), np.random.default_rng(0))
ok = (pred == HY).all(1)
print(f"\n[adversarial max-carry, {len(hard)} cases @ steps=28] {int(ok.sum())}/{len(hard)} correct")
for (a, b), o, p in zip(hard, ok, pred):
flag = "ok " if o else "MISS"
if not o:
print(f" {flag} {a} + {b} = {a+b} got {to_int(p[None])[0]}")
# ---- 3. accuracy by operand length ----
print("\n[exact-match by max operand length @ steps=24]")
for L in range(1, 8):
rng = np.random.default_rng(500 + L)
lo = 10**(L-1) if L > 1 else 0
a = rng.integers(lo, 10**L, 4000); b = rng.integers(lo, 10**L, 4000)
A, B, Y = grids(list(zip(a.tolist(), b.tolist())))
em = exact_match(forward_np(W, A, B, C(steps=24), np.random.default_rng(0)), Y)
print(f" len {L}: {em*100:6.3f}%")
if mode == "sweep":
A, B, Y = make_batch(10000, S, np.random.default_rng(424242), 7)
print("\n[inference async-steps sweep] (trained at 20)")
for st in [8, 10, 12, 16, 20, 24, 28, 32]:
ems = [exact_match(forward_np(W, A, B, C(steps=st), np.random.default_rng(s)), Y) for s in range(3)]
print(f" steps={st:2d}: exact {np.mean(ems)*100:6.3f}% (±{np.std(ems)*100:.3f})")
print("\n[update-probability sweep @ steps=28] p=1.0 is fully synchronous")
for p in [0.25, 0.5, 0.75, 1.0]:
ems = [exact_match(forward_np(W, A, B, C(steps=28, p_update=p), np.random.default_rng(s)), Y) for s in range(3)]
print(f" p_update={p}: exact {np.mean(ems)*100:6.3f}% (±{np.std(ems)*100:.3f})")