File size: 7,552 Bytes
ad8888f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 | ---
license: apache-2.0
tags:
- chemistry
- molecular-generation
- smiles
- stamp
- drug-discovery
---
# STAMP Hybrid AO-GPT (31M, d=512/l=8, epoch 7)
Pretrained AO-GPT (any-order GPT) over **STAMP** molecular token sequences
with a **hybrid motif + character vocabulary**. Trained on 30M unique
filtered molecules. Achieves **79.00% GenMol quality** — matching the
112M-parameter AR baseline (79.64%) with 28% of the parameters and 20% of
the vocabulary size.
## Highlights
- **Small vocab (2481)**: 2387 high-frequency atomic motifs (freq ≥ 5000)
+ 49 SMILES character tokens + ~45 STAMP structural tokens. Covers ~91%
of motif occurrences as atomic tokens; rare motifs expand to chars.
- **Training-time char fallback** with log-interpolated probability
(~2% at the most frequent motif, ~15% at the cutoff). The model sees
each atomic motif in both atomic and char form, closing the
train/inference gap for OOV motifs.
- **STAMP structural tokens** (`[J_*]`, `[B_*]`, `[S_*]`, `[END]`) act as
natural motif boundaries — no extra `[MS]`/`[ME]` markers needed.
- **Drug-like outputs**: 100% validity, 100% uniqueness (at N=1024),
79.00% pass the GenMol filter (QED ≥ 0.6 AND SA ≤ 4.0).
## Files
| file | what it is |
|---|---|
| `model.pt` | torch checkpoint: `{model_state, cfg, epoch, representation, model_type}` |
| `hybrid_vocab.json` | full vocab with atomic motif map, frequencies, and char expansions |
| `motif_vocab.txt` | source motif-freq file (v3_cm_union format: `smiles\tn_heavy\tfreq`) |
| `hybrid_vocab.py` | self-contained `HybridVocab` class for decoding |
| `config.json` | architecture summary + default sampling + eval numbers |
## Evaluation (N=1024 at T=0.95, top_p=0.85)
| metric | value |
|---|---:|
| validity | 100.00% |
| uniqueness (raw SMILES) | 100.00% |
| quality over valid (QED ≥ 0.6 ∧ SA ≤ 4) | 79.16% |
| **GenMol score** | **79.00%** |
| QED mean | 0.727 |
| SA mean | 2.92 |
| diversity (1 − pairwise Tanimoto, 1024-bit Morgan r=2) | 0.860 |
**Reference (AR baseline, old 12573-token vocab, d=768/l=12, 112M params): 79.64%.**
The hybrid model matches within noise at 28% of the parameter count.
## Usage
### 1. Load vocab
```python
from hybrid_vocab import HybridVocab
vocab = HybridVocab.load("hybrid_vocab.json")
# vocab.itos -> list of 2481 token strings
# vocab.atomic_motifs -> {smiles: id} for the 2387 motifs
# vocab.motif_freq -> {smiles: freq}
# vocab.motif_expansion -> {smiles: [char_id, ...]}
```
### 2. Load model
```python
import torch
from dataclasses import dataclass, field
from typing import Optional
# Option A: clone https://github.com/... (STAMP repo) to get `stamp.benchmark.lm`
from stamp.benchmark.lm import LMConfig, TinyDecoderLM
ckpt = torch.load("model.pt", map_location="cpu", weights_only=False)
cfg = LMConfig(**ckpt["cfg"])
cfg.use_adaln = True # AO-GPT arch
model = TinyDecoderLM(vocab_size=len(vocab.itos), cfg=cfg, bidirectional=False)
state = ckpt["model_state"]
# Strip torch.compile prefix if present.
if any(k.startswith("_orig_mod.") for k in state):
state = {k.replace("_orig_mod.", "", 1): v for k, v in state.items()}
model.load_state_dict(state)
model.eval().cuda()
```
### 3. Sample (AR, top-p)
```python
import torch
from hybrid_vocab import is_stamp_structural
BOS, EOS = vocab.bos_id, vocab.eos_id
PAD, UNK, MASK = vocab.pad_id, vocab.unk_id, vocab.mask_id
struct_ids = {vocab.stoi[t] for t in vocab.itos if is_stamp_structural(t)}
suppress = {PAD, BOS, MASK, UNK}
T, P = 0.95, 0.85
n, max_new = 64, 64
@torch.no_grad()
def sample(n_samples=64):
x = torch.full((n_samples, 1), BOS, dtype=torch.long, device="cuda")
finished = torch.zeros(n_samples, dtype=torch.bool, device="cuda")
for step in range(max_new):
orders = torch.arange(x.size(1), device="cuda").unsqueeze(0).expand(x.size(0), -1)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits = model(x[:, -cfg.max_seq_len:], orders=orders)[:, -1, :].float()
for sid in suppress:
logits[:, sid] = float("-inf")
# top-p
sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
sorted_probs = torch.softmax(sorted_logits / T, dim=-1)
cum = torch.cumsum(sorted_probs, dim=-1)
remove = cum > P
remove[..., 1:] = remove[..., :-1].clone()
remove[..., 0] = False
sorted_logits = sorted_logits.masked_fill(remove, float("-inf"))
logits = torch.zeros_like(logits).scatter_(-1, sorted_idx, sorted_logits)
probs = torch.softmax(logits / T, dim=-1)
nxt = torch.multinomial(probs, 1).squeeze(-1)
nxt = torch.where(finished, torch.full_like(nxt, EOS), nxt)
x = torch.cat([x, nxt.unsqueeze(1)], dim=1)
finished = finished | (nxt == EOS)
if finished.all():
break
return x
```
### 4. Decode token stream → SMILES
```python
def decode_to_stamp_tokens(ids):
"""Flush character runs to motif SMILES at structural token boundaries."""
special = {PAD, BOS, EOS, MASK, UNK}
out, buf = [], []
for i in ids:
if i in special: continue
tok = vocab.itos[i]
if i in struct_ids:
if buf: out.append("".join(buf)); buf = []
out.append(tok)
else:
buf.append(tok)
if buf: out.append("".join(buf))
return out
# Then run through the STAMP codec in the stamp repo:
# from stamp.benchmark.representations import build_representation
# rep = build_representation("stamp")
# text = rep.detokenize(stamp_tokens)
# mol = rep.codec.decode_stamp_to_mol(text)
```
## Sample outputs
Ten representative draws from this checkpoint (all drug-like, QED ≥ 0.6 ∧ SA ≤ 4):
```
Cn1nc(CNCc2cc(Cl)ccc2Cl)n(C)c1=O QED=0.935 SA=2.46 MW=300
CCn1ncc(NC[C@@H]2CCCC[C@@H]2C)c(Br)c1=O QED=0.923 SA=3.31 MW=327
Cc1cccc(Cl)c1NC(=O)CN1CCO[C@@H](C(F)F)CC1 QED=0.921 SA=2.86 MW=332
CCN1CCN(CC(=O)Nc2cc(C(F)(F)F)ccc2Cl)CC1 QED=0.908 SA=1.94 MW=349
COc1ccc(F)c(CNC(=O)C2=CCCCC2)c1 QED=0.907 SA=2.16 MW=263
CN1CC[C@@H]2[C@@H](CCCN2C(=O)NCc2ccc(OC(F)F)cc2)C1 QED=0.905 SA=3.03 MW=353
CC[C@H](C(=O)NCc1c(F)cc(F)cc1F)N1CCCC1=O QED=0.905 SA=2.93 MW=314
Cc1ccc(C2CCN(C(=O)NCc3cccc(F)c3F)CC2)c(=O)n1C QED=0.897 SA=2.45 MW=375
CN(CC(=O)NCc1ccccc1)C(=O)C12CC3CC(CC(C3)C1)C2 QED=0.896 SA=3.37 MW=340
NCC1CCN(c2cc3c(cc2F)c(=O)c(C(=O)O)cn3C2CC2)C1 QED=0.884 SA=2.96 MW=345
```
## Architecture notes
- **AO-GPT**: decoder-only transformer with causal attention over a
shuffled token order per batch (random permutation of middle tokens,
BOS/EOS pinned at ends). Target position is conditioned via AdaLN so
the model learns "any-order" decoding.
- **Hybrid vocab**: structural tokens + SMILES char tokens + atomic motif
tokens share a single id space. At training time, atomic motif tokens
may be expanded to their SMILES char form with a
log-frequency-weighted probability (`HybridVocab.fallback_prob`) so
the model is not brittle at char-level decoding.
- **Decoder**: the STAMP structural tokens delimit motifs; consecutive
character tokens between structural tokens concatenate to a single
motif SMILES, which the STAMP codec parses to a molecule via a stack
machine with safety fallbacks.
## License
Apache-2.0.
## Citation
Cite the STAMP representation paper and this repository. (Placeholder —
fill in with your actual citation info.)
|