| """Yaz+Engram demo β an editable knowledge model that KNOWS WHEN IT DOESN'T KNOW. |
| |
| Free-text prompt -> route it to a fact-atom via a FROZEN Engram (MiniLM) embedding -> |
| if the routing CONFIDENCE MARGIN (top1-top2 centroid cosine) is high enough, ANSWER |
| (emit the routed fact); else ABSTAIN ("I'm not confident which fact you mean") and show |
| the top-2 candidates. Facts are live-editable (--edit) and deletable (--delete) with no |
| retraining, and the confidence signal is unchanged by edits. |
| |
| This packages the routing-confidence abstention result (scripts/scaling/s3_route_abstain.py: |
| AURC 0.004 vs oracle 0.003; on hard name-free clues 0.194 vs GRACE-distance 0.436). Honest |
| scope: 807K byte-LM, country->capital, first-byte routing, CPU. A real, novel feature |
| (no published editor refuses on low routing confidence) β but not a unique advantage (the |
| pieces are copyable). Self-contained: imports the router/model but edits no shared file. |
| |
| Usage: |
| python demo.py --demo # scripted transcript |
| python demo.py --prompt "the country of the Eiffel Tower" |
| python demo.py --prompt "The capital of France is " --edit France=Lima |
| python demo.py --prompt "What is the deal with quantum gravity" # -> ABSTAIN |
| """ |
| from __future__ import annotations |
| import os |
| import argparse, sys |
| from pathlib import Path |
| import numpy as np |
| import torch |
|
|
| ROOT = Path(__file__).resolve().parent |
| _emb_path = os.environ.get("YAZ_EMBEDDER_PATH", "") |
| if _emb_path: |
| sys.path.insert(0, _emb_path) |
| sys.path.insert(0, str(ROOT)) |
| from yaz import YazConfig, YazLM |
| from yaz.semantic_router import SemanticRouter |
| from scripts.gen_paraphrase_data import TRAIN_TEMPLATES |
|
|
| CKPT = ROOT / "checkpoints" / "yaz_gen_semantic_v2.pt" |
| SAFETENSORS = ROOT / "model.safetensors" |
| META = ROOT / "yaz_meta.json" |
| DEFAULT_THRESHOLD = 0.08 |
|
|
|
|
| def ids(s): |
| return torch.tensor(list(s.encode("utf-8")), dtype=torch.long).unsqueeze(0) |
|
|
|
|
| class YazDemo: |
| def __init__(self, threshold=DEFAULT_THRESHOLD): |
| |
| |
| import json |
| if SAFETENSORS.exists() and META.exists(): |
| from safetensors.torch import load_file |
| meta = json.loads(META.read_text()) |
| self.cfg = YazConfig(**meta["cfg"]) |
| self.model = YazLM(self.cfg); self.model.load_state_dict(load_file(str(SAFETENSORS))) |
| self.model.eval() |
| self.c2i = meta["country_to_target_atom"] |
| else: |
| ck = torch.load(CKPT, map_location="cpu", weights_only=False) |
| self.cfg = YazConfig(**ck["cfg"]) |
| self.model = YazLM(self.cfg); self.model.load_state_dict(ck["model"]); self.model.eval() |
| self.c2i = ck["country_to_target_atom"] |
| self.order = list(self.c2i.keys()) |
| self.threshold = threshold |
| |
| self.capof = {} |
| for l in (ROOT / "data" / "facts_para_train.jsonl").read_text().splitlines(): |
| if l: |
| import json; r = json.loads(l); self.capof.setdefault(r["country"], r["capital"]) |
| self.router = SemanticRouter(self.order, TRAIN_TEMPLATES); self.router.build_centroids() |
| self.C = self.router.centroids |
|
|
| |
| def route(self, prompt): |
| v = self.router.embed(prompt) |
| sims = self.C @ v; o = np.argsort(-sims) |
| return int(o[0]), float(sims[o[0]] - sims[o[1]]), [(self.order[int(i)], float(sims[i])) for i in o[:3]] |
|
|
| @torch.no_grad() |
| def _gen(self, prompt, atom, n=10): |
| ra = torch.tensor([int(atom)]); out = ids(prompt); plen = out.shape[1]; mc = self.cfg.max_seq_len |
| for _ in range(n): |
| ctx = out if out.shape[1] <= mc else out[:, -mc:] |
| out = torch.cat([out, self.model(ctx, route_atom=ra)[:, -1].argmax(-1, keepdim=True)], dim=1) |
| return bytes(out[0, plen:].tolist()).decode("latin-1", "ignore").strip() |
|
|
| |
| def _resolve_source(self, rhs): |
| """RHS of --edit may be a country name or a capital; return source country.""" |
| if rhs in self.c2i: |
| return rhs |
| for c, cap in self.capof.items(): |
| if cap.lower() == rhs.lower(): |
| return c |
| return None |
|
|
| def edit(self, tgt, rhs): |
| src = self._resolve_source(rhs) |
| if tgt not in self.c2i or src is None: |
| print(f" [edit skipped: unknown '{tgt}' or '{rhs}']"); return |
| with torch.no_grad(): |
| self.model.fact_layer.W_dec.weight[:, self.c2i[tgt]] = \ |
| self.model.fact_layer.W_dec.weight[:, self.c2i[src]].clone() |
| print(f" β EDIT applied: {tgt}'s capital -> {self.capof.get(src, src)} (copied {src}'s atom)") |
|
|
| def delete(self, tgt): |
| if tgt not in self.c2i: |
| print(f" [delete skipped: unknown '{tgt}']"); return |
| a = self.c2i[tgt] |
| with torch.no_grad(): |
| self.model.fact_layer.W_dec.weight[:, a] = 0.0 |
| self.model.fact_layer.W_enc.weight[a, :] = 0.0; self.model.fact_layer.W_enc.bias[a] = 0.0 |
| print(f" π DELETE applied: {tgt}'s atom zeroed") |
|
|
| |
| def ask(self, prompt): |
| atom, margin, top3 = self.route(prompt) |
| country = self.order[atom] |
| if margin < self.threshold: |
| cands = ", ".join(f"{c}" for c, _ in top3[:2]) |
| print(f' Q: "{prompt}"') |
| print(f" β ABSTAIN (margin {margin:.3f} < {self.threshold}). " |
| f"I'm not confident which fact you mean β {cands}?") |
| return |
| gen = self._gen(prompt, atom) |
| fb = gen.lstrip()[:1] |
| cap = self.capof.get(country, "?") |
| print(f' Q: "{prompt}"') |
| print(f" β ANSWER (routed: {country}, margin {margin:.3f}): first byte '{fb}' " |
| f"[raw gen: {gen!r}]") |
|
|
|
|
| def run_demo(d): |
| print("\n=== Yaz+Engram: an editor that knows when it doesn't know ===\n") |
| print("[1] Confident, in-scope question:") |
| d.ask("The capital of France is ") |
| print("\n[2] Name-free clue (semantic routing does the entity-ID):") |
| d.ask("the country of the Eiffel Tower and the Louvre, its capital is ") |
| print("\n[3] Live edit (no retrain), then re-ask a PARAPHRASE β edit transfers:") |
| d.edit("France", "Lima") |
| d.ask("the country of the Eiffel Tower and the Louvre, its capital is ") |
| print("\n[4] Out-of-scope / unknown question β the model REFUSES instead of confabulating:") |
| d.ask("What is the deal with quantum gravity and black holes") |
| d.ask("Tell me about the best pizza topping") |
| print("\n(Confidence = Engram top1-top2 centroid margin. The feature on show is ABSTENTION + edit;") |
| print(" the answer is judged on the FIRST BYTE β this 807K model is a first-byte editor, so the") |
| print(" multi-byte gen is garbled by design (full-word transfer ~0.05, measured). Evidence the") |
| print(" confidence signal is real: s3 AURC 0.004 vs oracle 0.003; hard name-free 0.194 vs GRACE 0.436.") |
| print(" A real, novel feature β no published editor refuses on low routing confidence β but a STEP,") |
| print(" the pieces are individually published.)\n") |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--prompt", type=str, default=None) |
| ap.add_argument("--edit", type=str, default=None, help="Country=CapitalOrCountry, e.g. France=Lima") |
| ap.add_argument("--delete", type=str, default=None, help="Country to delete, e.g. France") |
| ap.add_argument("--threshold", type=float, default=DEFAULT_THRESHOLD) |
| ap.add_argument("--demo", action="store_true") |
| args = ap.parse_args() |
| torch.set_num_threads(1) |
| d = YazDemo(threshold=args.threshold) |
| if args.demo: |
| run_demo(d); return |
| if args.delete: |
| d.delete(args.delete) |
| if args.edit: |
| tgt, rhs = args.edit.split("=", 1); d.edit(tgt.strip(), rhs.strip()) |
| if args.prompt: |
| d.ask(args.prompt) |
| elif not (args.edit or args.delete): |
| ap.print_help() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|