File size: 6,457 Bytes
578c1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""
KnowForge Encoder β€” standalone inference.

Predicts transform_type and answer_type from a KnowForge input prompt.

CLI:   python inference.py "A cao hΖ‘n B, B cao hΖ‘n C. A cΓ³ cao hΖ‘n C khΓ΄ng?"
API:   from inference import predict; result = predict("A cao hΖ‘n B...")
"""
import json
import re
import sys
from pathlib import Path
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

_HERE = Path(__file__).parent

# ── Label maps (must match training) ────────────────────────────────────────

TRANSFORM_LABELS = ["linear_to_cyclic", "relation_property_check", "relation_to_graph"]
ATYPE_LABELS     = ["conditional_answer", "exact_answer", "need_more_rule",
                    "unresolvable_without_observation"]

# ── Tokenizer ────────────────────────────────────────────────────────────────

_TOK_RE = re.compile(r"[\w]+|[^\w\s]", re.UNICODE)


def _tokenize(text: str) -> list:
    return _TOK_RE.findall(text.lower())


# ── Model architecture ───────────────────────────────────────────────────────

class _MultiTaskEncoder(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int = 64,
                 hidden_dim: int = 64, n_layers: int = 2, dropout: float = 0.3):
        super().__init__()
        enc_dim = hidden_dim * 2  # 128

        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.dropout   = nn.Dropout(dropout)

        conv_layers = []
        in_ch = embed_dim
        for _ in range(n_layers):
            conv_layers += [nn.Conv1d(in_ch, enc_dim, 3, padding=1), nn.ReLU()]
            in_ch = enc_dim
        self.encoder = nn.Sequential(*conv_layers)

        self.transform_head   = nn.Linear(enc_dim, len(TRANSFORM_LABELS))
        self.atype_head       = nn.Linear(enc_dim, len(ATYPE_LABELS))
        # Unused heads included so state_dict keys match exactly
        self.etype_head       = nn.Linear(enc_dim, 24)
        self.uncertainty_head = nn.Linear(enc_dim, 5)
        self.bio_head         = nn.Linear(enc_dim, 12)

    def forward(self, token_ids: torch.Tensor) -> dict:
        x   = self.embedding(token_ids)                          # (B, L, E)
        x   = self.dropout(x)
        out = self.encoder(x.transpose(1, 2)).transpose(1, 2)   # (B, L, 128)
        # Global max pooling over sequence dim
        pooled = out.max(dim=1).values                           # (B, 128)
        return {
            "transform": self.transform_head(pooled),
            "atype":     self.atype_head(pooled),
        }


# ── Lazy singleton loader ────────────────────────────────────────────────────

_encoder: Optional[_MultiTaskEncoder] = None
_vocab:   Optional[dict] = None


def _load():
    global _encoder, _vocab
    if _encoder is not None:
        return _encoder, _vocab

    vocab_path = _HERE / "vocab.json"
    cfg_path   = _HERE / "model_config.json"
    sf_path    = _HERE / "best_model.safetensors"
    pt_path    = _HERE / "best_model.pt"

    if not vocab_path.exists():
        raise FileNotFoundError(f"vocab.json not found at {vocab_path}")

    _vocab = json.load(open(vocab_path))

    cfg = json.load(open(cfg_path)) if cfg_path.exists() else {}
    model = _MultiTaskEncoder(
        vocab_size = cfg.get("vocab_size", len(_vocab)),
        embed_dim  = cfg.get("embed_dim",  64),
        hidden_dim = cfg.get("hidden_dim", 64),
        n_layers   = cfg.get("n_layers",   2),
        dropout    = cfg.get("dropout",    0.3),
    )

    if sf_path.exists():
        from safetensors.torch import load_file
        state = load_file(str(sf_path))
    elif pt_path.exists():
        state = torch.load(str(pt_path), map_location="cpu", weights_only=True)
    else:
        raise FileNotFoundError(f"No model weights found at {sf_path} or {pt_path}")

    model.load_state_dict(state)
    model.eval()
    _encoder = model
    return _encoder, _vocab


# ── Public API ───────────────────────────────────────────────────────────────

def predict(text: str) -> dict:
    """
    Predict transform_type and answer_type for a KnowForge input.

    Args:
        text: Natural-language input (rules + question or question alone).

    Returns:
        {
            "transform_type":       str   β€” one of linear_to_cyclic /
                                            relation_property_check /
                                            relation_to_graph,
            "transform_confidence": float β€” softmax probability [0,1],
            "answer_type":          str   β€” one of conditional_answer /
                                            exact_answer /
                                            need_more_rule /
                                            unresolvable_without_observation,
            "atype_confidence":     float,
        }
    """
    model, vocab = _load()

    toks = _tokenize(text)
    ids  = [vocab.get(t, vocab.get("<UNK>", 1)) for t in toks] or [0]
    x    = torch.tensor([ids], dtype=torch.long)  # (1, L)

    with torch.no_grad():
        logits = model(x)

    t_probs = F.softmax(logits["transform"][0], dim=-1)
    a_probs = F.softmax(logits["atype"][0],     dim=-1)

    t_idx = int(t_probs.argmax())
    a_idx = int(a_probs.argmax())

    return {
        "transform_type":       TRANSFORM_LABELS[t_idx],
        "transform_confidence": round(float(t_probs[t_idx]), 4),
        "answer_type":          ATYPE_LABELS[a_idx],
        "atype_confidence":     round(float(a_probs[a_idx]), 4),
    }


def _main():
    if len(sys.argv) < 2:
        print("Usage: python inference.py \"<input text>\"")
        sys.exit(1)
    text   = " ".join(sys.argv[1:])
    result = predict(text)
    print(f"Transform: {result['transform_type']} ({result['transform_confidence']:.2%})")
    print(f"Answer type: {result['answer_type']} ({result['atype_confidence']:.2%})")


if __name__ == "__main__":
    _main()