FLATest β Hybrid GLA + Attention code/reasoning LM
A 308M-parameter decoder-only language model trained from scratch on a single RTX PRO 6000 (Blackwell). It mixes Gated Linear Attention (GLA) layers with sparse full-attention layers β a hybrid sequence mixer in the spirit of Jamba / MiniMax β to get O(N) long-context behaviour while keeping the exact associative recall that pure linear attention loses.
This is an educational / research model: trained on a single GPU for a limited token budget. It is not a SOTA code assistant. Its purpose is to demonstrate a correct, scalable architecture for long-context + reasoning, and the GrokAdamW optimizer recipe.
Architecture
| Params | ~308M |
d_model |
1024 |
| Layers | 24 |
| Heads | 16 (GQA, 4 KV-heads) |
| Mixer | hybrid β GLA on most layers, attention every 4th layer (gggAgggAgggAgggAgggAgggA) |
| Train context | 4096 |
| Vocab | 49152 (StarCoder2 BPE) |
| Position | RoPE on attention layers; GLA uses learned decay (no RoPE) |
| Norm / MLP | RMSNorm + SwiGLU |
| Embeddings | tied input/output |
Why hybrid: pure GLA fails exact associative recall (recall β chance on an induction probe), while a few interleaved attention layers restore it (recall β 1.0). The GLA layers keep the model linear in context length, so the real payoff is long-form generation: in our decode benchmark GLA's recurrent state is ~8.7Γ faster and ~20Γ lighter than an attention KV-cache at 64k output tokens.
Training
- Optimizer: GrokAdamW β decoupled weight decay (0.1), betas (0.9, 0.95), cautious update, optional Grokfast EMA. The weight-decay-driven recipe was verified to reproduce grokking on modular addition (val acc 0 β 1.0).
- Data: infinite mixed stream β code documents (
bigcode/starcoderdata) + reasoning traces (open-r1/OpenR1-Math-220k, ratio 0.3). Reasoning examples are prompt-masked (loss only on<think>β¦</think>+ answer). - Schedule: warmup + cosine, bf16 autocast, grad clip 1.0, effective batch 64
(262k tokens/step),
torch.compile. - Throughput: ~81k tok/s, ~58 GB peak on the PRO 6000.
- Reasoning format: special tokens
<think>/</think>; the model learns to reason in text before answering.
Validation perplexity dropped steadily (ppl 33 β ~4 within a few thousand steps).
See config.json and training_state.json for the exact step the uploaded
checkpoint corresponds to.
Files
ckpt_last.ptβ checkpoint:{model, opt, step, cfg}(PyTorch).config.jsonβ theModelConfigused to build the model.model.py,optim.pyβ model + optimizer definitions (thecodetrainpackage).generate.pyβ inference / sampling script.
Inference
pip install torch transformers flash-linear-attention
python generate.py --ckpt ckpt_last.pt --prompt "Write a Python function that reverses a linked list."
generate.py seeds a <think> block to elicit reasoning, then samples the answer.
Requires flash-linear-attention (Triton) for the GLA layers.
Limitations
- Single-GPU, limited token budget β expect incoherent or repetitive output on hard prompts. It is a scaffold, not a product.
- GLA layers require
flash-linear-attention+ a Triton-capable GPU. generate.pyuses full-recompute decoding (simple, correct for both layer types); the O(1) recurrent GLA decode that gives the long-context speedup is not yet wired into the sampler.
Citation / lineage
Builds on: Gated Linear Attention (Yang et al. 2023), Grokfast (Lee et al. 2024), grokking (Power et al. 2022), hybrid linear/attention stacks (Jamba, MiniMax).
- Downloads last month
- 35