skhavin's picture
fix: change ProactiveCachePress to inherit from ScorerPress for KVPress compatibility
f044920
Raw
History Blame Contribute Delete
3.76 kB
"""
press.py — KVPress-compatible wrapper for ProactiveCache eviction.
Implements the BasePress API from NVIDIA's kvpress library so that
ProactiveCache can be benchmarked directly against the 20+ methods
in the KVPress standard evaluation suite.
Usage (requires: pip install kvpress):
from proactive_cache import ProactiveCachePress
press = ProactiveCachePress(compression_ratio=0.75, prototype_path="...")
# Use with kvpress evaluation harness
"""
from __future__ import annotations
import os
import pickle
import torch
import numpy as np
from dataclasses import dataclass
from typing import Optional
from .eviction import score_tokens
# Shim for Python 3.13+ which removed the 'pipes' module (needed by fire/kvpress)
try:
import pipes
except ImportError:
import sys, shlex
sys.modules['pipes'] = shlex
# ── KVPress ScorerPress compatibility shim ────────────────────────────────────
try:
from kvpress import ScorerPress
_KVPRESS_AVAILABLE = True
except ImportError:
_KVPRESS_AVAILABLE = False
class ScorerPress:
"""Minimal shim — allows import without kvpress installed."""
def __init__(self):
self.compression_ratio = 0.0
def score(self, module, hidden_states, keys, values, attentions, kwargs):
raise NotImplementedError("Install kvpress: pip install kvpress")
@dataclass
class ProactiveCachePress(ScorerPress):
"""
KVPress-compatible Proactive KV Cache eviction plugin.
Implements the BasePress.score() hook, called once per attention layer
during prefill. Returns a scalar importance score per token position —
higher score = keep, lower score = evict (following KVPress convention).
Args:
compression_ratio: Fraction of tokens to EVICT [0.0, 1.0).
e.g. 0.75 → keep 25% of the KV cache (budget = seq_len * 0.25).
prototype_path: Path to a prototypes .pkl file from ``ProactiveCache.profile()``.
If None, falls back to attention-sink + recency-only scoring.
Example:
press = ProactiveCachePress(compression_ratio=0.75, prototype_path="protos.pkl")
"""
compression_ratio: float = 0.5
prototype_path: Optional[str] = None
def __post_init__(self):
self._prototypes = None
if self.prototype_path and os.path.exists(self.prototype_path):
with open(self.prototype_path, "rb") as f:
self._prototypes = pickle.load(f)
print(f"[ProactiveCachePress] Loaded {len(self._prototypes)} prototypes "
f"from {self.prototype_path}")
else:
print("[ProactiveCachePress] No prototypes loaded — using sink+recency scoring.")
def score(self, module, hidden_states, keys, values, attentions, kwargs):
"""
KVPress hook: called once per attention layer during the prefill pass.
Returns:
scores: (batch, num_heads, seq_len) float tensor.
Higher = more important. KVPress will keep the top-K tokens
where K = seq_len * (1 - compression_ratio).
"""
batch_size, num_heads, seq_len, head_dim = keys.shape
budget = max(1, int(seq_len * (1.0 - self.compression_ratio)))
device = keys.device
# Build position scores (O(n), query-free)
proto_scores = score_tokens(self._prototypes, seq_len, budget)
proto_tensor = torch.tensor(proto_scores, dtype=torch.float32, device=device)
# Broadcast (1, 1, seq_len) → (batch, num_heads, seq_len)
scores = proto_tensor.unsqueeze(0).unsqueeze(0).expand(batch_size, num_heads, seq_len)
return scores.contiguous()