FlashMemory DS-V4 Retriever

A lightweight retriever that sparsifies DeepSeek-V4 CSA KV-cache. Given a decode-token hidden state, it predicts which compressed-K chunks the next ~64 tokens will attend to β€” keeping only those on GPU, offloading the rest.

In downstream evaluation it matches or beats full-attention baseline on reasoning-heavy long-context tasks (RULER, LongMemEval, LongBench V2) while reducing KV-cache usage by ~85–90%. Precise needle-retrieval tasks require an additional threshold-fallback mechanism (not in this release).

Quick start

pip install torch safetensors
python demo.py --ckpt weights/flashmemory_ds_v4.safetensors

Usage

from retriever import FlashMemoryRetriever

model = FlashMemoryRetriever.from_checkpoint(
    "weights/flashmemory_ds_v4.safetensors", device="cuda"
)

# hidden: [B, 4096] decode hidden state
# compressed_k: [B, N, 132] uint8 CSA keys
# positions: [B] int64 token positions

scores = model.ensemble(hidden, compressed_k, positions, mode="max")        # [B, N]
keep   = model.select_topk(hidden, compressed_k, positions, top_k=512)      # boolean mask

compressed_k format: each chunk = 128 bytes float8_e4m3 values + 4 bytes float32 scale. See make_mock_compressed_k() in demo.py.

Architecture

3-layer joint model (l10, l12, l20), 128 heads, 2048 LoRA rank. Per-layer sigmoid scores are ensembled (max or mean) per chunk.

hidden [B,4096] β†’ q-proj β†’ RoPE(YaRN) β†’ Hadamard β†’ q [B,128,128]
               β†’ weights_proj β†’ fused_w [B,128]
compressed_k    β†’ FP8 dequant β†’ k [B,N,128]

score = sigmoid( Ξ£( relu(k @ qα΅€) Β· fused_w ) )  ∈ [0,1]

Toy inference reference

toy_flashmemory_inference.py illustrates how the retriever drives memory recall during decode: every 64 steps it re-scores all chunks, and unselected ones are masked from attention (equivalent to "not recalled to GPU").

python toy_flashmemory_inference.py --ckpt weights/flashmemory_ds_v4.safetensors

The decoder is a few toy layers with random weights β€” it is not a real DeepSeek-V4. The retriever, scoring math, and decode-time control flow are real.

Files

File Purpose
retriever.py FlashMemoryRetriever model (torch-only, self-contained)
demo.py minimal demo with mock inputs
toy_flashmemory_inference.py toy sparse-decode loop
weights/flashmemory_ds_v4.safetensors trained weights (~510 MB)

Citation

If you use FlashMemory in your research, please cite:

@article{wang2026flashmemory,
  title   = {FlashMemory-DeepSeek-V4: Lightning Index Ultra-Long Context via Lookahead Sparse Attention},
  author  = {Yan Wang and Qifan Zhang and Jiachen Yu and Tian Liang and Dongyang Ma and
             Xiang Hu and Zibo Lin and Chunyang Li and Zhichao Wang and Jia Li and
             Yujiu Yang and Haitao Mi and Dong Yu},
  year    = {2026},
  journal = {arXiv preprint arXiv:2606.09079},
  url     = {https://huggingface.co/papers/2606.09079},
}

License

MIT

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Paper for libertywing/FlashMemory-Deepseek-V4