File size: 8,488 Bytes
b14638e e7289a8 b14638e e7289a8 b14638e e7289a8 b14638e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | """Yaz+Engram demo β an editable knowledge model that KNOWS WHEN IT DOESN'T KNOW.
Free-text prompt -> route it to a fact-atom via a FROZEN Engram (MiniLM) embedding ->
if the routing CONFIDENCE MARGIN (top1-top2 centroid cosine) is high enough, ANSWER
(emit the routed fact); else ABSTAIN ("I'm not confident which fact you mean") and show
the top-2 candidates. Facts are live-editable (--edit) and deletable (--delete) with no
retraining, and the confidence signal is unchanged by edits.
This packages the routing-confidence abstention result (scripts/scaling/s3_route_abstain.py:
AURC 0.004 vs oracle 0.003; on hard name-free clues 0.194 vs GRACE-distance 0.436). Honest
scope: 807K byte-LM, country->capital, first-byte routing, CPU. A real, novel feature
(no published editor refuses on low routing confidence) β but not a unique advantage (the
pieces are copyable). Self-contained: imports the router/model but edits no shared file.
Usage:
python demo.py --demo # scripted transcript
python demo.py --prompt "the country of the Eiffel Tower"
python demo.py --prompt "The capital of France is " --edit France=Lima
python demo.py --prompt "What is the deal with quantum gravity" # -> ABSTAIN
"""
from __future__ import annotations
import os
import argparse, sys
from pathlib import Path
import numpy as np
import torch
ROOT = Path(__file__).resolve().parent
_emb_path = os.environ.get("YAZ_EMBEDDER_PATH", "")
if _emb_path: # only add a real path (avoid inserting "" == cwd)
sys.path.insert(0, _emb_path)
sys.path.insert(0, str(ROOT))
from yaz import YazConfig, YazLM
from yaz.semantic_router import SemanticRouter
from scripts.gen_paraphrase_data import TRAIN_TEMPLATES
CKPT = ROOT / "checkpoints" / "yaz_gen_semantic_v2.pt" # PyTorch pickle (fallback)
SAFETENSORS = ROOT / "model.safetensors" # recommended, pickle-free
META = ROOT / "yaz_meta.json"
DEFAULT_THRESHOLD = 0.08 # margin below this -> abstain (tunable; see s3 risk-coverage)
def ids(s):
return torch.tensor(list(s.encode("utf-8")), dtype=torch.long).unsqueeze(0)
class YazDemo:
def __init__(self, threshold=DEFAULT_THRESHOLD):
# Prefer the pickle-free safetensors + JSON sidecar (the recommended artifact);
# fall back to the PyTorch pickle checkpoint if they're not present.
import json
if SAFETENSORS.exists() and META.exists():
from safetensors.torch import load_file
meta = json.loads(META.read_text())
self.cfg = YazConfig(**meta["cfg"])
self.model = YazLM(self.cfg); self.model.load_state_dict(load_file(str(SAFETENSORS)))
self.model.eval()
self.c2i = meta["country_to_target_atom"]
else:
ck = torch.load(CKPT, map_location="cpu", weights_only=False)
self.cfg = YazConfig(**ck["cfg"])
self.model = YazLM(self.cfg); self.model.load_state_dict(ck["model"]); self.model.eval()
self.c2i = ck["country_to_target_atom"]
self.order = list(self.c2i.keys())
self.threshold = threshold
# capitals (full word) for resolving --edit RHS and for nicer output
self.capof = {}
for l in (ROOT / "data" / "facts_para_train.jsonl").read_text().splitlines():
if l:
import json; r = json.loads(l); self.capof.setdefault(r["country"], r["capital"])
self.router = SemanticRouter(self.order, TRAIN_TEMPLATES); self.router.build_centroids()
self.C = self.router.centroids # (n, dim) unit-norm
# ---- routing with confidence ----
def route(self, prompt):
v = self.router.embed(prompt)
sims = self.C @ v; o = np.argsort(-sims)
return int(o[0]), float(sims[o[0]] - sims[o[1]]), [(self.order[int(i)], float(sims[i])) for i in o[:3]]
@torch.no_grad()
def _gen(self, prompt, atom, n=10):
ra = torch.tensor([int(atom)]); out = ids(prompt); plen = out.shape[1]; mc = self.cfg.max_seq_len
for _ in range(n):
ctx = out if out.shape[1] <= mc else out[:, -mc:]
out = torch.cat([out, self.model(ctx, route_atom=ra)[:, -1].argmax(-1, keepdim=True)], dim=1)
return bytes(out[0, plen:].tolist()).decode("latin-1", "ignore").strip()
# ---- CRUD (live, no retrain) ----
def _resolve_source(self, rhs):
"""RHS of --edit may be a country name or a capital; return source country."""
if rhs in self.c2i:
return rhs
for c, cap in self.capof.items():
if cap.lower() == rhs.lower():
return c
return None
def edit(self, tgt, rhs):
src = self._resolve_source(rhs)
if tgt not in self.c2i or src is None:
print(f" [edit skipped: unknown '{tgt}' or '{rhs}']"); return
with torch.no_grad():
self.model.fact_layer.W_dec.weight[:, self.c2i[tgt]] = \
self.model.fact_layer.W_dec.weight[:, self.c2i[src]].clone()
print(f" β EDIT applied: {tgt}'s capital -> {self.capof.get(src, src)} (copied {src}'s atom)")
def delete(self, tgt):
if tgt not in self.c2i:
print(f" [delete skipped: unknown '{tgt}']"); return
a = self.c2i[tgt]
with torch.no_grad():
self.model.fact_layer.W_dec.weight[:, a] = 0.0
self.model.fact_layer.W_enc.weight[a, :] = 0.0; self.model.fact_layer.W_enc.bias[a] = 0.0
print(f" π DELETE applied: {tgt}'s atom zeroed")
# ---- the answer-or-abstain decision ----
def ask(self, prompt):
atom, margin, top3 = self.route(prompt)
country = self.order[atom]
if margin < self.threshold:
cands = ", ".join(f"{c}" for c, _ in top3[:2])
print(f' Q: "{prompt}"')
print(f" β ABSTAIN (margin {margin:.3f} < {self.threshold}). "
f"I'm not confident which fact you mean β {cands}?")
return
gen = self._gen(prompt, atom)
fb = gen.lstrip()[:1]
cap = self.capof.get(country, "?")
print(f' Q: "{prompt}"')
print(f" β ANSWER (routed: {country}, margin {margin:.3f}): first byte '{fb}' "
f"[raw gen: {gen!r}]")
def run_demo(d):
print("\n=== Yaz+Engram: an editor that knows when it doesn't know ===\n")
print("[1] Confident, in-scope question:")
d.ask("The capital of France is ")
print("\n[2] Name-free clue (semantic routing does the entity-ID):")
d.ask("the country of the Eiffel Tower and the Louvre, its capital is ")
print("\n[3] Live edit (no retrain), then re-ask a PARAPHRASE β edit transfers:")
d.edit("France", "Lima")
d.ask("the country of the Eiffel Tower and the Louvre, its capital is ")
print("\n[4] Out-of-scope / unknown question β the model REFUSES instead of confabulating:")
d.ask("What is the deal with quantum gravity and black holes")
d.ask("Tell me about the best pizza topping")
print("\n(Confidence = Engram top1-top2 centroid margin. The feature on show is ABSTENTION + edit;")
print(" the answer is judged on the FIRST BYTE β this 807K model is a first-byte editor, so the")
print(" multi-byte gen is garbled by design (full-word transfer ~0.05, measured). Evidence the")
print(" confidence signal is real: s3 AURC 0.004 vs oracle 0.003; hard name-free 0.194 vs GRACE 0.436.")
print(" A real, novel feature β no published editor refuses on low routing confidence β but a STEP,")
print(" the pieces are individually published.)\n")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--prompt", type=str, default=None)
ap.add_argument("--edit", type=str, default=None, help="Country=CapitalOrCountry, e.g. France=Lima")
ap.add_argument("--delete", type=str, default=None, help="Country to delete, e.g. France")
ap.add_argument("--threshold", type=float, default=DEFAULT_THRESHOLD)
ap.add_argument("--demo", action="store_true")
args = ap.parse_args()
torch.set_num_threads(1)
d = YazDemo(threshold=args.threshold)
if args.demo:
run_demo(d); return
if args.delete:
d.delete(args.delete)
if args.edit:
tgt, rhs = args.edit.split("=", 1); d.edit(tgt.strip(), rhs.strip())
if args.prompt:
d.ask(args.prompt)
elif not (args.edit or args.delete):
ap.print_help()
if __name__ == "__main__":
main()
|