""" +=============================================================+ | TRIADS V6 — Graph Attention TRM + Gate-Based Halting | | | | Single model: Gate-halt (4-16 adaptive cycles) | | d=56, 4 heads, gated residuals, deep supervision | | SWA last 50 ep | 200 epochs | | | | Loads: phonons_v6_dataset.pt | +=============================================================+ DEPENDENCIES (dataset already pre-computed, no matminer needed): pip install torch numpy scikit-learn tqdm (all pre-installed on Kaggle) USAGE: python phonons_v6.py """ import os, copy, json, time, math, warnings, threading from collections import defaultdict warnings.filterwarnings('ignore') import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.optim.swa_utils import AveragedModel, SWALR from sklearn.preprocessing import StandardScaler # Notebook dashboard (IPython is always available on Kaggle) try: from IPython.display import display, HTML, clear_output IN_NOTEBOOK = True except ImportError: IN_NOTEBOOK = False # ═══════════════════════════════════════════════════════════════ # CONFIG # ═══════════════════════════════════════════════════════════════ D = 56 N_HEADS = 4 N_WARMUP = 1 # 1 unshared warm-up (param budget) N_ANGLE_RBF = 8 DROPOUT = 0.1 BATCH_SIZE = 64 EPOCHS = 200 SWA_START = 150 LR = 5e-4 WD = 1e-4 SEEDS = [42] # Gate-halt model MIN_CYCLES = 4 MAX_CYCLES = 16 GATE_HALT_THR = 0.05 # halt when max gate < this GATE_SPARSITY = 0.001 # encourage gates to close BASELINES = { 'MEGNet': 28.76, 'ALIGNN': 29.34, 'MODNet': 45.39, 'CrabNet': 47.09, 'TRIADS V4': 56.33, 'TRIADS V3.1': 63.00, 'TRIADS V1': 71.82, 'Dummy': 323.76, } # ═══════════════════════════════════════════════════════════════ # SCATTER # ═══════════════════════════════════════════════════════════════ def scatter_sum(src, idx, dim_size): out = torch.zeros(dim_size, src.shape[-1], dtype=src.dtype, device=src.device) out.scatter_add_(0, idx.unsqueeze(-1).expand_as(src), src) return out # ═══════════════════════════════════════════════════════════════ # COLLATION + DATALOADER # ═══════════════════════════════════════════════════════════════ def collate(graphs, comp, glob_phys, targets, indices, device): az, af = [], [] ei, rb, vc, ph = [], [], [], [] tr, an = [], [] ba, na_list = [], [] a_off, e_off = 0, 0 for k, i in enumerate(indices): g = graphs[i] na, ne = g['n_atoms'], g['n_edges'] az.append(g['atom_z']) af.append(g['atom_features']) ei.append(g['edge_index'] + a_off) rb.append(g['edge_rbf']); vc.append(g['edge_vec']); ph.append(g['edge_physics']) tr.append(g['triplet_index'] + e_off) an.append(g['angle_rbf']) ba.append(torch.full((na,), k, dtype=torch.long)) na_list.append(na) a_off += na; e_off += ne return ( comp[indices].to(device), glob_phys[indices].to(device), { 'atom_z': torch.cat(az).to(device), 'atom_feat': torch.cat(af).to(device), 'ei': torch.cat(ei, 1).to(device), 'rbf': torch.cat(rb).to(device), 'vec': torch.cat(vc).to(device), 'phys': torch.cat(ph).to(device), 'triplets': torch.cat(tr, 1).to(device), 'angle_feat': torch.cat(an).to(device), 'batch': torch.cat(ba).to(device), 'n_crystals': len(indices), 'n_atoms': na_list, }, targets[indices].to(device), ) class Loader: def __init__(self, graphs, comp, gp, tgt, idx, bs, dev, shuf=False): self.g, self.c, self.gp, self.t = graphs, comp, gp, tgt self.idx, self.bs, self.dev, self.shuf = np.array(idx), bs, dev, shuf def __iter__(self): i = self.idx.copy() if self.shuf: np.random.shuffle(i) self._b = [i[j:j+self.bs] for j in range(0, len(i), self.bs)] self._p = 0; return self def __next__(self): if self._p >= len(self._b): raise StopIteration b = self._b[self._p]; self._p += 1 return collate(self.g, self.c, self.gp, self.t, b, self.dev) def __len__(self): return (len(self.idx) + self.bs - 1) // self.bs # ═══════════════════════════════════════════════════════════════ # GRAPH MESSAGE PASSING LAYER (Line Graph style) # ═══════════════════════════════════════════════════════════════ class GraphMPLayer(nn.Module): """Bond update (line graph) + Atom update (edge-gated).""" def __init__(self, d, n_angle=N_ANGLE_RBF, dropout=DROPOUT): super().__init__() # Phase 1: Bond update from angular neighbors self.bond_msg = nn.Sequential(nn.Linear(d*2 + n_angle, d), nn.SiLU()) self.bond_gate = nn.Sequential(nn.Linear(d*2 + n_angle, d), nn.Sigmoid()) self.bond_up = nn.Sequential(nn.Linear(d*2, d), nn.LayerNorm(d), nn.SiLU(), nn.Dropout(dropout)) # Phase 2: Atom update from bonds self.atom_msg = nn.Sequential(nn.Linear(d*3, d), nn.SiLU()) self.atom_gate = nn.Sequential(nn.Linear(d*3, d), nn.Sigmoid()) self.atom_up = nn.Sequential(nn.Linear(d*2, d), nn.LayerNorm(d), nn.SiLU(), nn.Dropout(dropout)) def forward(self, atoms, bonds, ei, triplets, angle_feat): # Phase 1: bonds learn from angular neighbors if triplets.shape[1] > 0: b_ij, b_kj = bonds[triplets[0]], bonds[triplets[1]] inp = torch.cat([b_ij, b_kj, angle_feat], -1) msg = self.bond_msg(inp) * self.bond_gate(inp) agg = torch.zeros(bonds.size(0), bonds.size(1), dtype=torch.float32, device=msg.device) agg.scatter_add_(0, triplets[0].unsqueeze(-1).expand_as(msg), msg) bonds = bonds + self.bond_up(torch.cat([bonds, agg], -1)) # Phase 2: atoms aggregate from bonds inp = torch.cat([atoms[ei[0]], atoms[ei[1]], bonds], -1) msg = self.atom_msg(inp) * self.atom_gate(inp) agg = scatter_sum(msg, ei[1], atoms.size(0)) atoms = atoms + self.atom_up(torch.cat([atoms, agg], -1)) return atoms, bonds # ═══════════════════════════════════════════════════════════════ # PHONON V6 MODEL # ═══════════════════════════════════════════════════════════════ class PhononV6(nn.Module): """ Graph Attention TRM for phonon prediction. mode='fixed': Fixed n_cycles TRM cycles (Model 1) mode='gate_halt': Gate-based implicit halting (Model 2) """ def __init__(self, comp_dim, global_phys_dim=15, d=D, mode='gate_halt', n_cycles=MAX_CYCLES, min_cycles=MIN_CYCLES, max_cycles=MAX_CYCLES, n_warmup=N_WARMUP, n_heads=N_HEADS, dropout=DROPOUT): super().__init__() self.d = d self.mode = mode self.total_cycles = n_cycles if mode == 'fixed' else max_cycles self.min_cycles = min_cycles if mode == 'gate_halt' else self.total_cycles # Feature layout (from V6 dataset: 132 magpie + extras + 11 struct + 200 m2v) self.n_magpie = 132 self.n_extra = comp_dim - 132 - 11 - 200 self.n_comp_tokens = 22 + 1 + 1 # 22 magpie + 1 extra + 1 m2v = 24 # ── Input Encoding ──────────────────────────────────── self.atom_embed = nn.Embedding(103, d) self.atom_feat_proj = nn.Linear(18, d) self.rbf_enc = nn.Linear(40, d) self.vec_enc = nn.Linear(3, d) self.phys_enc = nn.Linear(8, d) # ── Composition Token Projections ───────────────────── self.magpie_proj = nn.Linear(6, d) self.extra_proj = nn.Linear(max(self.n_extra, 1), d) self.m2v_proj = nn.Linear(200, d) # ── Context (structural + global physics) ───────────── self.ctx_proj = nn.Linear(11 + global_phys_dim, d) # ── Token Type Embeddings ───────────────────────────── self.type_embed = nn.Embedding(2, d) # ── Warm-up Layers (unshared) ───────────────────────── self.warmup = nn.ModuleList([GraphMPLayer(d, N_ANGLE_RBF, dropout) for _ in range(n_warmup)]) self.warmup_out = nn.Sequential(nn.Linear(d, d), nn.LayerNorm(d), nn.SiLU()) # ── Shared TRM Block ────────────────────────────────── # Graph MP (shared) self.trm_gnn = GraphMPLayer(d, N_ANGLE_RBF, dropout) # Self-Attention self.sa = nn.MultiheadAttention(d, n_heads, dropout=dropout, batch_first=True) self.sa_n = nn.LayerNorm(d) self.sa_ff = nn.Sequential(nn.Linear(d, d), nn.GELU(), nn.Dropout(dropout), nn.Linear(d, d)) self.sa_fn = nn.LayerNorm(d) # Cross-Attention self.ca = nn.MultiheadAttention(d, n_heads, dropout=dropout, batch_first=True) self.ca_n = nn.LayerNorm(d) # ── State Update (Gated Residuals) ─────────────────── self.z_proj = nn.Linear(d*3, d) self.z_up = nn.Sequential(nn.Linear(d*2, d), nn.SiLU(), nn.Linear(d, d)) self.z_gate = nn.Sequential(nn.Linear(d*2, d), nn.Sigmoid()) self.y_up = nn.Sequential(nn.Linear(d*2, d), nn.SiLU(), nn.Linear(d, d)) self.y_gate = nn.Sequential(nn.Linear(d*2, d), nn.Sigmoid()) # ── Output Head ─────────────────────────────────────── self.head = nn.Sequential(nn.Linear(d, d//2), nn.SiLU(), nn.Linear(d//2, 1)) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, comp, glob_phys, g, deep_supervision=False): B = g['n_crystals'] ei = g['ei'] dev = comp.device # ══════════════════════════════════════════════════════ # INPUT ENCODING # ══════════════════════════════════════════════════════ # Atom features atoms = self.atom_embed(g['atom_z'].clamp(0, 102)) + self.atom_feat_proj(g['atom_feat']) # Bond features: distance (direction-gated) + physics bonds = self.rbf_enc(g['rbf']) * torch.tanh(self.vec_enc(g['vec'])) + self.phys_enc(g['phys']) triplets = g['triplets'] angle_feat = g['angle_feat'] # ══════════════════════════════════════════════════════ # WARM-UP (2 unshared graph layers) # ══════════════════════════════════════════════════════ for layer in self.warmup: atoms, bonds = layer(atoms, bonds, ei, triplets, angle_feat) atoms = self.warmup_out(atoms) # ══════════════════════════════════════════════════════ # COMPOSITION TOKENS (24 total) # ══════════════════════════════════════════════════════ magpie = comp[:, :132].view(B, 22, 6) extras = comp[:, 132:132+self.n_extra] s_meta = comp[:, 132+self.n_extra:132+self.n_extra+11] m2v = comp[:, -200:] mag_tok = self.magpie_proj(magpie) # [B, 22, d] ext_tok = self.extra_proj(extras).unsqueeze(1) # [B, 1, d] m2v_tok = self.m2v_proj(m2v).unsqueeze(1) # [B, 1, d] comp_tok = torch.cat([mag_tok, ext_tok, m2v_tok], 1) # [B, 24, d] comp_tok = comp_tok + self.type_embed.weight[0] # Context vector (structural + global physics) ctx = self.ctx_proj(torch.cat([s_meta, glob_phys], -1)) # [B, d] # ══════════════════════════════════════════════════════ # TRM REASONING LOOP # ══════════════════════════════════════════════════════ z = torch.zeros(B, self.d, device=dev) y = torch.zeros(B, self.d, device=dev) preds = [] n_atoms = g['n_atoms'] self._gate_sparsity = 0. # track gate magnitudes for regularizer for cyc in range(self.total_cycles): # ── Phase 1+2: Graph MP (shared weights) ────────── atoms, bonds = self.trm_gnn(atoms, bonds, ei, triplets, angle_feat) # ── Pad atoms for attention ─────────────────────── ma = max(n_atoms) atom_tok = atoms.new_zeros(B, ma, self.d) atom_mask = torch.ones(B, ma, dtype=torch.bool, device=dev) off = 0 for i, na in enumerate(n_atoms): atom_tok[i, :na] = atoms[off:off+na] atom_mask[i, :na] = False off += na atom_tok = atom_tok + self.type_embed.weight[1] # ── Phase 3: Joint Self-Attention ───────────────── all_tok = torch.cat([comp_tok, atom_tok], 1) full_mask = torch.cat([ torch.zeros(B, self.n_comp_tokens, dtype=torch.bool, device=dev), atom_mask ], 1) sa_out = self.sa(all_tok, all_tok, all_tok, key_padding_mask=full_mask)[0] all_tok = self.sa_n(all_tok + sa_out) all_tok = self.sa_fn(all_tok + self.sa_ff(all_tok)) comp_tok = all_tok[:, :self.n_comp_tokens] atom_tok = all_tok[:, self.n_comp_tokens:] # ── Phase 4: Cross-Attention (comp queries atoms) ─ ca_out = self.ca(comp_tok, atom_tok, atom_tok, key_padding_mask=atom_mask)[0] comp_tok = self.ca_n(comp_tok + ca_out) # ── Unpad atoms back to flat ────────────────────── parts = [atom_tok[i, :n_atoms[i]] for i in range(B)] atoms = torch.cat(parts, 0) # ── Phase 5: State Update (Gated Residuals) ─────── xp = comp_tok.mean(dim=1) # [B, d] z_inp = self.z_proj(torch.cat([xp, ctx, y], -1)) z_cand = self.z_up(torch.cat([z_inp, z], -1)) z_g = self.z_gate(torch.cat([z_inp, z], -1)) z = z + z_g * z_cand y_cand = self.y_up(torch.cat([y, z], -1)) y_g = self.y_gate(torch.cat([y, z], -1)) y = y + y_g * y_cand # Track gate sparsity (mean of all gate activations) self._gate_sparsity = self._gate_sparsity + (z_g.mean() + y_g.mean()) / 2 preds.append(self.head(y).squeeze(-1)) # ── Phase 6: Gate-Based Halting ──────────────────── if self.mode == 'gate_halt' and cyc >= self.min_cycles - 1: if y_g.max().item() < GATE_HALT_THR: break # Normalize gate sparsity by number of cycles actually run n_run = len(preds) self._gate_sparsity = self._gate_sparsity / max(n_run, 1) return preds if deep_supervision else preds[-1] def count_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) # ═══════════════════════════════════════════════════════════════ # LOSS FUNCTIONS # ═══════════════════════════════════════════════════════════════ def deep_sup_loss(preds_list, targets): """Linearly-weighted deep supervision: later cycles get more weight.""" p = torch.stack(preds_list) w = torch.arange(1, p.shape[0]+1, device=p.device, dtype=p.dtype) w = w / w.sum() return (w * (p - targets.unsqueeze(0)).abs().mean(1)).sum() def gate_halt_loss(preds_list, targets, gate_sparsity): """Deep supervision + gate sparsity to encourage early halting.""" return deep_sup_loss(preds_list, targets) + GATE_SPARSITY * gate_sparsity # ═══════════════════════════════════════════════════════════════ # STRATIFIED SPLIT (within train fold → train/val) # ═══════════════════════════════════════════════════════════════ def strat_split(t, vf=0.15, seed=42): bins = np.digitize(t, np.percentile(t, [25, 50, 75])) tr, vl = [], [] rng = np.random.RandomState(seed) for b in range(4): m = np.where(bins == b)[0] if len(m) == 0: continue n = max(1, int(len(m) * vf)) c = rng.choice(m, n, replace=False) vl.extend(c.tolist()) tr.extend(np.setdiff1d(m, c).tolist()) return np.array(tr), np.array(vl) # ═══════════════════════════════════════════════════════════════ # LIVE DASHBOARD (IPython HTML — works in Kaggle/Jupyter) # ═══════════════════════════════════════════════════════════════ _print_lock = threading.Lock() # Shared state updated by training threads, read by dashboard _dash_state = { 'GH': {'fold': 0, 'ep': 0, 'tr': float('inf'), 'val': float('inf'), 'best': float('inf'), 'best_ep': 0, 'lr': 0., 'eta_m': 0, 'ep_s': 0., 'swa': False, 'done': False, 'test_mae': None}, } _dash_log = [] # Accumulates milestone messages def _log(msg): with _print_lock: _dash_log.append(msg) if not IN_NOTEBOOK: print(msg, flush=True) def _render_html(): """Build an HTML table from _dash_state + recent log lines.""" css = ( '' ) rows = '' for name, s in _dash_state.items(): if s['done'] and s['test_mae']: status = f'✅ {s["test_mae"]:.1f}' elif s['swa']: status = 'SWA' elif s['ep'] == 0: status = 'Waiting' else: status = '▶ Training' ep_str = f"{s['ep']}/{EPOCHS}" if s['ep'] else '-' tr_str = f"{s['tr']:.1f}" if s['tr'] < 1e6 else '-' val_str = f"{s['val']:.1f}" if s['val'] < 1e6 else '-' best_str = f'{s["best"]:.1f}@{s["best_ep"]}' if s['best'] < 1e6 else '-' lr_str = f"{s['lr']:.0e}" if s['lr'] > 0 else '-' eps_str = f"{s['ep_s']:.1f}" if s['ep_s'] > 0 else '-' eta_str = f"{s['eta_m']:.0f}m" if s['eta_m'] > 0 else '-' fold_str = str(s['fold']) if s['fold'] else '-' rows += (f'
| Model | Fold | Epoch | ' f'Train MAE | Val MAE | Best MAE | ' f'LR | s/ep | ETA | Status |
|---|