File size: 7,552 Bytes
ad8888f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
---
license: apache-2.0
tags:
  - chemistry
  - molecular-generation
  - smiles
  - stamp
  - drug-discovery
---

# STAMP Hybrid AO-GPT (31M, d=512/l=8, epoch 7)

Pretrained AO-GPT (any-order GPT) over **STAMP** molecular token sequences
with a **hybrid motif + character vocabulary**. Trained on 30M unique
filtered molecules. Achieves **79.00% GenMol quality** — matching the
112M-parameter AR baseline (79.64%) with 28% of the parameters and 20% of
the vocabulary size.

## Highlights

- **Small vocab (2481)**: 2387 high-frequency atomic motifs (freq ≥ 5000)
  + 49 SMILES character tokens + ~45 STAMP structural tokens. Covers ~91%
  of motif occurrences as atomic tokens; rare motifs expand to chars.
- **Training-time char fallback** with log-interpolated probability
  (~2% at the most frequent motif, ~15% at the cutoff). The model sees
  each atomic motif in both atomic and char form, closing the
  train/inference gap for OOV motifs.
- **STAMP structural tokens** (`[J_*]`, `[B_*]`, `[S_*]`, `[END]`) act as
  natural motif boundaries — no extra `[MS]`/`[ME]` markers needed.
- **Drug-like outputs**: 100% validity, 100% uniqueness (at N=1024),
  79.00% pass the GenMol filter (QED ≥ 0.6 AND SA ≤ 4.0).

## Files

| file | what it is |
|---|---|
| `model.pt` | torch checkpoint: `{model_state, cfg, epoch, representation, model_type}` |
| `hybrid_vocab.json` | full vocab with atomic motif map, frequencies, and char expansions |
| `motif_vocab.txt` | source motif-freq file (v3_cm_union format: `smiles\tn_heavy\tfreq`) |
| `hybrid_vocab.py` | self-contained `HybridVocab` class for decoding |
| `config.json` | architecture summary + default sampling + eval numbers |

## Evaluation (N=1024 at T=0.95, top_p=0.85)

| metric | value |
|---|---:|
| validity | 100.00% |
| uniqueness (raw SMILES) | 100.00% |
| quality over valid (QED ≥ 0.6 ∧ SA ≤ 4) | 79.16% |
| **GenMol score** | **79.00%** |
| QED mean | 0.727 |
| SA mean | 2.92 |
| diversity (1 − pairwise Tanimoto, 1024-bit Morgan r=2) | 0.860 |

**Reference (AR baseline, old 12573-token vocab, d=768/l=12, 112M params): 79.64%.**
The hybrid model matches within noise at 28% of the parameter count.

## Usage

### 1. Load vocab

```python
from hybrid_vocab import HybridVocab
vocab = HybridVocab.load("hybrid_vocab.json")
# vocab.itos          -> list of 2481 token strings
# vocab.atomic_motifs -> {smiles: id} for the 2387 motifs
# vocab.motif_freq    -> {smiles: freq}
# vocab.motif_expansion -> {smiles: [char_id, ...]}
```

### 2. Load model

```python
import torch
from dataclasses import dataclass, field
from typing import Optional

# Option A: clone https://github.com/... (STAMP repo) to get `stamp.benchmark.lm`
from stamp.benchmark.lm import LMConfig, TinyDecoderLM

ckpt = torch.load("model.pt", map_location="cpu", weights_only=False)
cfg = LMConfig(**ckpt["cfg"])
cfg.use_adaln = True  # AO-GPT arch
model = TinyDecoderLM(vocab_size=len(vocab.itos), cfg=cfg, bidirectional=False)

state = ckpt["model_state"]
# Strip torch.compile prefix if present.
if any(k.startswith("_orig_mod.") for k in state):
    state = {k.replace("_orig_mod.", "", 1): v for k, v in state.items()}
model.load_state_dict(state)
model.eval().cuda()
```

### 3. Sample (AR, top-p)

