skhavin's picture
feat: initial release of proactive-cache v0.1.0
b786614
Raw
History Blame Contribute Delete
5.68 kB
"""
profiler.py — Offline attention pattern collection for any HuggingFace model.
Run once per model family to collect per-head attention distributions.
The resulting patterns are then clustered into prototypes (see prototypes.py)
which drive the O(n) eviction policy at inference time.
Usage:
from proactive_cache import profile_model
patterns = profile_model(model, tokenizer, corpus="wikitext", num_docs=50)
"""
from __future__ import annotations
import torch
import numpy as np
from typing import Optional, Union, List, Dict
from tqdm import tqdm
def profile_model(
model,
tokenizer,
corpus: Union[str, List[str]] = "wikitext",
num_docs: int = 50,
seq_len: int = 512,
output_attentions: bool = True,
) -> Dict:
"""
Collect per-head attention distributions over a calibration corpus.
Args:
model: A HuggingFace CausalLM model (any architecture).
tokenizer: Corresponding tokenizer.
corpus: Either a dataset name ("wikitext", "pg19") or a list
of raw text strings to profile on.
num_docs: Number of documents to sample for profiling.
seq_len: Sequence length for profiling chunks.
output_attentions: Whether to collect full attention matrices.
Returns:
patterns: Dict mapping ``(layer_idx, head_idx) → np.ndarray`` of shape
``(num_docs, seq_len)`` — mean attention received per position.
"""
model.eval()
device = next(model.parameters()).device
# ── Load corpus ───────────────────────────────────────────────────────────
texts = _load_corpus(corpus, num_docs)
print(f"[ProactiveCache] Profiling on {len(texts)} documents, seq_len={seq_len}")
all_patterns: List[Dict] = []
for text in tqdm(texts, desc="Profiling attention patterns"):
enc = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=seq_len,
)
input_ids = enc["input_ids"].to(device)
if input_ids.shape[1] < 32:
continue
with torch.no_grad():
out = model(
input_ids,
output_attentions=output_attentions,
use_cache=False,
)
if out.attentions is None:
raise RuntimeError(
"Model did not return attention weights. "
"Ensure the model config has `output_attentions=True` support, "
"or set output_attentions=True in the model config."
)
doc_pattern = {}
for layer_idx, attn in enumerate(out.attentions):
# attn: (batch=1, num_heads, seq_len, seq_len)
attn_np = attn[0].float().cpu().numpy() # (heads, seq, seq)
num_heads, slen, _ = attn_np.shape
for head_idx in range(num_heads):
# Mean attention received at each position (column-wise mean)
received = attn_np[head_idx].mean(axis=0) # (seq_len,)
# Pad / truncate to fixed seq_len
padded = np.zeros(seq_len, dtype=np.float32)
padded[:min(slen, seq_len)] = received[:seq_len]
doc_pattern[(layer_idx, head_idx)] = padded
all_patterns.append(doc_pattern)
if not all_patterns:
raise RuntimeError("No valid documents found in corpus. Try increasing num_docs.")
print(f"[ProactiveCache] Profiled {len(all_patterns)} documents across "
f"{len(all_patterns[0])} (layer, head) pairs.")
return all_patterns
def _load_corpus(corpus: Union[str, List[str]], num_docs: int) -> List[str]:
"""Load a text corpus for profiling."""
if isinstance(corpus, list):
return corpus[:num_docs]
if corpus == "wikitext":
return _load_wikitext(num_docs)
elif corpus == "pg19":
return _load_pg19(num_docs)
else:
# Try to load as a HuggingFace dataset name
return _load_hf_dataset(corpus, num_docs)
def _load_wikitext(num_docs: int) -> List[str]:
from datasets import load_dataset
ds = load_dataset("wikitext", "wikitext-103-v1", split="validation", streaming=True)
texts, current = [], ""
for item in ds:
t = item["text"].strip()
if t:
current += " " + t
if len(current) > 2000:
texts.append(current.strip())
current = ""
if len(texts) >= num_docs:
break
return texts
def _load_pg19(num_docs: int) -> List[str]:
from datasets import load_dataset
ds = load_dataset("emozilla/pg19", split="test", streaming=True)
texts = []
for item in ds:
text = item.get("text", "")
if len(text) > 500:
texts.append(text[:4000])
if len(texts) >= num_docs:
break
return texts
def _load_hf_dataset(name: str, num_docs: int) -> List[str]:
from datasets import load_dataset
try:
ds = load_dataset(name, split="train", streaming=True)
texts = []
for item in ds:
# Try common text field names
for field in ["text", "content", "body", "sentence"]:
if field in item and isinstance(item[field], str) and len(item[field]) > 100:
texts.append(item[field])
break
if len(texts) >= num_docs:
break
return texts
except Exception as e:
raise ValueError(f"Could not load corpus '{name}': {e}")