knowforge-encoder / inference.py
qox's picture
Initial upload: KnowForge Encoder (131K params)
578c1ba verified
"""
KnowForge Encoder β€” standalone inference.
Predicts transform_type and answer_type from a KnowForge input prompt.
CLI: python inference.py "A cao hΖ‘n B, B cao hΖ‘n C. A cΓ³ cao hΖ‘n C khΓ΄ng?"
API: from inference import predict; result = predict("A cao hΖ‘n B...")
"""
import json
import re
import sys
from pathlib import Path
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
_HERE = Path(__file__).parent
# ── Label maps (must match training) ────────────────────────────────────────
TRANSFORM_LABELS = ["linear_to_cyclic", "relation_property_check", "relation_to_graph"]
ATYPE_LABELS = ["conditional_answer", "exact_answer", "need_more_rule",
"unresolvable_without_observation"]
# ── Tokenizer ────────────────────────────────────────────────────────────────
_TOK_RE = re.compile(r"[\w]+|[^\w\s]", re.UNICODE)
def _tokenize(text: str) -> list:
return _TOK_RE.findall(text.lower())
# ── Model architecture ───────────────────────────────────────────────────────
class _MultiTaskEncoder(nn.Module):
def __init__(self, vocab_size: int, embed_dim: int = 64,
hidden_dim: int = 64, n_layers: int = 2, dropout: float = 0.3):
super().__init__()
enc_dim = hidden_dim * 2 # 128
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.dropout = nn.Dropout(dropout)
conv_layers = []
in_ch = embed_dim
for _ in range(n_layers):
conv_layers += [nn.Conv1d(in_ch, enc_dim, 3, padding=1), nn.ReLU()]
in_ch = enc_dim
self.encoder = nn.Sequential(*conv_layers)
self.transform_head = nn.Linear(enc_dim, len(TRANSFORM_LABELS))
self.atype_head = nn.Linear(enc_dim, len(ATYPE_LABELS))
# Unused heads included so state_dict keys match exactly
self.etype_head = nn.Linear(enc_dim, 24)
self.uncertainty_head = nn.Linear(enc_dim, 5)
self.bio_head = nn.Linear(enc_dim, 12)
def forward(self, token_ids: torch.Tensor) -> dict:
x = self.embedding(token_ids) # (B, L, E)
x = self.dropout(x)
out = self.encoder(x.transpose(1, 2)).transpose(1, 2) # (B, L, 128)
# Global max pooling over sequence dim
pooled = out.max(dim=1).values # (B, 128)
return {
"transform": self.transform_head(pooled),
"atype": self.atype_head(pooled),
}
# ── Lazy singleton loader ────────────────────────────────────────────────────
_encoder: Optional[_MultiTaskEncoder] = None
_vocab: Optional[dict] = None
def _load():
global _encoder, _vocab
if _encoder is not None:
return _encoder, _vocab
vocab_path = _HERE / "vocab.json"
cfg_path = _HERE / "model_config.json"
sf_path = _HERE / "best_model.safetensors"
pt_path = _HERE / "best_model.pt"
if not vocab_path.exists():
raise FileNotFoundError(f"vocab.json not found at {vocab_path}")
_vocab = json.load(open(vocab_path))
cfg = json.load(open(cfg_path)) if cfg_path.exists() else {}
model = _MultiTaskEncoder(
vocab_size = cfg.get("vocab_size", len(_vocab)),
embed_dim = cfg.get("embed_dim", 64),
hidden_dim = cfg.get("hidden_dim", 64),
n_layers = cfg.get("n_layers", 2),
dropout = cfg.get("dropout", 0.3),
)
if sf_path.exists():
from safetensors.torch import load_file
state = load_file(str(sf_path))
elif pt_path.exists():
state = torch.load(str(pt_path), map_location="cpu", weights_only=True)
else:
raise FileNotFoundError(f"No model weights found at {sf_path} or {pt_path}")
model.load_state_dict(state)
model.eval()
_encoder = model
return _encoder, _vocab
# ── Public API ───────────────────────────────────────────────────────────────
def predict(text: str) -> dict:
"""
Predict transform_type and answer_type for a KnowForge input.
Args:
text: Natural-language input (rules + question or question alone).
Returns:
{
"transform_type": str β€” one of linear_to_cyclic /
relation_property_check /
relation_to_graph,
"transform_confidence": float β€” softmax probability [0,1],
"answer_type": str β€” one of conditional_answer /
exact_answer /
need_more_rule /
unresolvable_without_observation,
"atype_confidence": float,
}
"""
model, vocab = _load()
toks = _tokenize(text)
ids = [vocab.get(t, vocab.get("<UNK>", 1)) for t in toks] or [0]
x = torch.tensor([ids], dtype=torch.long) # (1, L)
with torch.no_grad():
logits = model(x)
t_probs = F.softmax(logits["transform"][0], dim=-1)
a_probs = F.softmax(logits["atype"][0], dim=-1)
t_idx = int(t_probs.argmax())
a_idx = int(a_probs.argmax())
return {
"transform_type": TRANSFORM_LABELS[t_idx],
"transform_confidence": round(float(t_probs[t_idx]), 4),
"answer_type": ATYPE_LABELS[a_idx],
"atype_confidence": round(float(a_probs[a_idx]), 4),
}
def _main():
if len(sys.argv) < 2:
print("Usage: python inference.py \"<input text>\"")
sys.exit(1)
text = " ".join(sys.argv[1:])
result = predict(text)
print(f"Transform: {result['transform_type']} ({result['transform_confidence']:.2%})")
print(f"Answer type: {result['answer_type']} ({result['atype_confidence']:.2%})")
if __name__ == "__main__":
_main()