```python
import torch
from hybrid_vocab import is_stamp_structural

BOS, EOS = vocab.bos_id, vocab.eos_id
PAD, UNK, MASK = vocab.pad_id, vocab.unk_id, vocab.mask_id
struct_ids = {vocab.stoi[t] for t in vocab.itos if is_stamp_structural(t)}
suppress = {PAD, BOS, MASK, UNK}
T, P = 0.95, 0.85
n, max_new = 64, 64

@torch.no_grad()
def sample(n_samples=64):
    x = torch.full((n_samples, 1), BOS, dtype=torch.long, device="cuda")
    finished = torch.zeros(n_samples, dtype=torch.bool, device="cuda")
    for step in range(max_new):
        orders = torch.arange(x.size(1), device="cuda").unsqueeze(0).expand(x.size(0), -1)
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            logits = model(x[:, -cfg.max_seq_len:], orders=orders)[:, -1, :].float()
        for sid in suppress:
            logits[:, sid] = float("-inf")
        # top-p
        sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
        sorted_probs = torch.softmax(sorted_logits / T, dim=-1)
        cum = torch.cumsum(sorted_probs, dim=-1)
        remove = cum > P
        remove[..., 1:] = remove[..., :-1].clone()
        remove[..., 0] = False
        sorted_logits = sorted_logits.masked_fill(remove, float("-inf"))
        logits = torch.zeros_like(logits).scatter_(-1, sorted_idx, sorted_logits)
        probs = torch.softmax(logits / T, dim=-1)
        nxt = torch.multinomial(probs, 1).squeeze(-1)
        nxt = torch.where(finished, torch.full_like(nxt, EOS), nxt)
        x = torch.cat([x, nxt.unsqueeze(1)], dim=1)
        finished = finished | (nxt == EOS)
        if finished.all():
            break
    return x
```

### 4. Decode token stream → SMILES

```python
def decode_to_stamp_tokens(ids):
    """Flush character runs to motif SMILES at structural token boundaries."""
    special = {PAD, BOS, EOS, MASK, UNK}
    out, buf = [], []
    for i in ids:
        if i in special: continue
        tok = vocab.itos[i]
        if i in struct_ids:
            if buf: out.append("".join(buf)); buf = []
            out.append(tok)
        else:
            buf.append(tok)
    if buf: out.append("".join(buf))
    return out

# Then run through the STAMP codec in the stamp repo:
#     from stamp.benchmark.representations import build_representation
#     rep = build_representation("stamp")
#     text = rep.detokenize(stamp_tokens)
#     mol  = rep.codec.decode_stamp_to_mol(text)
```

## Sample outputs

Ten representative draws from this checkpoint (all drug-like, QED ≥ 0.6 ∧ SA ≤ 4):

```
Cn1nc(CNCc2cc(Cl)ccc2Cl)n(C)c1=O                   QED=0.935 SA=2.46 MW=300
CCn1ncc(NC[C@@H]2CCCC[C@@H]2C)c(Br)c1=O            QED=0.923 SA=3.31 MW=327
Cc1cccc(Cl)c1NC(=O)CN1CCO[C@@H](C(F)F)CC1          QED=0.921 SA=2.86 MW=332
CCN1CCN(CC(=O)Nc2cc(C(F)(F)F)ccc2Cl)CC1            QED=0.908 SA=1.94 MW=349
COc1ccc(F)c(CNC(=O)C2=CCCCC2)c1                    QED=0.907 SA=2.16 MW=263
CN1CC[C@@H]2[C@@H](CCCN2C(=O)NCc2ccc(OC(F)F)cc2)C1 QED=0.905 SA=3.03 MW=353
CC[C@H](C(=O)NCc1c(F)cc(F)cc1F)N1CCCC1=O           QED=0.905 SA=2.93 MW=314
Cc1ccc(C2CCN(C(=O)NCc3cccc(F)c3F)CC2)c(=O)n1C      QED=0.897 SA=2.45 MW=375
CN(CC(=O)NCc1ccccc1)C(=O)C12CC3CC(CC(C3)C1)C2      QED=0.896 SA=3.37 MW=340
NCC1CCN(c2cc3c(cc2F)c(=O)c(C(=O)O)cn3C2CC2)C1      QED=0.884 SA=2.96 MW=345
```

## Architecture notes

- **AO-GPT**: decoder-only transformer with causal attention over a
  shuffled token order per batch (random permutation of middle tokens,
  BOS/EOS pinned at ends). Target position is conditioned via AdaLN so
  the model learns "any-order" decoding.
- **Hybrid vocab**: structural tokens + SMILES char tokens + atomic motif
  tokens share a single id space. At training time, atomic motif tokens
  may be expanded to their SMILES char form with a
  log-frequency-weighted probability (`HybridVocab.fallback_prob`) so
  the model is not brittle at char-level decoding.
- **Decoder**: the STAMP structural tokens delimit motifs; consecutive
  character tokens between structural tokens concatenate to a single
  motif SMILES, which the STAMP codec parses to a molecule via a stack
  machine with safety fallbacks.

## License

Apache-2.0.

## Citation

Cite the STAMP representation paper and this repository. (Placeholder —
fill in with your actual citation info.)