skhavin's picture
fix: resolve budget 256 eviction anomaly by implementing robust proportional split-budget boosting
10e074b
Raw
History Blame Contribute Delete
4.39 kB
"""
eviction.py — Token scoring and KV cache pruning.
Core O(n) eviction policy:
1. Score each token position using offline-profiled prototype centroids.
2. Keep the top-budget tokens (attention sink + recency anchors + semantic prototypes).
3. Prune the KV cache to exactly `budget` positions.
This is coordinate-free and RoPE-compatible: we only select positions, never
reorder them, so relative position encodings remain valid.
"""
from __future__ import annotations
import torch
import numpy as np
from typing import Optional, Dict, Tuple, List
from .utils import to_tuple_kv, to_dynamic_cache
def score_tokens(
prototypes: Optional[Dict],
seq_len: int,
budget: int,
) -> np.ndarray:
"""
Score all token positions using prototype centroid histograms.
Algorithm (O(n) per call):
- For each profiled (layer, head), accumulate the centroid attention
histogram as a distance-weighted score over token positions.
- Boost attention sink (token 0) unconditionally.
- Boost a proportional recency window at the tail.
- Add a small deterministic tiebreaker (position index).
Args:
prototypes: Output of ``build_prototypes()``. If None, falls back to
uniform scoring (no-op — keep all tokens equally).
seq_len: Current sequence length to score.
budget: Target number of tokens to keep.
Returns:
scores: (seq_len,) float64 array. Higher = more important.
"""
scores = np.zeros(seq_len, dtype=np.float64)
if prototypes is not None:
for (layer, head), data in prototypes.items():
centroid = data["centroids"][0] # shape: (profile_seq_len,)
max_d = min(len(centroid), seq_len)
if max_d == 0:
continue
cumsum = np.cumsum(centroid[:max_d])
for p in range(seq_len):
reach = min(max_d, seq_len - p)
if reach > 0:
scores[p] += cumsum[reach - 1]
# ── Robust Split-Budget Boosting (Sinks + 50% Recency + 50% Semantic) ─────
# Ensures perfect stability on relative position models (like LLaMA/RoPE)
# by guaranteeing a large contiguous local context window and a secure sink.
peak = scores.max() if scores.max() > 0 else 1.0
# 1. Boost Attention Sinks (first 4 tokens) securely
for i in range(min(4, seq_len)):
scores[i] += peak * 100.0
# 2. Boost Recency Window (50% of the budget) securely
recency_window = min(max(8, budget // 2), seq_len)
for i in range(recency_window):
scores[seq_len - 1 - i] += peak * 50.0
# ── Deterministic tiebreaker (prefer later tokens among equals) ───────────
scores += np.linspace(0, 1e-4, seq_len)
return scores
def select_indices(scores: np.ndarray, budget: int) -> List[int]:
"""Return the top-budget indices, sorted in ascending order (preserves sequence order)."""
actual_budget = min(budget, len(scores))
top = np.argsort(scores)[-actual_budget:]
return sorted(top.tolist())
def prune_kv_cache(
past_key_values,
indices: List[int],
device: torch.device,
):
"""
Prune a KV cache to the given token indices.
Args:
past_key_values: DynamicCache or legacy tuple from a model forward pass.
indices: Sorted list of token indices to keep.
device: CUDA/CPU device for the index tensor.
Returns:
Pruned KV cache in the same format the model expects
(DynamicCache if transformers ≥ 4.38, else tuple).
"""
idx_t = torch.tensor(indices, dtype=torch.long, device=device)
kv_tuple = to_tuple_kv(past_key_values)
pruned = tuple(
(k.index_select(2, idx_t), v.index_select(2, idx_t))
for k, v in kv_tuple
)
return to_dynamic_cache(pruned)
def evict(
past_key_values,
budget: int,
prototypes: Optional[Dict],
seq_len: int,
device: torch.device,
):
"""
One-shot eviction: score → select → prune.
If ``seq_len <= budget``, returns ``past_key_values`` unchanged.
"""
if seq_len <= budget:
return past_key_values
scores = score_tokens(prototypes, seq_len, budget)
indices = select_indices(scores, budget)
return prune_kv_cache(past_key_values, indices, device)