| from __future__ import annotations |
| import json |
| from pathlib import Path |
| import torch |
|
|
| DEFAULT_REPO_ID = "jackyoung27/matrix-sae" |
|
|
|
|
| def _root(root=None): |
| return Path(root).resolve() if root else Path(__file__).resolve().parent |
|
|
|
|
| def load_manifest(root=None): |
| return json.loads((_root(root) / "manifest.json").read_text()) |
|
|
|
|
| def _find(spec, m): |
| for e in m["checkpoints"]: |
| if e["relative_path"] == spec or e["tag"] == spec: |
| return e |
| raise KeyError(spec) |
|
|
|
|
| def list_checkpoints(root=None, group=None, sae_type=None): |
| es = load_manifest(root)["checkpoints"] |
| if group: es = [e for e in es if e["group"] == group] |
| if sae_type: es = [e for e in es if e["parsed"].get("sae_type") == sae_type] |
| return es |
|
|
|
|
| def load_checkpoint(spec, device="cpu", root=None): |
| r = _root(root) |
| e = _find(spec, load_manifest(r)) |
| import sys |
| if str(r) not in sys.path: sys.path.insert(0, str(r)) |
| from sae import build_sae_from_config |
| d = r / e["relative_path"] |
| cfg = json.loads((d / "config.json").read_text()) |
| ckpt = torch.load(d / "best.pt", map_location="cpu", weights_only=False) |
| model = build_sae_from_config(cfg, state_dict=ckpt["model_state_dict"]) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| return model.to(device).eval(), cfg, e, ckpt |
|
|
|
|
| def load_from_hub(spec, repo_id=DEFAULT_REPO_ID, device="cpu", revision=None, cache_dir=None): |
| from huggingface_hub import hf_hub_download |
| def dl(f): return hf_hub_download(repo_id=repo_id, filename=f, revision=revision, cache_dir=str(cache_dir) if cache_dir else None) |
| m = json.loads(Path(dl("manifest.json")).read_text()) |
| e = _find(spec, m) |
| rel = e["relative_path"] |
| cfg_p = dl(f"{rel}/config.json") |
| ckpt_p = dl(f"{rel}/best.pt") |
| sae_p = dl("sae.py") |
| import importlib.util, sys |
| sae_dir = str(Path(sae_p).parent) |
| if sae_dir not in sys.path: sys.path.insert(0, sae_dir) |
| s = importlib.util.spec_from_file_location("sae", sae_p) |
| assert s and s.loader |
| mod = importlib.util.module_from_spec(s) |
| s.loader.exec_module(mod) |
| cfg = json.loads(Path(cfg_p).read_text()) |
| ckpt = torch.load(ckpt_p, map_location="cpu", weights_only=False) |
| model = mod.build_sae_from_config(cfg, state_dict=ckpt["model_state_dict"]) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| return model.to(device).eval(), cfg, e, ckpt |
|
|