tinystories-25m / model.py
epoyraz's picture
Upgrade to modded+Muon+zero-init checkpoint (val 2.65 -> 2.40)
4cae3f7 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.checkpoint import checkpoint
def soft_cap(logits, cap):
"""Gemma2/modded-nanoGPT logit soft-capping: cap * tanh(logits / cap). No-op if cap falsy."""
if cap:
return cap * torch.tanh(logits / cap)
return logits
def chunked_cross_entropy(hidden, weight, targets, cap=0, chunk_size=2048):
"""Memory-efficient cross-entropy. Projects hidden -> logits and reduces the loss
in token-chunks so the full [N, vocab] logits are never materialized at once (each
chunk's logits are recomputed in backward via checkpointing). Numerically equal to
F.cross_entropy(soft_cap(hidden @ weight.T), targets, ignore_index=-1)."""
hidden = hidden.reshape(-1, hidden.size(-1))
targets = targets.reshape(-1)
n_valid = (targets != -1).sum().clamp(min=1)
def chunk_loss(h, t, w):
logits = soft_cap(F.linear(h, w), cap)
return F.cross_entropy(logits, t, ignore_index=-1, reduction="sum")
use_ckpt = torch.is_grad_enabled() and (hidden.requires_grad or weight.requires_grad)
total = hidden.new_zeros(())
for i in range(0, hidden.size(0), chunk_size):
h, t = hidden[i:i + chunk_size], targets[i:i + chunk_size]
if use_ckpt:
total = total + checkpoint(chunk_loss, h, t, weight, use_reentrant=False)
else:
total = total + chunk_loss(h, t, weight)
return total / n_valid
# --- mHC: Manifold-Constrained Hyper-Connections ---
def sinkhorn(log_alpha, n_iters=5):
for _ in range(n_iters):
log_alpha = log_alpha - torch.logsumexp(log_alpha, dim=-1, keepdim=True)
log_alpha = log_alpha - torch.logsumexp(log_alpha, dim=-2, keepdim=True)
return log_alpha.exp()
class MHCResidual(nn.Module):
def __init__(self, n_streams):
super().__init__()
self.n_streams = n_streams
self.log_alpha = nn.Parameter(torch.zeros(n_streams, n_streams))
def forward(self, streams, update):
W = sinkhorn(self.log_alpha)
mixed = torch.einsum("ij,bjte->bite", W, streams)
mixed[:, 0] = mixed[:, 0] + update
return mixed
class MHCExpand(nn.Module):
def __init__(self, n_streams, n_embd):
super().__init__()
self.n_streams = n_streams
self.proj = nn.Linear(n_embd, n_streams * n_embd) if n_streams > 1 else None
def forward(self, x):
if self.n_streams == 1:
return x.unsqueeze(1)
B, T, C = x.shape
return self.proj(x).view(B, self.n_streams, T, C)
class MHCCollapse(nn.Module):
def __init__(self, n_streams, n_embd):
super().__init__()
self.n_streams = n_streams
self.proj = nn.Linear(n_streams * n_embd, n_embd) if n_streams > 1 else None
def forward(self, streams):
if self.n_streams == 1:
return streams.squeeze(1)
B, S, T, C = streams.shape
return self.proj(streams.permute(0, 2, 1, 3).reshape(B, T, S * C))
# --- BitNet: Ternary weight linear layer ---
class BitLinear(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.empty(out_features, in_features))
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
self.rms_norm = nn.RMSNorm(in_features)
nn.init.normal_(self.weight, std=0.02)
def ternary_quantize(self, w):
alpha = w.abs().mean()
threshold = alpha * 0.5
w_ternary = torch.zeros_like(w)
w_ternary[w > threshold] = alpha
w_ternary[w < -threshold] = -alpha
return w_ternary.detach() + (w - w.detach())
def activation_quantize(self, x):
scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
x_scaled = x * scale
x_q = x_scaled.round().clamp(-128, 127).detach() + (x_scaled - x_scaled.detach())
return x_q / scale
def forward(self, x):
x = self.rms_norm(x)
w_q = self.ternary_quantize(self.weight)
x_q = self.activation_quantize(x)
out = F.linear(x_q, w_q, self.bias)
return out
class FastBitLinear(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.empty(out_features, in_features))
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
self.rms_norm = nn.RMSNorm(in_features)
nn.init.normal_(self.weight, std=0.02)
def _int8_forward(self, x):
w = self.weight.detach()
alpha = w.abs().mean()
threshold = alpha * 0.5
# Pack the ternary weight into a single signed int8 tensor {-1,0,+1} so the
# whole layer is ONE int8 matmul (dp4a / int8 tensor cores), not two. This is
# exactly equivalent to (x @ w_pos.T) - (x @ w_neg.T) but ~2x cheaper, and it
# beats fp16 at prefill/training scale.
w_ternary = torch.zeros_like(w, dtype=torch.int8)
w_ternary[w > threshold] = 1
w_ternary[w < -threshold] = -1
x_max = x.detach().abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
x_scale = 127.0 / x_max
x_q = (x.detach() * x_scale).round().clamp(-128, 127).to(torch.int8)
shape = x_q.shape
x_2d = x_q.reshape(-1, shape[-1])
rows = x_2d.shape[0]
if rows <= 16: # torch._int_mm requires more than 16 rows
x_2d = torch.nn.functional.pad(x_2d, (0, 0, 0, 17 - rows))
y = torch._int_mm(x_2d, w_ternary.T)[:rows]
else:
y = torch._int_mm(x_2d, w_ternary.T)
y = y.float().reshape(*shape[:-1], self.out_features)
return y * (alpha / x_scale)
def _ste_forward(self, x):
alpha = self.weight.abs().mean()
threshold = alpha * 0.5
w_ternary = torch.zeros_like(self.weight)
w_ternary[self.weight > threshold] = alpha
w_ternary[self.weight < -threshold] = -alpha
w_q = self.weight + (w_ternary - self.weight).detach()
x_scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
x_scaled = x * x_scale
x_q = x_scaled + (x_scaled.round().clamp(-128, 127) - x_scaled).detach()
x_q = x_q / x_scale
return F.linear(x_q, w_q, None)
def forward(self, x):
x = self.rms_norm(x)
if self.training:
out = self._ste_forward(x)
else:
out = self._int8_forward(x)
if self.bias is not None:
out = out + self.bias
return out
def make_linear(in_f, out_f, bias=True, use_bitnet=False, use_fast_bitnet=False):
if use_fast_bitnet:
return FastBitLinear(in_f, out_f, bias=bias)
if use_bitnet:
return BitLinear(in_f, out_f, bias=bias)
return nn.Linear(in_f, out_f, bias=bias)
# --- TurboQuant: KV-cache compression for inference ---
class PolarQuantizer:
def __init__(self, bits=4):
self.bits = bits
self.levels = 2 ** bits
def quantize(self, tensor):
norms = tensor.norm(dim=-1, keepdim=True).clamp(min=1e-8)
unit = tensor / norms
norm_min = norms.min()
norm_max = norms.max()
norm_scale = (norm_max - norm_min) / (self.levels - 1)
q_norms = ((norms - norm_min) / norm_scale.clamp(min=1e-8)).round().clamp(0, self.levels - 1)
val_min = unit.min()
val_max = unit.max()
val_scale = (val_max - val_min) / (self.levels - 1)
q_unit = ((unit - val_min) / val_scale.clamp(min=1e-8)).round().clamp(0, self.levels - 1)
return q_norms, q_unit, (norm_min, norm_scale, val_min, val_scale)
def dequantize(self, q_norms, q_unit, params):
norm_min, norm_scale, val_min, val_scale = params
norms = q_norms * norm_scale + norm_min
unit = q_unit * val_scale + val_min
unit = unit / unit.norm(dim=-1, keepdim=True).clamp(min=1e-8)
return unit * norms
class TurboQuantKVCache:
def __init__(self, bits=4):
self.quantizer = PolarQuantizer(bits=bits)
self.k_cache = []
self.v_cache = []
def update(self, k_new, v_new):
qk_norms, qk_unit, k_params = self.quantizer.quantize(k_new)
qv_norms, qv_unit, v_params = self.quantizer.quantize(v_new)
self.k_cache.append((qk_norms, qk_unit, k_params))
self.v_cache.append((qv_norms, qv_unit, v_params))
def get(self):
ks = [self.quantizer.dequantize(*entry) for entry in self.k_cache]
vs = [self.quantizer.dequantize(*entry) for entry in self.v_cache]
return torch.cat(ks, dim=2), torch.cat(vs, dim=2)
def clear(self):
self.k_cache.clear()
self.v_cache.clear()
class KVCache:
def __init__(self, max_seq_len):
self.max_seq_len = max_seq_len
self.k_cache = None
self.v_cache = None
self.pos = 0
def _ensure_allocated(self, k_new, v_new):
B, H, _, D = k_new.shape
needs_alloc = (
self.k_cache is None
or self.k_cache.shape[0] != B
or self.k_cache.shape[1] != H
or self.k_cache.shape[3] != D
or self.k_cache.device != k_new.device
or self.k_cache.dtype != k_new.dtype
)
if needs_alloc:
self.k_cache = torch.empty(
B, H, self.max_seq_len, D,
device=k_new.device,
dtype=k_new.dtype,
)
self.v_cache = torch.empty(
B, H, self.max_seq_len, D,
device=v_new.device,
dtype=v_new.dtype,
)
self.pos = 0
def update(self, k_new, v_new):
self._ensure_allocated(k_new, v_new)
T = k_new.size(2)
if self.pos + T > self.max_seq_len:
raise ValueError(f"KV cache length {self.pos + T} exceeds max_seq_len {self.max_seq_len}")
self.k_cache[:, :, self.pos:self.pos + T, :].copy_(k_new)
self.v_cache[:, :, self.pos:self.pos + T, :].copy_(v_new)
self.pos += T
def get(self):
if self.k_cache is None:
return None, None
return self.k_cache[:, :, :self.pos, :], self.v_cache[:, :, :self.pos, :]
def clear(self):
self.pos = 0
# --- MTP: Multi-Token Prediction ---
class MTPHead(nn.Module):
def __init__(self, config, future_idx):
super().__init__()
self.future_idx = future_idx
n_embd = config["n_embd"]
vocab_size = config["vocab_size"]
self.logit_cap = config.get("logit_cap", 0)
self.use_chunked_loss = config.get("use_chunked_loss", False)
self.loss_chunk_size = config.get("loss_chunk_size", 2048)
self.proj = nn.Linear(n_embd, n_embd)
self.ln = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
def forward(self, hidden, targets=None):
if targets is not None and self.use_chunked_loss:
shift = self.future_idx
if targets.size(1) <= shift:
return None, None
# Project only the positions with a future target, then reduce in chunks.
h = self.ln(self.proj(hidden[:, :-shift]))
loss = chunked_cross_entropy(
h, self.lm_head.weight, targets[:, shift:], self.logit_cap, self.loss_chunk_size
)
return None, loss
h = self.ln(self.proj(hidden))
logits = soft_cap(self.lm_head(h), self.logit_cap)
loss = None
if targets is not None:
shift = self.future_idx
if targets.size(1) > shift:
logits_shifted = logits[:, :-shift].contiguous()
targets_shifted = targets[:, shift:].contiguous()
loss = F.cross_entropy(
logits_shifted.view(-1, logits_shifted.size(-1)),
targets_shifted.view(-1),
ignore_index=-1,
)
return logits, loss
# --- RoPE: Rotary Position Embeddings ---
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=4096, base=10000.0):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self._build_cache(max_seq_len)
def _build_cache(self, seq_len):
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(self, seq_len):
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def apply_rope(q, k, cos, sin):
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
q = q * cos + rotate_half(q) * sin
k = k * cos + rotate_half(k) * sin
return q, k
# --- SwiGLU MLP ---
class SwiGLU(nn.Module):
def __init__(self, config):
super().__init__()
n_embd = config["n_embd"]
hidden = int(4 * n_embd * 2 / 3)
hidden = ((hidden + 63) // 64) * 64
use_bitnet = config.get("use_bitnet", False)
use_fast_bitnet = config.get("use_fast_bitnet", False)
self.gate = make_linear(n_embd, hidden, bias=False, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
self.up = make_linear(n_embd, hidden, bias=False, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
self.down = make_linear(hidden, n_embd, bias=False, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
def forward(self, x):
return self.down(F.silu(self.gate(x)) * self.up(x))
class ReLU2MLP(nn.Module):
"""Ungated MLP with squared-ReLU activation (modded-nanoGPT). Simpler and a bit
faster than SwiGLU; competitive quality at small scale."""
def __init__(self, config):
super().__init__()
n_embd = config["n_embd"]
hidden = 4 * n_embd
use_bitnet = config.get("use_bitnet", False)
use_fast_bitnet = config.get("use_fast_bitnet", False)
self.fc = make_linear(n_embd, hidden, bias=False, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
self.proj = make_linear(hidden, n_embd, bias=False, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
def forward(self, x):
return self.proj(F.relu(self.fc(x)).square())
# --- Core model ---
def make_norm(n_embd, use_rmsnorm=False):
if use_rmsnorm:
return nn.RMSNorm(n_embd)
return nn.LayerNorm(n_embd)
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.n_head = config["n_head"]
self.n_embd = config["n_embd"]
self.n_kv_head = config.get("n_kv_head", self.n_head)
if self.n_embd % self.n_head != 0:
raise ValueError(f"n_embd ({self.n_embd}) must be divisible by n_head ({self.n_head})")
if self.n_head % self.n_kv_head != 0:
raise ValueError(f"n_head ({self.n_head}) must be divisible by n_kv_head ({self.n_kv_head})")
self.head_dim = self.n_embd // self.n_head
self.use_rope = config.get("use_rope", False)
self.use_qk_norm = config.get("use_qk_norm", False)
use_bitnet = config.get("use_bitnet", False)
use_fast_bitnet = config.get("use_fast_bitnet", False)
self.q_proj = make_linear(self.n_embd, self.n_head * self.head_dim, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
self.k_proj = make_linear(self.n_embd, self.n_kv_head * self.head_dim, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
self.v_proj = make_linear(self.n_embd, self.n_kv_head * self.head_dim, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
self.proj = make_linear(self.n_embd, self.n_embd, use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
# QK-Norm (modded-nanoGPT): RMSNorm Q and K over the head dim before attention.
if self.use_qk_norm:
self.q_norm = nn.RMSNorm(self.head_dim)
self.k_norm = nn.RMSNorm(self.head_dim)
if self.use_rope:
self.rope = RotaryEmbedding(self.head_dim, max_seq_len=config.get("block_size", 512))
def forward(self, x, kv_cache=None, pos_offset=0):
B, T, C = x.shape
q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
if self.use_qk_norm:
q = self.q_norm(q)
k = self.k_norm(k)
if self.use_rope:
cos, sin = self.rope(pos_offset + T)
cos, sin = cos[pos_offset:pos_offset + T], sin[pos_offset:pos_offset + T]
q, k = apply_rope(q, k, cos, sin)
if self.n_kv_head < self.n_head:
repeats = self.n_head // self.n_kv_head
k = k.repeat_interleave(repeats, dim=1)
v = v.repeat_interleave(repeats, dim=1)
if kv_cache is not None:
kv_cache.update(k, v)
k, v = kv_cache.get()
use_causal = (T > 1)
out = F.scaled_dot_product_attention(q, k, v, is_causal=use_causal)
out = out.transpose(1, 2).reshape(B, T, C)
return self.proj(out)
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
use_bitnet = config.get("use_bitnet", False)
use_fast_bitnet = config.get("use_fast_bitnet", False)
self.fc = make_linear(config["n_embd"], 4 * config["n_embd"], use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
self.proj = make_linear(4 * config["n_embd"], config["n_embd"], use_bitnet=use_bitnet, use_fast_bitnet=use_fast_bitnet)
def forward(self, x):
return self.proj(F.gelu(self.fc(x)))
class Block(nn.Module):
def __init__(self, config, layer_idx=0):
super().__init__()
self.use_mhc = config.get("use_mhc", False)
use_rmsnorm = config.get("use_rmsnorm", False)
self.ln1 = make_norm(config["n_embd"], use_rmsnorm)
self.attn = CausalSelfAttention(config)
self.ln2 = make_norm(config["n_embd"], use_rmsnorm)
if config.get("use_relu2", False):
self.mlp = ReLU2MLP(config)
elif config.get("use_swiglu", False):
self.mlp = SwiGLU(config)
else:
self.mlp = MLP(config)
if self.use_mhc:
n_streams = config.get("mhc_streams", 4)
self.mhc_attn = MHCResidual(n_streams)
self.mhc_mlp = MHCResidual(n_streams)
def forward(self, x, streams=None, kv_cache=None, pos_offset=0):
if self.use_mhc and streams is not None:
inp = streams[:, 0]
attn_out = self.attn(self.ln1(inp), kv_cache=kv_cache, pos_offset=pos_offset)
streams = self.mhc_attn(streams, attn_out)
mlp_inp = streams[:, 0]
mlp_out = self.mlp(self.ln2(mlp_inp))
streams = self.mhc_mlp(streams, mlp_out)
return streams
else:
x = x + self.attn(self.ln1(x), kv_cache=kv_cache, pos_offset=pos_offset)
x = x + self.mlp(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.use_mhc = config.get("use_mhc", False)
self.use_mtp = config.get("use_mtp", False)
self.use_rope = config.get("use_rope", False)
self.mtp_heads_n = config.get("mtp_heads", 4)
self.mtp_weight = config.get("mtp_weight", 0.1)
self.use_turboquant = config.get("use_turboquant", False)
self.turboquant_bits = config.get("turboquant_bits", 4)
self.use_activation_checkpointing = config.get("use_activation_checkpointing", False)
self.logit_cap = config.get("logit_cap", 0)
self.use_chunked_loss = config.get("use_chunked_loss", False)
self.loss_chunk_size = config.get("loss_chunk_size", 2048)
use_rmsnorm = config.get("use_rmsnorm", False)
self.tok_emb = nn.Embedding(config["vocab_size"], config["n_embd"])
if not self.use_rope:
self.pos_emb = nn.Embedding(config["block_size"], config["n_embd"])
self.blocks = nn.ModuleList([Block(config, i) for i in range(config["n_layer"])])
self.ln_f = make_norm(config["n_embd"], use_rmsnorm)
self.lm_head = nn.Linear(config["n_embd"], config["vocab_size"], bias=False)
self.tok_emb.weight = self.lm_head.weight
if self.use_mhc:
n_streams = config.get("mhc_streams", 4)
self.mhc_expand = MHCExpand(n_streams, config["n_embd"])
self.mhc_collapse = MHCCollapse(n_streams, config["n_embd"])
if self.use_mtp:
self.mtp_heads = nn.ModuleList([
MTPHead(config, future_idx=i + 1) for i in range(self.mtp_heads_n)
])
if config.get("tie_mtp_lm_head", True):
for head in self.mtp_heads:
head.lm_head.weight = self.lm_head.weight
self.apply(self._init_weights)
# Zero-init the output projection of each block (attention out-proj + MLP
# down-proj), muP-style (modded-nanoGPT / nanochat). Each block starts as a
# near-identity residual and learns to contribute, which helps convergence.
if config.get("use_zero_init", False):
for block in self.blocks:
torch.nn.init.zeros_(block.attn.proj.weight)
mlp_out = getattr(block.mlp, "down", None)
if mlp_out is None:
mlp_out = block.mlp.proj # MLP / ReLU2MLP name the out-proj "proj"
torch.nn.init.zeros_(mlp_out.weight)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, BitLinear)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def _compute_hidden(self, idx):
B, T = idx.shape
if T > self.config["block_size"]:
raise ValueError(f"Input length {T} exceeds block_size {self.config['block_size']}")
x = self.tok_emb(idx)
if not self.use_rope:
pos = torch.arange(T, device=idx.device)
x = x + self.pos_emb(pos)
if self.use_mhc:
streams = self.mhc_expand(x)
for block in self.blocks:
if self.training and self.use_activation_checkpointing:
streams = checkpoint(lambda s, b=block: b(x, streams=s), streams, use_reentrant=False)
else:
streams = block(x, streams=streams)
x = self.mhc_collapse(streams)
else:
for block in self.blocks:
if self.training and self.use_activation_checkpointing:
x = checkpoint(block, x, use_reentrant=False)
else:
x = block(x)
return self.ln_f(x)
def forward(self, idx, targets=None, return_hidden=False):
hidden = self._compute_hidden(idx)
loss = None
# Chunked loss avoids materializing the full [N, vocab] logits during training.
# It can't return logits, so fall back to the dense path when logits are needed.
if targets is not None and self.use_chunked_loss and not return_hidden:
logits = None
loss = chunked_cross_entropy(
hidden, self.lm_head.weight, targets, self.logit_cap, self.loss_chunk_size
)
else:
logits = soft_cap(self.lm_head(hidden), self.logit_cap)
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
if targets is not None and self.use_mtp:
for head in self.mtp_heads:
_, mtp_loss = head(hidden, targets)
if mtp_loss is not None:
loss = loss + self.mtp_weight * mtp_loss
if return_hidden:
return logits, loss, hidden
return logits, loss
def _forward_inference(self, x, kv_caches, pos_offset=0, return_hidden=False):
if self.use_mhc:
streams = self.mhc_expand(x)
for block, cache in zip(self.blocks, kv_caches or [None] * len(self.blocks)):
streams = block(x, streams=streams, kv_cache=cache, pos_offset=pos_offset)
x = self.mhc_collapse(streams)
else:
for block, cache in zip(self.blocks, kv_caches or [None] * len(self.blocks)):
x = block(x, kv_cache=cache, pos_offset=pos_offset)
hidden = self.ln_f(x)
logits = soft_cap(self.lm_head(hidden), self.logit_cap)
if return_hidden:
return logits, hidden
return logits
def _embed(self, tokens, pos_offset=0):
x = self.tok_emb(tokens)
if not self.use_rope:
T = tokens.shape[1]
pos = torch.arange(pos_offset, pos_offset + T, device=tokens.device)
x = x + self.pos_emb(pos)
return x
def _filter_logits(self, logits, top_k=None, top_p=None, min_p=None):
if top_k is not None and top_k > 0:
k = min(top_k, logits.size(-1))
values, _ = torch.topk(logits, k)
logits = logits.masked_fill(logits < values[:, [-1]], -float("inf"))
if min_p is not None and min_p > 0:
probs = F.softmax(logits, dim=-1)
max_probs = probs.max(dim=-1, keepdim=True).values
remove = probs < (min_p * max_probs)
top_token = logits.argmax(dim=-1, keepdim=True)
remove.scatter_(dim=-1, index=top_token, value=False)
logits = logits.masked_fill(remove, -float("inf"))
if top_p is not None and 0 < top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = sorted_probs.cumsum(dim=-1)
sorted_remove = cumulative_probs > top_p
sorted_remove[..., 1:] = sorted_remove[..., :-1].clone()
sorted_remove[..., 0] = False
remove = torch.zeros_like(logits, dtype=torch.bool)
remove.scatter_(dim=-1, index=sorted_idx, src=sorted_remove)
logits = logits.masked_fill(remove, -float("inf"))
return logits
def _distribution(self, logits, temperature=0.8, top_k=40, top_p=None, min_p=None):
if temperature <= 0:
token = logits.argmax(dim=-1, keepdim=True)
probs = torch.zeros_like(logits)
probs.scatter_(1, token, 1.0)
return token, probs
logits = self._filter_logits(logits / temperature, top_k=top_k, top_p=top_p, min_p=min_p)
probs = F.softmax(logits, dim=-1)
token = torch.multinomial(probs, num_samples=1)
return token, probs
def _make_kv_caches(self, use_turboquant, use_kv_cache=True):
if not use_kv_cache:
return None
if use_turboquant:
return [TurboQuantKVCache(bits=self.turboquant_bits) for _ in self.blocks]
return [KVCache(self.config["block_size"]) for _ in self.blocks]
def _trim_or_seed_prompt(self, idx):
block_size = self.config["block_size"]
if idx.shape[1] == 0:
eos_id = 1
idx = torch.tensor([[eos_id]], dtype=idx.dtype, device=idx.device)
return idx[:, -block_size:]
def _prefill_generation(self, idx, use_turboquant=False, use_kv_cache=True):
kv_caches = self._make_kv_caches(use_turboquant, use_kv_cache=use_kv_cache)
seq_len = idx.shape[1]
x = self._embed(idx)
logits, hidden = self._forward_inference(x, kv_caches, pos_offset=0, return_hidden=True)
return logits, hidden[:, -1:, :], kv_caches, seq_len
def _advance_generation_state(self, idx, idx_next, kv_caches, seq_len, use_turboquant):
block_size = self.config["block_size"]
if kv_caches is not None and seq_len < block_size:
x = self._embed(idx_next, pos_offset=seq_len)
logits, hidden = self._forward_inference(x, kv_caches, pos_offset=seq_len, return_hidden=True)
return logits, hidden[:, -1:, :], kv_caches, seq_len + 1
use_kv_cache = kv_caches is not None
if kv_caches:
for cache in kv_caches:
cache.clear()
idx_cond = idx[:, -block_size:]
logits, hidden, kv_caches, seq_len = self._prefill_generation(
idx_cond,
use_turboquant=use_turboquant,
use_kv_cache=use_kv_cache,
)
return logits, hidden, kv_caches, seq_len
def _generate_autoregressive(
self,
idx,
max_new_tokens,
temperature=0.8,
top_k=40,
top_p=None,
min_p=None,
use_turboquant=None,
use_kv_cache=True,
):
idx = self._trim_or_seed_prompt(idx)
use_turboquant = self.use_turboquant if use_turboquant is None else use_turboquant
logits, last_hidden, kv_caches, seq_len = self._prefill_generation(
idx,
use_turboquant=use_turboquant,
use_kv_cache=use_kv_cache,
)
for i in range(max_new_tokens):
idx_next, _ = self._distribution(
logits[:, -1, :],
temperature=temperature,
top_k=top_k,
top_p=top_p,
min_p=min_p,
)
idx = torch.cat([idx, idx_next], dim=1)
if i < max_new_tokens - 1:
logits, last_hidden, kv_caches, seq_len = self._advance_generation_state(
idx, idx_next, kv_caches, seq_len, use_turboquant
)
return idx
def _mtp_draft(self, last_hidden, n_tokens, temperature=0.8, top_k=40, top_p=None, min_p=None):
draft_tokens = []
draft_probs = []
for head in self.mtp_heads[:n_tokens]:
draft_logits, _ = head(last_hidden)
token, probs = self._distribution(
draft_logits[:, -1, :],
temperature=temperature,
top_k=top_k,
top_p=top_p,
min_p=min_p,
)
draft_tokens.append(token)
draft_probs.append(probs)
return draft_tokens, draft_probs
def _resample_on_reject(self, target_token, p_probs, q_probs, temperature):
if temperature <= 0:
return target_token
residual = (p_probs - q_probs).clamp(min=0)
denom = residual.sum(dim=-1, keepdim=True)
if denom.item() <= 1e-12:
return target_token
return torch.multinomial(residual / denom, num_samples=1)
def _mtp_speculative_generate(
self,
idx,
max_new_tokens,
temperature=0.8,
top_k=40,
top_p=None,
min_p=None,
speculate_tokens=None,
use_turboquant=None,
use_kv_cache=True,
):
use_turboquant = self.use_turboquant if use_turboquant is None else use_turboquant
# Batched verification needs a single sequence, MTP draft heads, and the
# plain (rollback-able) KV cache. TurboQuant's cache cannot be rolled back
# token-by-token, so fall back to autoregressive there.
if not self.use_mtp or idx.size(0) != 1 or not use_kv_cache or use_turboquant:
return self._generate_autoregressive(
idx,
max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
min_p=min_p,
use_turboquant=use_turboquant,
use_kv_cache=use_kv_cache,
)
idx = self._trim_or_seed_prompt(idx)
block_size = self.config["block_size"]
draft_width = speculate_tokens or self.mtp_heads_n
draft_width = max(1, min(draft_width, self.mtp_heads_n))
logits, last_hidden, kv_caches, seq_len = self._prefill_generation(
idx, use_turboquant=False, use_kv_cache=True
)
# p0 = main-model logits for the next token (verifies the first draft).
p0_logits = logits[:, -1, :]
generated = 0
while generated < max_new_tokens:
remaining = max_new_tokens - generated
n_draft = min(draft_width, remaining)
# No room left in the cache window: take one plain step (this slides the
# window via re-prefill inside _advance_generation_state) and continue.
if seq_len + n_draft > block_size:
idx_next, _ = self._distribution(p0_logits, temperature, top_k, top_p, min_p)
idx = torch.cat([idx, idx_next], dim=1)
generated += 1
if generated < max_new_tokens:
logits, last_hidden, kv_caches, seq_len = self._advance_generation_state(
idx, idx_next, kv_caches, seq_len, False
)
p0_logits = logits[:, -1, :]
continue
# 1. Draft n tokens cheaply from the MTP heads (no main-model forward).
draft_tokens, draft_probs = self._mtp_draft(
last_hidden, n_draft, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p
)
draft_seq = torch.cat(draft_tokens, dim=1)
# 2. Verify ALL drafts in a SINGLE main-model forward pass.
x = self._embed(draft_seq, pos_offset=seq_len)
v_logits, v_hidden = self._forward_inference(
x, kv_caches, pos_offset=seq_len, return_hidden=True
)
# 3. Walk the drafts left-to-right; draft j is checked against the main
# distribution at the previous position (p0 for j=0, else v_logits[j-1]).
accepted = 0
reject_token = None
for j in range(n_draft):
target_logits = p0_logits if j == 0 else v_logits[:, j - 1, :]
target_token, p_probs = self._distribution(
target_logits, temperature, top_k, top_p, min_p
)
if temperature <= 0:
accept = torch.equal(draft_tokens[j], target_token)
else:
proposed = draft_tokens[j].item()
p = p_probs[0, proposed]
q = draft_probs[j][0, proposed].clamp(min=1e-12)
accept = torch.rand((), device=idx.device) <= torch.minimum(torch.ones_like(p), p / q)
if accept:
accepted += 1
else:
reject_token = self._resample_on_reject(
target_token, p_probs, draft_probs[j], temperature
)
break
if accepted == n_draft:
# Every draft matched the main model: commit them all. The cache
# already holds them and v_hidden/v_logits give the next draft state
# for free (no extra forward, no separate bonus token needed).
idx = torch.cat([idx, draft_seq], dim=1)
generated += n_draft
seq_len += n_draft
last_hidden = v_hidden[:, -1:, :]
p0_logits = v_logits[:, -1, :]
else:
# Commit the accepted prefix plus the corrected token, then roll the
# cache back to drop the rejected drafts' (now stale) KV entries.
commit = torch.cat(draft_tokens[:accepted] + [reject_token], dim=1)
idx = torch.cat([idx, commit], dim=1)
generated += accepted + 1
for cache in kv_caches:
cache.pos = seq_len + accepted
seq_len += accepted
if generated < max_new_tokens:
# reject_token's KV/hidden are not cached yet; one short forward rebases.
logits, last_hidden, kv_caches, seq_len = self._advance_generation_state(
idx, reject_token, kv_caches, seq_len, False
)
p0_logits = logits[:, -1, :]
return idx
def generate(
self,
idx,
max_new_tokens,
temperature=0.8,
top_k=40,
top_p=None,
min_p=None,
speculative=False,
speculate_tokens=None,
use_turboquant=None,
use_kv_cache=True,
):
if speculative:
return self._mtp_speculative_generate(
idx,
max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
min_p=min_p,
speculate_tokens=speculate_tokens,
use_turboquant=use_turboquant,
use_kv_cache=use_kv_cache,
)
return self._generate_autoregressive(
idx,
max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
min_p=min_p,
use_turboquant=use_turboquant,
use_kv_cache=use_kv_cache,
)
# --- Configs ---
BASE_CONFIG = {
"vocab_size": 16384,
"block_size": 512,
"n_embd": 512,
"n_head": 8,
"n_layer": 12,
}
# Individual techniques
MHC_CONFIG = {**BASE_CONFIG, "use_mhc": True, "mhc_streams": 4}
BITNET_CONFIG = {**BASE_CONFIG, "use_bitnet": True}
FAST_BITNET_CONFIG = {**BASE_CONFIG, "use_fast_bitnet": True}
MTP_CONFIG = {**BASE_CONFIG, "use_mtp": True, "mtp_heads": 4, "mtp_weight": 0.1}
ROPE_CONFIG = {**BASE_CONFIG, "use_rope": True}
GQA_CONFIG = {**BASE_CONFIG, "n_kv_head": 2}
SWIGLU_CONFIG = {**BASE_CONFIG, "use_swiglu": True}
RMSNORM_CONFIG = {**BASE_CONFIG, "use_rmsnorm": True}
TURBOQUANT_CONFIG = {**BASE_CONFIG, "use_turboquant": True, "turboquant_bits": 4}
# Combinations
MHC_BITNET_CONFIG = {**BASE_CONFIG, "use_mhc": True, "mhc_streams": 4, "use_bitnet": True}
MHC_MTP_CONFIG = {**BASE_CONFIG, "use_mhc": True, "mhc_streams": 4, "use_mtp": True, "mtp_heads": 4, "mtp_weight": 0.1}
# Modern LLaMA-style (RoPE + GQA + SwiGLU + RMSNorm)
MODERN_CONFIG = {**BASE_CONFIG, "use_rope": True, "n_kv_head": 2, "use_swiglu": True, "use_rmsnorm": True}
# Everything
ALL_CONFIG = {
**BASE_CONFIG,
"use_mhc": True, "mhc_streams": 4,
"use_bitnet": True,
"use_mtp": True, "mtp_heads": 4, "mtp_weight": 0.1,
"use_rope": True, "n_kv_head": 2,
"use_swiglu": True, "use_rmsnorm": True,
"use_turboquant": True, "turboquant_bits": 4,
}
RECOMMENDED_CONFIG = {
**BASE_CONFIG,
"use_rope": True, "n_kv_head": 2,
"use_swiglu": True, "use_rmsnorm": True,
"use_mtp": True, "mtp_heads": 4, "mtp_weight": 0.1,
}
FAST_2060_CONFIG = {
**BASE_CONFIG,
"block_size": 256,
"n_embd": 384,
"n_head": 6,
"n_layer": 8,
"use_rope": True,
"n_kv_head": 2,
"use_swiglu": True,
"use_rmsnorm": True,
}
FAST_2060_MTP_CONFIG = {
**FAST_2060_CONFIG,
"use_mtp": True,
"mtp_heads": 2,
"mtp_weight": 0.1,
"tie_mtp_lm_head": True,
}
FAST_2060_MTP_FBITNET_CONFIG = {
**FAST_2060_MTP_CONFIG,
"use_fast_bitnet": True,
}
# modded-nanoGPT-style recipe. QK-Norm helps under any optimizer; ReLU2 and
# logit_cap only pay off paired with Muon's higher LR. Train with --optimizer muon.
FAST_2060_MODDED_CONFIG = {
**FAST_2060_MTP_CONFIG,
"use_swiglu": False, # superseded by ReLU2 below
"use_relu2": True,
"use_qk_norm": True,
"logit_cap": 15.0,
"use_zero_init": True, # measured: val 2.13 -> 2.04 at equal steps, free
}
# Same modded recipe but WITHOUT MTP (built on FAST_2060_CONFIG, not the _mtp one).
# This is exactly the config that won the convergence A/B (val 2.13). No MTP means a
# cleaner pure-CE loss number and faster steps, but no speculative-decoding heads.
FAST_2060_MODDED_NOMTP_CONFIG = {
**FAST_2060_CONFIG,
"use_swiglu": False,
"use_relu2": True,
"use_qk_norm": True,
"logit_cap": 15.0,
"use_zero_init": True,
}
FAST_2060_MTP_TURBO_CONFIG = {
**FAST_2060_MTP_CONFIG,
"use_turboquant": True,
"turboquant_bits": 4,
}
TINY_FAST_CONFIG = {
**BASE_CONFIG,
"block_size": 256,
"n_embd": 256,
"n_head": 4,
"n_layer": 6,
"use_rope": True,
"n_kv_head": 2,
"use_swiglu": True,
"use_rmsnorm": True,
}
LOW_MEMORY_2060_CONFIG = {
**FAST_2060_CONFIG,
"use_activation_checkpointing": True,
}
CONFIGS = {
"base": BASE_CONFIG,
"mhc": MHC_CONFIG,
"bitnet": BITNET_CONFIG,
"mtp": MTP_CONFIG,
"rope": ROPE_CONFIG,
"gqa": GQA_CONFIG,
"swiglu": SWIGLU_CONFIG,
"rmsnorm": RMSNORM_CONFIG,
"turboquant": TURBOQUANT_CONFIG,
"mhc_bitnet": MHC_BITNET_CONFIG,
"mhc_mtp": MHC_MTP_CONFIG,
"modern": MODERN_CONFIG,
"all": ALL_CONFIG,
"recommended": RECOMMENDED_CONFIG,
"fast_2060": FAST_2060_CONFIG,
"fast_2060_mtp": FAST_2060_MTP_CONFIG,
"fast_2060_mtp_fbitnet": FAST_2060_MTP_FBITNET_CONFIG,
"fast_2060_modded": FAST_2060_MODDED_CONFIG,
"fast_2060_modded_nomtp": FAST_2060_MODDED_NOMTP_CONFIG,
"fast_2060_mtp_turbo": FAST_2060_MTP_TURBO_CONFIG,
"tiny_fast": TINY_FAST_CONFIG,
"low_memory_2060": LOW_MEMORY_2060_CONFIG,
}
def get_model_config(name="fast_2060", **overrides):
if name not in CONFIGS:
available = ", ".join(sorted(CONFIGS))
raise ValueError(f"Unknown config '{name}'. Available configs: {available}")
return {**CONFIGS[name], **{k: v for k, v in overrides.items() if v is not None}}
MODEL_CONFIG = RECOMMENDED_CONFIG
if __name__ == "__main__":
configs = CONFIGS
for name, cfg in configs.items():
model = GPT(cfg)
n_params = sum(p.numel() for p in model.parameters())
x = torch.randint(0, cfg["vocab_size"], (2, 64))
logits, loss = model(x, x)
print(f"{name:<12} | {n_params:>12,} params ({n_params/1e6:.1f}M) | loss: {loss.item():.2f}")