| |
| |
| |
| |
| import sys, json, struct |
| import torch |
| from transformers import AutoTokenizer, AutoModel |
|
|
| MODEL = "sentence-transformers/all-MiniLM-L6-v2" |
|
|
| def load(): |
| tok = AutoTokenizer.from_pretrained(MODEL) |
| model = AutoModel.from_pretrained(MODEL) |
| model.eval() |
| return tok, model |
|
|
| @torch.no_grad() |
| def embed(tok, model, texts, batch=64): |
| out = [] |
| for i in range(0, len(texts), batch): |
| chunk = texts[i:i+batch] |
| enc = tok(chunk, padding=True, truncation=True, max_length=128, return_tensors="pt") |
| h = model(**enc).last_hidden_state |
| mask = enc["attention_mask"].unsqueeze(-1).float() |
| emb = (h * mask).sum(1) / mask.sum(1).clamp(min=1e-9) |
| emb = torch.nn.functional.normalize(emb, dim=-1) |
| out.append(emb) |
| if len(texts) > 200: |
| print(f" embedded {min(i+batch, len(texts))}/{len(texts)}", file=sys.stderr) |
| return torch.cat(out, 0) |
|
|
| def main(): |
| mode = sys.argv[1] |
| tok, model = load() |
| if mode == "cache": |
| texts = json.load(open(sys.argv[2], encoding="utf-8")) |
| v = embed(tok, model, texts) |
| arr = v.numpy().astype("float32") |
| with open(sys.argv[3] + ".f32", "wb") as f: |
| f.write(arr.tobytes()) |
| json.dump({"n": arr.shape[0], "d": arr.shape[1]}, open(sys.argv[3] + ".meta.json", "w")) |
| print(json.dumps({"ok": True, "n": int(arr.shape[0]), "d": int(arr.shape[1])})) |
| elif mode == "query": |
| v = embed(tok, model, [sys.argv[2]]) |
| print(json.dumps(v[0].tolist())) |
| elif mode == "serve": |
| |
| |
| sys.stdout.write("READY\n"); sys.stdout.flush() |
| for line in sys.stdin: |
| line = line.rstrip("\n") |
| if not line: |
| continue |
| try: |
| v = embed(tok, model, [line]) |
| sys.stdout.write(json.dumps(v[0].tolist()) + "\n") |
| except Exception as e: |
| sys.stdout.write("null\n") |
| sys.stdout.flush() |
|
|
| if __name__ == "__main__": |
| main() |
|
|