Spaces:
Sleeping
Sleeping
| """ | |
| profiler.py — Offline attention pattern collection for any HuggingFace model. | |
| Run once per model family to collect per-head attention distributions. | |
| The resulting patterns are then clustered into prototypes (see prototypes.py) | |
| which drive the O(n) eviction policy at inference time. | |
| Usage: | |
| from proactive_cache import profile_model | |
| patterns = profile_model(model, tokenizer, corpus="wikitext", num_docs=50) | |
| """ | |
| from __future__ import annotations | |
| import torch | |
| import numpy as np | |
| from typing import Optional, Union, List, Dict | |
| from tqdm import tqdm | |
| def profile_model( | |
| model, | |
| tokenizer, | |
| corpus: Union[str, List[str]] = "wikitext", | |
| num_docs: int = 50, | |
| seq_len: int = 512, | |
| output_attentions: bool = True, | |
| ) -> Dict: | |
| """ | |
| Collect per-head attention distributions over a calibration corpus. | |
| Args: | |
| model: A HuggingFace CausalLM model (any architecture). | |
| tokenizer: Corresponding tokenizer. | |
| corpus: Either a dataset name ("wikitext", "pg19") or a list | |
| of raw text strings to profile on. | |
| num_docs: Number of documents to sample for profiling. | |
| seq_len: Sequence length for profiling chunks. | |
| output_attentions: Whether to collect full attention matrices. | |
| Returns: | |
| patterns: Dict mapping ``(layer_idx, head_idx) → np.ndarray`` of shape | |
| ``(num_docs, seq_len)`` — mean attention received per position. | |
| """ | |
| model.eval() | |
| device = next(model.parameters()).device | |
| # ── Load corpus ─────────────────────────────────────────────────────────── | |
| texts = _load_corpus(corpus, num_docs) | |
| print(f"[ProactiveCache] Profiling on {len(texts)} documents, seq_len={seq_len}") | |
| all_patterns: List[Dict] = [] | |
| for text in tqdm(texts, desc="Profiling attention patterns"): | |
| enc = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=seq_len, | |
| ) | |
| input_ids = enc["input_ids"].to(device) | |
| if input_ids.shape[1] < 32: | |
| continue | |
| with torch.no_grad(): | |
| out = model( | |
| input_ids, | |
| output_attentions=output_attentions, | |
| use_cache=False, | |
| ) | |
| if out.attentions is None: | |
| raise RuntimeError( | |
| "Model did not return attention weights. " | |
| "Ensure the model config has `output_attentions=True` support, " | |
| "or set output_attentions=True in the model config." | |
| ) | |
| doc_pattern = {} | |
| for layer_idx, attn in enumerate(out.attentions): | |
| # attn: (batch=1, num_heads, seq_len, seq_len) | |
| attn_np = attn[0].float().cpu().numpy() # (heads, seq, seq) | |
| num_heads, slen, _ = attn_np.shape | |
| for head_idx in range(num_heads): | |
| # Mean attention received at each position (column-wise mean) | |
| received = attn_np[head_idx].mean(axis=0) # (seq_len,) | |
| # Pad / truncate to fixed seq_len | |
| padded = np.zeros(seq_len, dtype=np.float32) | |
| padded[:min(slen, seq_len)] = received[:seq_len] | |
| doc_pattern[(layer_idx, head_idx)] = padded | |
| all_patterns.append(doc_pattern) | |
| if not all_patterns: | |
| raise RuntimeError("No valid documents found in corpus. Try increasing num_docs.") | |
| print(f"[ProactiveCache] Profiled {len(all_patterns)} documents across " | |
| f"{len(all_patterns[0])} (layer, head) pairs.") | |
| return all_patterns | |
| def _load_corpus(corpus: Union[str, List[str]], num_docs: int) -> List[str]: | |
| """Load a text corpus for profiling.""" | |
| if isinstance(corpus, list): | |
| return corpus[:num_docs] | |
| if corpus == "wikitext": | |
| return _load_wikitext(num_docs) | |
| elif corpus == "pg19": | |
| return _load_pg19(num_docs) | |
| else: | |
| # Try to load as a HuggingFace dataset name | |
| return _load_hf_dataset(corpus, num_docs) | |
| def _load_wikitext(num_docs: int) -> List[str]: | |
| from datasets import load_dataset | |
| ds = load_dataset("wikitext", "wikitext-103-v1", split="validation", streaming=True) | |
| texts, current = [], "" | |
| for item in ds: | |
| t = item["text"].strip() | |
| if t: | |
| current += " " + t | |
| if len(current) > 2000: | |
| texts.append(current.strip()) | |
| current = "" | |
| if len(texts) >= num_docs: | |
| break | |
| return texts | |
| def _load_pg19(num_docs: int) -> List[str]: | |
| from datasets import load_dataset | |
| ds = load_dataset("emozilla/pg19", split="test", streaming=True) | |
| texts = [] | |
| for item in ds: | |
| text = item.get("text", "") | |
| if len(text) > 500: | |
| texts.append(text[:4000]) | |
| if len(texts) >= num_docs: | |
| break | |
| return texts | |
| def _load_hf_dataset(name: str, num_docs: int) -> List[str]: | |
| from datasets import load_dataset | |
| try: | |
| ds = load_dataset(name, split="train", streaming=True) | |
| texts = [] | |
| for item in ds: | |
| # Try common text field names | |
| for field in ["text", "content", "body", "sentence"]: | |
| if field in item and isinstance(item[field], str) and len(item[field]) > 100: | |
| texts.append(item[field]) | |
| break | |
| if len(texts) >= num_docs: | |
| break | |
| return texts | |
| except Exception as e: | |
| raise ValueError(f"Could not load corpus '{name}': {e}") | |