| from dataclasses import dataclass |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import torch |
|
|
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.utils import logging |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| @dataclass |
| class Cache: |
| """ |
| Base, abstract class for all caches. The actual data structure is specific to each subclass. |
| """ |
|
|
| def update( |
| self, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| layer_idx: int, |
| cache_kwargs: Optional[Dict[str, Any]] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
| |
| Parameters: |
| key_states (`torch.Tensor`): |
| The new key states to cache. |
| value_states (`torch.Tensor`): |
| The new value states to cache. |
| layer_idx (`int`): |
| The index of the layer to cache the states for. |
| cache_kwargs (`Dict[str, Any]`, `optional`): |
| Additional arguments for the cache subclass. These are specific to each subclass and allow new types of |
| cache to be created. |
| |
| Return: |
| A tuple containing the updated key and value states. |
| """ |
| raise NotImplementedError("Make sure to implement `update` in a subclass.") |
|
|
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
| """Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
| |
| raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") |
|
|
| def get_max_length(self) -> Optional[int]: |
| """Returns the maximum sequence length of the cached states, if there is any.""" |
| raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") |
|
|
| def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: |
| """Given the sequence length of the new inputs, returns the usable length of the cache.""" |
| |
| |
| |
| max_length = self.get_max_length() |
| previous_seq_length = self.get_seq_length(layer_idx) |
| if max_length is not None and previous_seq_length + new_seq_length > max_length: |
| return max_length - new_seq_length |
| return previous_seq_length |
|
|
| def reorder_cache(self, beam_idx: torch.LongTensor): |
| """Reorders the cache for beam search, given the selected beam indices.""" |
| for layer_idx in range(len(self.key_cache)): |
| device = self.key_cache[layer_idx].device |
| self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) |
| device = self.value_cache[layer_idx].device |
| self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
| @property |
| def seen_tokens(self): |
| logger.warning_once( |
| "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " |
| "model input instead." |
| ) |
| if hasattr(self, "_seen_tokens"): |
| return self._seen_tokens |
| else: |
| return None |
|
|
|
|
| class DynamicCache(Cache): |
| """ |
| A cache that grows dynamically as more tokens are generated. This is the default for generative models. |
| |
| It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is |
| `[batch_size, num_heads, seq_len, head_dim]`. |
| """ |
|
|
| def __init__(self) -> None: |
| self.key_cache: List[torch.Tensor] = [] |
| self.value_cache: List[torch.Tensor] = [] |
| self._seen_tokens = 0 |
|
|
| def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: |
| """ |
| Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the |
| sequence length. |
| """ |
| if layer_idx < len(self): |
| return (self.key_cache[layer_idx], self.value_cache[layer_idx]) |
| else: |
| raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") |
|
|
| def __iter__(self): |
| """ |
| Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over |
| keys and values |
| """ |
| for layer_idx in range(len(self)): |
| yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) |
|
|
| def __len__(self): |
| """ |
| Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds |
| to the number of layers in the model. |
| """ |
| return len(self.key_cache) |
|
|
| def update( |
| self, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| layer_idx: int, |
| cache_kwargs: Optional[Dict[str, Any]] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
| |
| Parameters: |
| key_states (`torch.Tensor`): |
| The new key states to cache. |
| value_states (`torch.Tensor`): |
| The new value states to cache. |
| layer_idx (`int`): |
| The index of the layer to cache the states for. |
| cache_kwargs (`Dict[str, Any]`, `optional`): |
| Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. |
| |
| Return: |
| A tuple containing the updated key and value states. |
| """ |
| |
| if layer_idx == 0: |
| self._seen_tokens += key_states.shape[-2] |
|
|
| |
| if len(self.key_cache) <= layer_idx: |
| self.key_cache.append(key_states) |
| self.value_cache.append(value_states) |
| else: |
| self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
| self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
|
|
| return self.key_cache[layer_idx], self.value_cache[layer_idx] |
|
|
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
| """Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
| |
| if len(self.key_cache) <= layer_idx: |
| return 0 |
| return self.key_cache[layer_idx].shape[-2] |
|
|
| def get_max_length(self) -> Optional[int]: |
| """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" |
| return None |
|
|
| def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: |
| """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" |
| legacy_cache = () |
| for layer_idx in range(len(self)): |
| legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) |
| return legacy_cache |
|
|
| @classmethod |
| def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": |
| """Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" |
| cache = cls() |
| if past_key_values is not None: |
| for layer_idx in range(len(past_key_values)): |
| key_states, value_states = past_key_values[layer_idx] |
| cache.update(key_states, value_states, layer_idx) |
| return cache |
|
|
|
|
| class SinkCache(Cache): |
| """ |
| A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to |
| generate beyond the length of its context window, without losing fluency in the conversation. As it discards past |
| tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. |
| |
| It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is |
| `[batch_size, num_heads, seq_len, head_dim]`. |
| |
| Parameters: |
| window_length (`int`): |
| The length of the context window. |
| num_sink_tokens (`int`): |
| The number of sink tokens. See the original paper for more information. |
| """ |
|
|
| def __init__(self, window_length: int, num_sink_tokens: int) -> None: |
| self.key_cache: List[torch.Tensor] = [] |
| self.value_cache: List[torch.Tensor] = [] |
| self.window_length = window_length |
| self.num_sink_tokens = num_sink_tokens |
| self.cos_sin_rerotation_cache = {} |
| self._cos_cache = None |
| self._sin_cache = None |
| self._seen_tokens = 0 |
|
|
| @staticmethod |
| def _rotate_half(x): |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
| def _apply_key_rotary_pos_emb( |
| self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
| ) -> torch.Tensor: |
| rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) |
| return rotated_key_states |
|
|
| def _get_rerotation_cos_sin( |
| self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| if key_states.shape[-2] not in self.cos_sin_rerotation_cache: |
| |
| cos = cos.to(torch.float32) |
| sin = sin.to(torch.float32) |
|
|
| |
| original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] |
| shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] |
| original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] |
| shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] |
| rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin |
| rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin |
|
|
| self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( |
| rerotation_cos.to(key_states.dtype).unsqueeze(0), |
| rerotation_sin.to(key_states.dtype).unsqueeze(0), |
| ) |
| return self.cos_sin_rerotation_cache[key_states.shape[-2]] |
|
|
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
| """Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
| |
| |
| if len(self.key_cache) <= layer_idx: |
| return 0 |
| return self.key_cache[layer_idx].shape[-2] |
|
|
| def get_max_length(self) -> Optional[int]: |
| """Returns the maximum sequence length of the cached states.""" |
| return self.window_length |
|
|
| def update( |
| self, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| layer_idx: int, |
| cache_kwargs: Optional[Dict[str, Any]] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
| |
| Parameters: |
| key_states (`torch.Tensor`): |
| The new key states to cache. |
| value_states (`torch.Tensor`): |
| The new value states to cache. |
| layer_idx (`int`): |
| The index of the layer to cache the states for. |
| cache_kwargs (`Dict[str, Any]`, `optional`): |
| Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, |
| `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the |
| rotation as the tokens are shifted. |
| |
| Return: |
| A tuple containing the updated key and value states. |
| """ |
| |
| |
| sin = cache_kwargs.get("sin") |
| cos = cache_kwargs.get("cos") |
| partial_rotation_size = cache_kwargs.get("partial_rotation_size") |
| using_rope = cos is not None and sin is not None |
|
|
| |
| if layer_idx == 0: |
| self._seen_tokens += key_states.shape[-2] |
|
|
| |
| if using_rope and layer_idx == 0: |
| |
| |
| if cos.dim() == 2: |
| self._cos_cache = cos |
| self._sin_cache = sin |
| else: |
| if self._cos_cache is None: |
| self._cos_cache = cos[0, ...] |
| self._sin_cache = sin[0, ...] |
| elif self._cos_cache.shape[0] < self.window_length: |
| self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) |
| self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) |
|
|
| |
| if len(self.key_cache) <= layer_idx: |
| |
| self.key_cache.append(key_states) |
| self.value_cache.append(value_states) |
|
|
| elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: |
| |
| self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
| self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
|
|
| else: |
| |
| keys_to_keep = self.key_cache[layer_idx][ |
| :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : |
| ] |
|
|
| |
| if using_rope: |
| rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( |
| key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length] |
| ) |
| if partial_rotation_size is not None: |
| keys_to_keep, keys_pass = ( |
| keys_to_keep[..., :partial_rotation_size], |
| keys_to_keep[..., partial_rotation_size:], |
| ) |
| keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin) |
| if partial_rotation_size is not None: |
| keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) |
|
|
| |
| sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] |
| self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) |
|
|
| sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] |
| values_to_keep = self.value_cache[layer_idx][ |
| :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : |
| ] |
| self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2) |
|
|
| return self.key_cache[layer_idx], self.value_cache[layer_idx] |
|
|
|
|
| class StaticCache(Cache): |
| """ |
| Static Cache class to be used with `torch.compile(model)`. |
| |
| Parameters: |
| config (`PretrainedConfig): |
| The configuration file defining the shape-related attributes required to initialize the static cache. |
| max_batch_size (`int`): |
| The maximum batch size with which the model will be used. |
| max_cache_len (`int`): |
| The maximum sequence length with which the model will be used. |
| device (`torch.device`): |
| The device on which the cache should be initialized. Should be the same as the layer. |
| dtype (*optional*, defaults to `torch.float32`): |
| The default `dtype` to use when initializing the layer. |
| """ |
|
|
| def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: |
| super().__init__() |
| self.max_batch_size = max_batch_size |
| self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len |
| |
| self.head_dim = ( |
| config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads |
| ) |
|
|
| self.dtype = dtype if dtype is not None else torch.float32 |
| self.num_key_value_heads = ( |
| config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads |
| ) |
|
|
| self.key_cache: List[torch.Tensor] = [] |
| self.value_cache: List[torch.Tensor] = [] |
| cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) |
| for _ in range(config.num_hidden_layers): |
| |
| |
| new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) |
| new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) |
| torch._dynamo.mark_static_address(new_layer_key_cache) |
| torch._dynamo.mark_static_address(new_layer_value_cache) |
| self.key_cache.append(new_layer_key_cache) |
| self.value_cache.append(new_layer_value_cache) |
|
|
| def update( |
| self, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| layer_idx: int, |
| cache_kwargs: Optional[Dict[str, Any]] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
| It is VERY important to index using a tensor, otherwise you introduce a copy to the device. |
| |
| Parameters: |
| key_states (`torch.Tensor`): |
| The new key states to cache. |
| value_states (`torch.Tensor`): |
| The new value states to cache. |
| layer_idx (`int`): |
| The index of the layer to cache the states for. |
| cache_kwargs (`Dict[str, Any]`, `optional`): |
| Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input |
| to know how where to write in the cache. |
| |
| Return: |
| A tuple containing the updated key and value states. |
| """ |
| cache_position = cache_kwargs.get("cache_position") |
| k_out = self.key_cache[layer_idx] |
| v_out = self.value_cache[layer_idx] |
|
|
| k_out[:, :, cache_position] = key_states |
| v_out[:, :, cache_position] = value_states |
|
|
| return k_out, v_out |
|
|
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
| """Returns the sequence length of the cached states that were seen by the model.""" |
| |
| |
| |
| return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() |
|
|
| def get_max_length(self) -> Optional[int]: |
| """Returns the maximum sequence length of the cached states.""" |
| return self.max_cache_len |
|
|
| def reset(self): |
| """Resets the cache values while preserving the objects""" |
| for layer_idx in range(len(self.key_cache)): |
| |
| self.key_cache[layer_idx].zero_() |
| self.value_cache[layer_idx].zero_() |