cento-engine / src /embed.py
LJTSG's picture
Cento v0.1 — bounded recombinant-memory engine
8494d00 verified
Raw
History Blame Contribute Delete
2.44 kB
# embed.py — the vector space. MiniLM-L6 (384d) via plain transformers,
# CPU only, mean-pooled + L2-normalized.
# python embed.py cache <texts.json> <out_prefix> -> out_prefix.f32 + .meta.json
# python embed.py query "some text" -> JSON vector on stdout
import sys, json, struct
import torch
from transformers import AutoTokenizer, AutoModel
MODEL = "sentence-transformers/all-MiniLM-L6-v2"
def load():
tok = AutoTokenizer.from_pretrained(MODEL)
model = AutoModel.from_pretrained(MODEL)
model.eval()
return tok, model
@torch.no_grad()
def embed(tok, model, texts, batch=64):
out = []
for i in range(0, len(texts), batch):
chunk = texts[i:i+batch]
enc = tok(chunk, padding=True, truncation=True, max_length=128, return_tensors="pt")
h = model(**enc).last_hidden_state
mask = enc["attention_mask"].unsqueeze(-1).float()
emb = (h * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
emb = torch.nn.functional.normalize(emb, dim=-1)
out.append(emb)
if len(texts) > 200:
print(f" embedded {min(i+batch, len(texts))}/{len(texts)}", file=sys.stderr)
return torch.cat(out, 0)
def main():
mode = sys.argv[1]
tok, model = load()
if mode == "cache":
texts = json.load(open(sys.argv[2], encoding="utf-8"))
v = embed(tok, model, texts)
arr = v.numpy().astype("float32")
with open(sys.argv[3] + ".f32", "wb") as f:
f.write(arr.tobytes())
json.dump({"n": arr.shape[0], "d": arr.shape[1]}, open(sys.argv[3] + ".meta.json", "w"))
print(json.dumps({"ok": True, "n": int(arr.shape[0]), "d": int(arr.shape[1])}))
elif mode == "query":
v = embed(tok, model, [sys.argv[2]])
print(json.dumps(v[0].tolist()))
elif mode == "serve":
# persistent server: model stays loaded; one query per stdin line,
# one JSON vector per stdout line. Turns the 7s cold-start into ~30ms.
sys.stdout.write("READY\n"); sys.stdout.flush()
for line in sys.stdin:
line = line.rstrip("\n")
if not line:
continue
try:
v = embed(tok, model, [line])
sys.stdout.write(json.dumps(v[0].tolist()) + "\n")
except Exception as e:
sys.stdout.write("null\n")
sys.stdout.flush()
if __name__ == "__main__":
main()