File size: 8,488 Bytes
b14638e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7289a8
 
 
b14638e
 
 
 
 
e7289a8
 
 
b14638e
 
 
 
 
 
 
 
 
e7289a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b14638e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""Yaz+Engram demo β€” an editable knowledge model that KNOWS WHEN IT DOESN'T KNOW.

Free-text prompt -> route it to a fact-atom via a FROZEN Engram (MiniLM) embedding ->
if the routing CONFIDENCE MARGIN (top1-top2 centroid cosine) is high enough, ANSWER
(emit the routed fact); else ABSTAIN ("I'm not confident which fact you mean") and show
the top-2 candidates. Facts are live-editable (--edit) and deletable (--delete) with no
retraining, and the confidence signal is unchanged by edits.

This packages the routing-confidence abstention result (scripts/scaling/s3_route_abstain.py:
AURC 0.004 vs oracle 0.003; on hard name-free clues 0.194 vs GRACE-distance 0.436). Honest
scope: 807K byte-LM, country->capital, first-byte routing, CPU. A real, novel feature
(no published editor refuses on low routing confidence) β€” but not a unique advantage (the
pieces are copyable). Self-contained: imports the router/model but edits no shared file.

Usage:
  python demo.py --demo                         # scripted transcript
  python demo.py --prompt "the country of the Eiffel Tower"
  python demo.py --prompt "The capital of France is " --edit France=Lima
  python demo.py --prompt "What is the deal with quantum gravity"   # -> ABSTAIN
