Spaces:
Sleeping
Sleeping
File size: 1,499 Bytes
b786614 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | """
utils.py — DynamicCache compatibility helpers.
Handles the transformers DynamicCache ↔ legacy tuple conversion cleanly
across transformers versions 4.38+.
"""
from __future__ import annotations
import torch
from typing import Tuple, Union
# Type aliases
KVTuple = Tuple[Tuple[torch.Tensor, torch.Tensor], ...]
def to_tuple_kv(past_key_values) -> KVTuple:
"""Normalize a DynamicCache or legacy tuple to a tuple of (k, v) pairs."""
if hasattr(past_key_values, "to_legacy_cache"):
return past_key_values.to_legacy_cache()
return tuple(past_key_values)
def to_dynamic_cache(kv_tuple: KVTuple):
"""Convert a (k, v) tuple back to DynamicCache for models that require it."""
try:
from transformers import DynamicCache
return DynamicCache.from_legacy_cache(kv_tuple)
except (ImportError, AttributeError):
# Older transformers — raw tuple is fine
return kv_tuple
def get_device(model) -> torch.device:
"""Get the primary device of a model."""
return next(model.parameters()).device
def get_num_layers(past_key_values) -> int:
"""Return the number of transformer layers in a KV cache."""
kv = to_tuple_kv(past_key_values)
return len(kv)
def get_seq_len(past_key_values) -> int:
"""Return the current sequence length stored in a KV cache."""
kv = to_tuple_kv(past_key_values)
if len(kv) == 0:
return 0
# Shape: (batch, num_heads, seq_len, head_dim)
return kv[0][0].shape[2]
|