FlashMemory DS-V4 Retriever
A lightweight retriever that sparsifies DeepSeek-V4 CSA KV-cache. Given a decode-token hidden state, it predicts which compressed-K chunks the next ~64 tokens will attend to β keeping only those on GPU, offloading the rest.
In downstream evaluation it matches or beats full-attention baseline on reasoning-heavy long-context tasks (RULER, LongMemEval, LongBench V2) while reducing KV-cache usage by ~85β90%. Precise needle-retrieval tasks require an additional threshold-fallback mechanism (not in this release).
Quick start
pip install torch safetensors
python demo.py --ckpt weights/flashmemory_ds_v4.safetensors
Usage
from retriever import FlashMemoryRetriever
model = FlashMemoryRetriever.from_checkpoint(
"weights/flashmemory_ds_v4.safetensors", device="cuda"
)
# hidden: [B, 4096] decode hidden state
# compressed_k: [B, N, 132] uint8 CSA keys
# positions: [B] int64 token positions
scores = model.ensemble(hidden, compressed_k, positions, mode="max") # [B, N]
keep = model.select_topk(hidden, compressed_k, positions, top_k=512) # boolean mask
compressed_k format: each chunk = 128 bytes float8_e4m3 values + 4 bytes float32 scale. See make_mock_compressed_k() in demo.py.
Architecture
3-layer joint model (l10, l12, l20), 128 heads, 2048 LoRA rank. Per-layer
sigmoid scores are ensembled (max or mean) per chunk.
hidden [B,4096] β q-proj β RoPE(YaRN) β Hadamard β q [B,128,128]
β weights_proj β fused_w [B,128]
compressed_k β FP8 dequant β k [B,N,128]
score = sigmoid( Ξ£( relu(k @ qα΅) Β· fused_w ) ) β [0,1]
Toy inference reference
toy_flashmemory_inference.py illustrates how the retriever drives memory
recall during decode: every 64 steps it re-scores all chunks, and unselected
ones are masked from attention (equivalent to "not recalled to GPU").
python toy_flashmemory_inference.py --ckpt weights/flashmemory_ds_v4.safetensors
The decoder is a few toy layers with random weights β it is not a real DeepSeek-V4. The retriever, scoring math, and decode-time control flow are real.
Files
| File | Purpose |
|---|---|
retriever.py |
FlashMemoryRetriever model (torch-only, self-contained) |
demo.py |
minimal demo with mock inputs |
toy_flashmemory_inference.py |
toy sparse-decode loop |
weights/flashmemory_ds_v4.safetensors |
trained weights (~510 MB) |
Citation
If you use FlashMemory in your research, please cite:
@article{wang2026flashmemory,
title = {FlashMemory-DeepSeek-V4: Lightning Index Ultra-Long Context via Lookahead Sparse Attention},
author = {Yan Wang and Qifan Zhang and Jiachen Yu and Tian Liang and Dongyang Ma and
Xiang Hu and Zibo Lin and Chunyang Li and Zhichao Wang and Jia Li and
Yujiu Yang and Haitao Mi and Dong Yu},
year = {2026},
journal = {arXiv preprint arXiv:2606.09079},
url = {https://huggingface.co/papers/2606.09079},
}
License
MIT