"""
from __future__ import annotations
import os
import argparse, sys
from pathlib import Path
import numpy as np
import torch

ROOT = Path(__file__).resolve().parent
_emb_path = os.environ.get("YAZ_EMBEDDER_PATH", "")
if _emb_path:                       # only add a real path (avoid inserting "" == cwd)
    sys.path.insert(0, _emb_path)
sys.path.insert(0, str(ROOT))
from yaz import YazConfig, YazLM
from yaz.semantic_router import SemanticRouter
from scripts.gen_paraphrase_data import TRAIN_TEMPLATES

CKPT = ROOT / "checkpoints" / "yaz_gen_semantic_v2.pt"   # PyTorch pickle (fallback)
SAFETENSORS = ROOT / "model.safetensors"                 # recommended, pickle-free
META = ROOT / "yaz_meta.json"
DEFAULT_THRESHOLD = 0.08          # margin below this -> abstain (tunable; see s3 risk-coverage)


def ids(s):
    return torch.tensor(list(s.encode("utf-8")), dtype=torch.long).unsqueeze(0)


class YazDemo:
    def __init__(self, threshold=DEFAULT_THRESHOLD):
        # Prefer the pickle-free safetensors + JSON sidecar (the recommended artifact);
        # fall back to the PyTorch pickle checkpoint if they're not present.
        import json
        if SAFETENSORS.exists() and META.exists():
            from safetensors.torch import load_file
            meta = json.loads(META.read_text())
            self.cfg = YazConfig(**meta["cfg"])
            self.model = YazLM(self.cfg); self.model.load_state_dict(load_file(str(SAFETENSORS)))
            self.model.eval()
            self.c2i = meta["country_to_target_atom"]
        else:
            ck = torch.load(CKPT, map_location="cpu", weights_only=False)
            self.cfg = YazConfig(**ck["cfg"])
            self.model = YazLM(self.cfg); self.model.load_state_dict(ck["model"]); self.model.eval()
            self.c2i = ck["country_to_target_atom"]
        self.order = list(self.c2i.keys())
        self.threshold = threshold
        # capitals (full word) for resolving --edit RHS and for nicer output
        self.capof = {}
        for l in (ROOT / "data" / "facts_para_train.jsonl").read_text().splitlines():
            if l:
                import json; r = json.loads(l); self.capof.setdefault(r["country"], r["capital"])
        self.router = SemanticRouter(self.order, TRAIN_TEMPLATES); self.router.build_centroids()
        self.C = self.router.centroids                       # (n, dim) unit-norm

    # ---- routing with confidence ----
    def route(self, prompt):
        v = self.router.embed(prompt)
        sims = self.C @ v; o = np.argsort(-sims)
        return int(o[0]), float(sims[o[0]] - sims[o[1]]), [(self.order[int(i)], float(sims[i])) for i in o[:3]]

    @torch.no_grad()
    def _gen(self, prompt, atom, n=10):
        ra = torch.tensor([int(atom)]); out = ids(prompt); plen = out.shape[1]; mc = self.cfg.max_seq_len
        for _ in range(n):
            ctx = out if out.shape[1] <= mc else out[:, -mc:]
            out = torch.cat([out, self.model(ctx, route_atom=ra)[:, -1].argmax(-1, keepdim=True)], dim=1)
        return bytes(out[0, plen:].tolist()).decode("latin-1", "ignore").strip()

    # ---- CRUD (live, no retrain) ----
    def _resolve_source(self, rhs):
        """RHS of --edit may be a country name or a capital; return source country."""
        if rhs in self.c2i:
            return rhs
        for c, cap in self.capof.items():
            if cap.lower() == rhs.lower():
                return c
        return None

    def edit(self, tgt, rhs):
        src = self._resolve_source(rhs)
        if tgt not in self.c2i or src is None:
            print(f"  [edit skipped: unknown '{tgt}' or '{rhs}']"); return
        with torch.no_grad():
            self.model.fact_layer.W_dec.weight[:, self.c2i[tgt]] = \
                self.model.fact_layer.W_dec.weight[:, self.c2i[src]].clone()
        print(f"  ✎ EDIT applied: {tgt}'s capital -> {self.capof.get(src, src)} (copied {src}'s atom)")

    def delete(self, tgt):
        if tgt not in self.c2i:
            print(f"  [delete skipped: unknown '{tgt}']"); return
        a = self.c2i[tgt]
        with torch.no_grad():
            self.model.fact_layer.W_dec.weight[:, a] = 0.0
            self.model.fact_layer.W_enc.weight[a, :] = 0.0; self.model.fact_layer.W_enc.bias[a] = 0.0
        print(f"  πŸ—‘ DELETE applied: {tgt}'s atom zeroed")

    # ---- the answer-or-abstain decision ----
    def ask(self, prompt):
        atom, margin, top3 = self.route(prompt)
        country = self.order[atom]
        if margin < self.threshold:
            cands = ", ".join(f"{c}" for c, _ in top3[:2])
            print(f'  Q: "{prompt}"')
            print(f"  β†’ ABSTAIN (margin {margin:.3f} < {self.threshold}). "
                  f"I'm not confident which fact you mean β€” {cands}?")
            return
        gen = self._gen(prompt, atom)
        fb = gen.lstrip()[:1]
        cap = self.capof.get(country, "?")
        print(f'  Q: "{prompt}"')
        print(f"  β†’ ANSWER (routed: {country}, margin {margin:.3f}): first byte '{fb}'  "
              f"[raw gen: {gen!r}]")


def run_demo(d):
    print("\n=== Yaz+Engram: an editor that knows when it doesn't know ===\n")
    print("[1] Confident, in-scope question:")
    d.ask("The capital of France is ")
    print("\n[2] Name-free clue (semantic routing does the entity-ID):")
    d.ask("the country of the Eiffel Tower and the Louvre, its capital is ")
    print("\n[3] Live edit (no retrain), then re-ask a PARAPHRASE β€” edit transfers:")
    d.edit("France", "Lima")
    d.ask("the country of the Eiffel Tower and the Louvre, its capital is ")
    print("\n[4] Out-of-scope / unknown question β€” the model REFUSES instead of confabulating:")
    d.ask("What is the deal with quantum gravity and black holes")
    d.ask("Tell me about the best pizza topping")
    print("\n(Confidence = Engram top1-top2 centroid margin. The feature on show is ABSTENTION + edit;")
    print(" the answer is judged on the FIRST BYTE β€” this 807K model is a first-byte editor, so the")
    print(" multi-byte gen is garbled by design (full-word transfer ~0.05, measured). Evidence the")
    print(" confidence signal is real: s3 AURC 0.004 vs oracle 0.003; hard name-free 0.194 vs GRACE 0.436.")
    print(" A real, novel feature β€” no published editor refuses on low routing confidence β€” but a STEP,")
    print(" the pieces are individually published.)\n")


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--prompt", type=str, default=None)
    ap.add_argument("--edit", type=str, default=None, help="Country=CapitalOrCountry, e.g. France=Lima")
    ap.add_argument("--delete", type=str, default=None, help="Country to delete, e.g. France")
    ap.add_argument("--threshold", type=float, default=DEFAULT_THRESHOLD)
    ap.add_argument("--demo", action="store_true")
    args = ap.parse_args()
    torch.set_num_threads(1)
    d = YazDemo(threshold=args.threshold)
    if args.demo:
        run_demo(d); return
    if args.delete:
        d.delete(args.delete)
    if args.edit:
        tgt, rhs = args.edit.split("=", 1); d.edit(tgt.strip(), rhs.strip())
    if args.prompt:
        d.ask(args.prompt)
    elif not (args.edit or args.delete):
        ap.print_help()


if __name__ == "__main__":
    main()