File size: 1,499 Bytes
b786614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""
utils.py — DynamicCache compatibility helpers.

Handles the transformers DynamicCache ↔ legacy tuple conversion cleanly
across transformers versions 4.38+.
"""

from __future__ import annotations
import torch
from typing import Tuple, Union


# Type aliases
KVTuple = Tuple[Tuple[torch.Tensor, torch.Tensor], ...]


def to_tuple_kv(past_key_values) -> KVTuple:
    """Normalize a DynamicCache or legacy tuple to a tuple of (k, v) pairs."""
    if hasattr(past_key_values, "to_legacy_cache"):
        return past_key_values.to_legacy_cache()
    return tuple(past_key_values)


def to_dynamic_cache(kv_tuple: KVTuple):
    """Convert a (k, v) tuple back to DynamicCache for models that require it."""
    try:
        from transformers import DynamicCache
        return DynamicCache.from_legacy_cache(kv_tuple)
    except (ImportError, AttributeError):
        # Older transformers — raw tuple is fine
        return kv_tuple


def get_device(model) -> torch.device:
    """Get the primary device of a model."""
    return next(model.parameters()).device


def get_num_layers(past_key_values) -> int:
    """Return the number of transformer layers in a KV cache."""
    kv = to_tuple_kv(past_key_values)
    return len(kv)


def get_seq_len(past_key_values) -> int:
    """Return the current sequence length stored in a KV cache."""
    kv = to_tuple_kv(past_key_values)
    if len(kv) == 0:
        return 0
    # Shape: (batch, num_heads, seq_len, head_dim)
    return kv[0][0].shape[2]