""" KnowForge Encoder — standalone inference. Predicts transform_type and answer_type from a KnowForge input prompt. CLI: python inference.py "A cao hơn B, B cao hơn C. A có cao hơn C không?" API: from inference import predict; result = predict("A cao hơn B...") """ import json import re import sys from pathlib import Path from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F _HERE = Path(__file__).parent # ── Label maps (must match training) ──────────────────────────────────────── TRANSFORM_LABELS = ["linear_to_cyclic", "relation_property_check", "relation_to_graph"] ATYPE_LABELS = ["conditional_answer", "exact_answer", "need_more_rule", "unresolvable_without_observation"] # ── Tokenizer ──────────────────────────────────────────────────────────────── _TOK_RE = re.compile(r"[\w]+|[^\w\s]", re.UNICODE) def _tokenize(text: str) -> list: return _TOK_RE.findall(text.lower()) # ── Model architecture ─────────────────────────────────────────────────────── class _MultiTaskEncoder(nn.Module): def __init__(self, vocab_size: int, embed_dim: int = 64, hidden_dim: int = 64, n_layers: int = 2, dropout: float = 0.3): super().__init__() enc_dim = hidden_dim * 2 # 128 self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) self.dropout = nn.Dropout(dropout) conv_layers = [] in_ch = embed_dim for _ in range(n_layers): conv_layers += [nn.Conv1d(in_ch, enc_dim, 3, padding=1), nn.ReLU()] in_ch = enc_dim self.encoder = nn.Sequential(*conv_layers) self.transform_head = nn.Linear(enc_dim, len(TRANSFORM_LABELS)) self.atype_head = nn.Linear(enc_dim, len(ATYPE_LABELS)) # Unused heads included so state_dict keys match exactly self.etype_head = nn.Linear(enc_dim, 24) self.uncertainty_head = nn.Linear(enc_dim, 5) self.bio_head = nn.Linear(enc_dim, 12) def forward(self, token_ids: torch.Tensor) -> dict: x = self.embedding(token_ids) # (B, L, E) x = self.dropout(x) out = self.encoder(x.transpose(1, 2)).transpose(1, 2) # (B, L, 128) # Global max pooling over sequence dim pooled = out.max(dim=1).values # (B, 128) return { "transform": self.transform_head(pooled), "atype": self.atype_head(pooled), } # ── Lazy singleton loader ──────────────────────────────────────────────────── _encoder: Optional[_MultiTaskEncoder] = None _vocab: Optional[dict] = None def _load(): global _encoder, _vocab if _encoder is not None: return _encoder, _vocab vocab_path = _HERE / "vocab.json" cfg_path = _HERE / "model_config.json" sf_path = _HERE / "best_model.safetensors" pt_path = _HERE / "best_model.pt" if not vocab_path.exists(): raise FileNotFoundError(f"vocab.json not found at {vocab_path}") _vocab = json.load(open(vocab_path)) cfg = json.load(open(cfg_path)) if cfg_path.exists() else {} model = _MultiTaskEncoder( vocab_size = cfg.get("vocab_size", len(_vocab)), embed_dim = cfg.get("embed_dim", 64), hidden_dim = cfg.get("hidden_dim", 64), n_layers = cfg.get("n_layers", 2), dropout = cfg.get("dropout", 0.3), ) if sf_path.exists(): from safetensors.torch import load_file state = load_file(str(sf_path)) elif pt_path.exists(): state = torch.load(str(pt_path), map_location="cpu", weights_only=True) else: raise FileNotFoundError(f"No model weights found at {sf_path} or {pt_path}") model.load_state_dict(state) model.eval() _encoder = model return _encoder, _vocab # ── Public API ─────────────────────────────────────────────────────────────── def predict(text: str) -> dict: """ Predict transform_type and answer_type for a KnowForge input. Args: text: Natural-language input (rules + question or question alone). Returns: { "transform_type": str — one of linear_to_cyclic / relation_property_check / relation_to_graph, "transform_confidence": float — softmax probability [0,1], "answer_type": str — one of conditional_answer / exact_answer / need_more_rule / unresolvable_without_observation, "atype_confidence": float, } """ model, vocab = _load() toks = _tokenize(text) ids = [vocab.get(t, vocab.get("", 1)) for t in toks] or [0] x = torch.tensor([ids], dtype=torch.long) # (1, L) with torch.no_grad(): logits = model(x) t_probs = F.softmax(logits["transform"][0], dim=-1) a_probs = F.softmax(logits["atype"][0], dim=-1) t_idx = int(t_probs.argmax()) a_idx = int(a_probs.argmax()) return { "transform_type": TRANSFORM_LABELS[t_idx], "transform_confidence": round(float(t_probs[t_idx]), 4), "answer_type": ATYPE_LABELS[a_idx], "atype_confidence": round(float(a_probs[a_idx]), 4), } def _main(): if len(sys.argv) < 2: print("Usage: python inference.py \"\"") sys.exit(1) text = " ".join(sys.argv[1:]) result = predict(text) print(f"Transform: {result['transform_type']} ({result['transform_confidence']:.2%})") print(f"Answer type: {result['answer_type']} ({result['atype_confidence']:.2%})") if __name__ == "__main__": _main()