Yaz / demo.py
TilelliLab's picture
UX fix: demo.py (download instructions / safetensors / pickle-free demo)
e7289a8 verified
Raw
History Blame Contribute Delete
8.49 kB
"""Yaz+Engram demo β€” an editable knowledge model that KNOWS WHEN IT DOESN'T KNOW.
Free-text prompt -> route it to a fact-atom via a FROZEN Engram (MiniLM) embedding ->
if the routing CONFIDENCE MARGIN (top1-top2 centroid cosine) is high enough, ANSWER
(emit the routed fact); else ABSTAIN ("I'm not confident which fact you mean") and show
the top-2 candidates. Facts are live-editable (--edit) and deletable (--delete) with no
retraining, and the confidence signal is unchanged by edits.
This packages the routing-confidence abstention result (scripts/scaling/s3_route_abstain.py:
AURC 0.004 vs oracle 0.003; on hard name-free clues 0.194 vs GRACE-distance 0.436). Honest
scope: 807K byte-LM, country->capital, first-byte routing, CPU. A real, novel feature
(no published editor refuses on low routing confidence) β€” but not a unique advantage (the
pieces are copyable). Self-contained: imports the router/model but edits no shared file.
Usage:
python demo.py --demo # scripted transcript
python demo.py --prompt "the country of the Eiffel Tower"
python demo.py --prompt "The capital of France is " --edit France=Lima
python demo.py --prompt "What is the deal with quantum gravity" # -> ABSTAIN
"""
from __future__ import annotations
import os
import argparse, sys
from pathlib import Path
import numpy as np
import torch
ROOT = Path(__file__).resolve().parent
_emb_path = os.environ.get("YAZ_EMBEDDER_PATH", "")
if _emb_path: # only add a real path (avoid inserting "" == cwd)
sys.path.insert(0, _emb_path)
sys.path.insert(0, str(ROOT))
from yaz import YazConfig, YazLM
from yaz.semantic_router import SemanticRouter
from scripts.gen_paraphrase_data import TRAIN_TEMPLATES
CKPT = ROOT / "checkpoints" / "yaz_gen_semantic_v2.pt" # PyTorch pickle (fallback)
SAFETENSORS = ROOT / "model.safetensors" # recommended, pickle-free
META = ROOT / "yaz_meta.json"
DEFAULT_THRESHOLD = 0.08 # margin below this -> abstain (tunable; see s3 risk-coverage)
def ids(s):
return torch.tensor(list(s.encode("utf-8")), dtype=torch.long).unsqueeze(0)
class YazDemo:
def __init__(self, threshold=DEFAULT_THRESHOLD):
# Prefer the pickle-free safetensors + JSON sidecar (the recommended artifact);
# fall back to the PyTorch pickle checkpoint if they're not present.
import json
if SAFETENSORS.exists() and META.exists():
from safetensors.torch import load_file
meta = json.loads(META.read_text())
self.cfg = YazConfig(**meta["cfg"])
self.model = YazLM(self.cfg); self.model.load_state_dict(load_file(str(SAFETENSORS)))
self.model.eval()
self.c2i = meta["country_to_target_atom"]
else:
ck = torch.load(CKPT, map_location="cpu", weights_only=False)
self.cfg = YazConfig(**ck["cfg"])
self.model = YazLM(self.cfg); self.model.load_state_dict(ck["model"]); self.model.eval()
self.c2i = ck["country_to_target_atom"]
self.order = list(self.c2i.keys())
self.threshold = threshold
# capitals (full word) for resolving --edit RHS and for nicer output
self.capof = {}
for l in (ROOT / "data" / "facts_para_train.jsonl").read_text().splitlines():
if l:
import json; r = json.loads(l); self.capof.setdefault(r["country"], r["capital"])
self.router = SemanticRouter(self.order, TRAIN_TEMPLATES); self.router.build_centroids()
self.C = self.router.centroids # (n, dim) unit-norm
# ---- routing with confidence ----
def route(self, prompt):
v = self.router.embed(prompt)
sims = self.C @ v; o = np.argsort(-sims)
return int(o[0]), float(sims[o[0]] - sims[o[1]]), [(self.order[int(i)], float(sims[i])) for i in o[:3]]
@torch.no_grad()
def _gen(self, prompt, atom, n=10):
ra = torch.tensor([int(atom)]); out = ids(prompt); plen = out.shape[1]; mc = self.cfg.max_seq_len
for _ in range(n):
ctx = out if out.shape[1] <= mc else out[:, -mc:]
out = torch.cat([out, self.model(ctx, route_atom=ra)[:, -1].argmax(-1, keepdim=True)], dim=1)
return bytes(out[0, plen:].tolist()).decode("latin-1", "ignore").strip()
# ---- CRUD (live, no retrain) ----
def _resolve_source(self, rhs):
"""RHS of --edit may be a country name or a capital; return source country."""
if rhs in self.c2i:
return rhs
for c, cap in self.capof.items():
if cap.lower() == rhs.lower():
return c
return None
def edit(self, tgt, rhs):
src = self._resolve_source(rhs)
if tgt not in self.c2i or src is None:
print(f" [edit skipped: unknown '{tgt}' or '{rhs}']"); return
with torch.no_grad():
self.model.fact_layer.W_dec.weight[:, self.c2i[tgt]] = \
self.model.fact_layer.W_dec.weight[:, self.c2i[src]].clone()
print(f" ✎ EDIT applied: {tgt}'s capital -> {self.capof.get(src, src)} (copied {src}'s atom)")
def delete(self, tgt):
if tgt not in self.c2i:
print(f" [delete skipped: unknown '{tgt}']"); return
a = self.c2i[tgt]
with torch.no_grad():
self.model.fact_layer.W_dec.weight[:, a] = 0.0
self.model.fact_layer.W_enc.weight[a, :] = 0.0; self.model.fact_layer.W_enc.bias[a] = 0.0
print(f" πŸ—‘ DELETE applied: {tgt}'s atom zeroed")
# ---- the answer-or-abstain decision ----
def ask(self, prompt):
atom, margin, top3 = self.route(prompt)
country = self.order[atom]
if margin < self.threshold:
cands = ", ".join(f"{c}" for c, _ in top3[:2])
print(f' Q: "{prompt}"')
print(f" β†’ ABSTAIN (margin {margin:.3f} < {self.threshold}). "
f"I'm not confident which fact you mean β€” {cands}?")
return
gen = self._gen(prompt, atom)
fb = gen.lstrip()[:1]
cap = self.capof.get(country, "?")
print(f' Q: "{prompt}"')
print(f" β†’ ANSWER (routed: {country}, margin {margin:.3f}): first byte '{fb}' "
f"[raw gen: {gen!r}]")
def run_demo(d):
print("\n=== Yaz+Engram: an editor that knows when it doesn't know ===\n")
print("[1] Confident, in-scope question:")
d.ask("The capital of France is ")
print("\n[2] Name-free clue (semantic routing does the entity-ID):")
d.ask("the country of the Eiffel Tower and the Louvre, its capital is ")
print("\n[3] Live edit (no retrain), then re-ask a PARAPHRASE β€” edit transfers:")
d.edit("France", "Lima")
d.ask("the country of the Eiffel Tower and the Louvre, its capital is ")
print("\n[4] Out-of-scope / unknown question β€” the model REFUSES instead of confabulating:")
d.ask("What is the deal with quantum gravity and black holes")
d.ask("Tell me about the best pizza topping")
print("\n(Confidence = Engram top1-top2 centroid margin. The feature on show is ABSTENTION + edit;")
print(" the answer is judged on the FIRST BYTE β€” this 807K model is a first-byte editor, so the")
print(" multi-byte gen is garbled by design (full-word transfer ~0.05, measured). Evidence the")
print(" confidence signal is real: s3 AURC 0.004 vs oracle 0.003; hard name-free 0.194 vs GRACE 0.436.")
print(" A real, novel feature β€” no published editor refuses on low routing confidence β€” but a STEP,")
print(" the pieces are individually published.)\n")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--prompt", type=str, default=None)
ap.add_argument("--edit", type=str, default=None, help="Country=CapitalOrCountry, e.g. France=Lima")
ap.add_argument("--delete", type=str, default=None, help="Country to delete, e.g. France")
ap.add_argument("--threshold", type=float, default=DEFAULT_THRESHOLD)
ap.add_argument("--demo", action="store_true")
args = ap.parse_args()
torch.set_num_threads(1)
d = YazDemo(threshold=args.threshold)
if args.demo:
run_demo(d); return
if args.delete:
d.delete(args.delete)
if args.edit:
tgt, rhs = args.edit.split("=", 1); d.edit(tgt.strip(), rhs.strip())
if args.prompt:
d.ask(args.prompt)
elif not (args.edit or args.delete):
ap.print_help()
if __name__ == "__main__":
main()