File size: 3,580 Bytes
b786614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
prototypes.py — Build and manage the offline prototype library.

The prototype library maps each (layer, head) pair to a set of K centroid
attention-distribution vectors, learned via K-Means from the profiling corpus.
At inference time, these centroids drive the O(n) scoring function without
any query lookups.
"""

from __future__ import annotations
import os
import pickle
import numpy as np
from typing import Dict, Optional, List
from sklearn.cluster import KMeans


def build_prototypes(
    patterns: List[Dict],
    n_clusters: int = 4,
    max_seq_len: int = 512,
    random_state: int = 42,
) -> Dict:
    """
    Cluster per-head attention patterns into prototype centroids.

    Args:
        patterns:    Output of ``profile_model()`` — list of dicts mapping
                     ``(layer, head) → np.ndarray`` of shape ``(seq_len,)``.
        n_clusters:  Number of K-Means clusters per head (default 4).
        max_seq_len: Maximum sequence length to include in clustering.
        random_state: Random seed for reproducibility.

    Returns:
        prototypes: Dict mapping ``(layer, head) → {"centroids": np.ndarray}``
                    where centroids has shape ``(n_clusters, max_seq_len)``.
    """
    if not patterns:
        raise ValueError("patterns list is empty. Run profile_model() first.")

    keys = sorted(patterns[0].keys())
    prototypes = {}

    for (layer, head) in keys:
        data = np.array([
            p[(layer, head)] for p in patterns
            if (layer, head) in p
        ])  # shape: (num_docs, max_seq_len)

        if len(data) == 0:
            continue

        k = min(n_clusters, len(data))
        kmeans = KMeans(n_clusters=k, random_state=random_state, n_init=10)
        kmeans.fit(data)

        prototypes[(layer, head)] = {
            "centroids": kmeans.cluster_centers_.astype(np.float32),
            "labels":    kmeans.labels_,
            "inertia":   float(kmeans.inertia_),
        }

    return prototypes


def save_prototypes(prototypes: Dict, path: str) -> None:
    """Serialize prototypes to disk."""
    os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
    with open(path, "wb") as f:
        pickle.dump(prototypes, f)
    print(f"[ProactiveCache] Prototypes saved to {path}")


def load_prototypes(path: str) -> Dict:
    """Load prototypes from disk."""
    if not os.path.exists(path):
        raise FileNotFoundError(
            f"Prototype file not found: {path}\n"
            "Run ProactiveCache.profile(model, ..., save_path='{path}') first."
        )
    with open(path, "rb") as f:
        prototypes = pickle.load(f)
    print(f"[ProactiveCache] Loaded {len(prototypes)} prototypes from {path}")
    return prototypes


def prototype_summary(prototypes: Dict) -> str:
    """Return a human-readable summary of a prototype library."""
    num_pairs = len(prototypes)
    if num_pairs == 0:
        return "Empty prototype library."

    layers = sorted(set(layer for (layer, _) in prototypes))
    heads_per_layer = sorted(set(head for (_, head) in prototypes))
    sample_key = next(iter(prototypes))
    n_clusters = prototypes[sample_key]["centroids"].shape[0]
    seq_len = prototypes[sample_key]["centroids"].shape[1]

    return (
        f"ProactiveCache Prototype Library\n"
        f"  Layers:          {len(layers)} ({layers[0]}{layers[-1]})\n"
        f"  Heads per layer: {len(heads_per_layer)}\n"
        f"  Total (L, H):    {num_pairs}\n"
        f"  Clusters/head:   {n_clusters}\n"
        f"  Profile seq_len: {seq_len}\n"
    )