FLATest β€” Hybrid GLA + Attention code/reasoning LM

A 308M-parameter decoder-only language model trained from scratch on a single RTX PRO 6000 (Blackwell). It mixes Gated Linear Attention (GLA) layers with sparse full-attention layers β€” a hybrid sequence mixer in the spirit of Jamba / MiniMax β€” to get O(N) long-context behaviour while keeping the exact associative recall that pure linear attention loses.

This is an educational / research model: trained on a single GPU for a limited token budget. It is not a SOTA code assistant. Its purpose is to demonstrate a correct, scalable architecture for long-context + reasoning, and the GrokAdamW optimizer recipe.

Architecture

Params ~308M
d_model 1024
Layers 24
Heads 16 (GQA, 4 KV-heads)
Mixer hybrid β€” GLA on most layers, attention every 4th layer (gggAgggAgggAgggAgggAgggA)
Train context 4096
Vocab 49152 (StarCoder2 BPE)
Position RoPE on attention layers; GLA uses learned decay (no RoPE)
Norm / MLP RMSNorm + SwiGLU
Embeddings tied input/output

Why hybrid: pure GLA fails exact associative recall (recall β‰ˆ chance on an induction probe), while a few interleaved attention layers restore it (recall β‰ˆ 1.0). The GLA layers keep the model linear in context length, so the real payoff is long-form generation: in our decode benchmark GLA's recurrent state is ~8.7Γ— faster and ~20Γ— lighter than an attention KV-cache at 64k output tokens.

Training

  • Optimizer: GrokAdamW β€” decoupled weight decay (0.1), betas (0.9, 0.95), cautious update, optional Grokfast EMA. The weight-decay-driven recipe was verified to reproduce grokking on modular addition (val acc 0 β†’ 1.0).
  • Data: infinite mixed stream β€” code documents (bigcode/starcoderdata) + reasoning traces (open-r1/OpenR1-Math-220k, ratio 0.3). Reasoning examples are prompt-masked (loss only on <think>…</think> + answer).
  • Schedule: warmup + cosine, bf16 autocast, grad clip 1.0, effective batch 64 (262k tokens/step), torch.compile.
  • Throughput: ~81k tok/s, ~58 GB peak on the PRO 6000.
  • Reasoning format: special tokens <think> / </think>; the model learns to reason in text before answering.

Validation perplexity dropped steadily (ppl 33 β†’ ~4 within a few thousand steps). See config.json and training_state.json for the exact step the uploaded checkpoint corresponds to.

Files

  • ckpt_last.pt β€” checkpoint: {model, opt, step, cfg} (PyTorch).
  • config.json β€” the ModelConfig used to build the model.
  • model.py, optim.py β€” model + optimizer definitions (the codetrain package).
  • generate.py β€” inference / sampling script.

Inference

pip install torch transformers flash-linear-attention
python generate.py --ckpt ckpt_last.pt --prompt "Write a Python function that reverses a linked list."

generate.py seeds a <think> block to elicit reasoning, then samples the answer. Requires flash-linear-attention (Triton) for the GLA layers.

Limitations

  • Single-GPU, limited token budget β†’ expect incoherent or repetitive output on hard prompts. It is a scaffold, not a product.
  • GLA layers require flash-linear-attention + a Triton-capable GPU.
  • generate.py uses full-recompute decoding (simple, correct for both layer types); the O(1) recurrent GLA decode that gives the long-context speedup is not yet wired into the sampler.

Citation / lineage

Builds on: Gated Linear Attention (Yang et al. 2023), Grokfast (Lee et al. 2024), grokking (Power et al. 2022), hybrid linear/attention stacks (Jamba, MiniMax).

Downloads last month
35
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support