Yaz / load_yaz.py
TilelliLab's picture
Yaz v0.0.1 — safetensors + loader + model card + paper (editable/auditable tiny KB model)
b14638e verified
Raw
History Blame Contribute Delete
1.37 kB
"""Load Yaz from the safetensors weights + JSON sidecar in this Hugging Face repo.
from load_yaz import load_yaz
model, cfg, meta = load_yaz() # uses files in this directory
# `model` is a YazLM (eval mode); `meta` has cfg + the country->atom / ->capital maps.
No pickle is loaded — weights come from `model.safetensors`. The `yaz/` package (model code) is
vendored alongside this file. See `demo.py` for routing + abstention + live-edit usage.
"""
from __future__ import annotations
import json
import os
import torch
from safetensors.torch import load_file
from yaz import YazConfig, YazLM
HERE = os.path.dirname(os.path.abspath(__file__))
def load_yaz(weights="model.safetensors", meta="yaz_meta.json"):
with open(os.path.join(HERE, meta), "r", encoding="utf-8") as f:
m = json.load(f)
cfg = YazConfig(**m["cfg"])
model = YazLM(cfg)
state = load_file(os.path.join(HERE, weights))
model.load_state_dict(state)
model.eval()
return model, cfg, m
if __name__ == "__main__":
model, cfg, meta = load_yaz()
n = sum(p.numel() for p in model.parameters())
print(f"Loaded Yaz: {n:,} parameters | {len(meta['country_to_target_atom'])} fact-atoms")
print("France -> atom", meta["country_to_target_atom"]["France"],
"| capital first byte", meta["country_to_capital_first"]["France"])