Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Any | |
| import torch | |
| from transformers.cache_utils import Cache, _static_cache_update | |
| class StaticCache(Cache): | |
| """ | |
| Static Cache class to be used with `torch.compile(model)` and `torch.export()`. | |
| Parameters: | |
| config (`PretrainedConfig`): | |
| The configuration file defining the shape-related attributes required to initialize the static cache. | |
| max_batch_size (`int`): | |
| The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a | |
| smaller batch size is used. If you are manually setting the batch size, make sure to take into account the | |
| number of beams if you are running beam search | |
| max_cache_len (`int`, *optional*): | |
| The maximum sequence length with which the model will be used. | |
| device (`torch.device` or `str`, *optional*): | |
| The device on which the cache should be initialized. If you're using more than 1 computation device, you | |
| should pass the `layer_device_map` argument instead. | |
| dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): | |
| The default `dtype` to use when initializing the layer. | |
| layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): | |
| Mapping between the layers and its device. This is required when you are manually initializing the cache | |
| and the model is split between different gpus. You can know which layers mapped to which device by | |
| checking the associated device_map: `model.hf_device_map`. | |
| Example: | |
| ```python | |
| >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache | |
| >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") | |
| >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") | |
| >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt") | |
| >>> # Prepare a cache class and pass it to model's forward | |
| >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate | |
| >>> max_generated_length = inputs.input_ids.shape[1] + 10 | |
| >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) | |
| >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) | |
| >>> outputs.past_key_values # access cache filled with key/values from generation | |
| StaticCache() | |
| ``` | |
| """ | |
| is_compileable = True | |
| def __init__( | |
| self, | |
| max_batch_size: int, | |
| head_dim: int, | |
| num_key_value_heads: int, | |
| num_hidden_layers: int, | |
| max_cache_len: int | None = None, | |
| device: torch.device | str | None = None, | |
| dtype: torch.dtype = torch.float32, | |
| layer_device_map: dict[int, str | torch.device | int] | None = None, | |
| ) -> None: | |
| super().__init__() | |
| self.max_batch_size = max_batch_size | |
| self.max_cache_len = max_cache_len | |
| # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads | |
| self.head_dim = head_dim | |
| self._dtype = dtype | |
| self.num_key_value_heads = num_key_value_heads | |
| self.num_hidden_layers = num_hidden_layers | |
| self.key_cache: list[torch.Tensor] = [] | |
| self.value_cache: list[torch.Tensor] = [] | |
| # Note: There will be significant perf decrease if switching to use 5D tensors instead. | |
| cache_shape = ( | |
| self.max_batch_size, | |
| self.num_key_value_heads, | |
| self.max_cache_len, | |
| self.head_dim, | |
| ) | |
| device = torch.device(device) if device is not None else None | |
| for idx in range(self.num_hidden_layers): | |
| if layer_device_map is not None: | |
| layer_device = layer_device_map[idx] | |
| else: | |
| layer_device = device | |
| new_layer_key_cache = torch.zeros( | |
| cache_shape, dtype=self._dtype, device=layer_device | |
| ) | |
| new_layer_value_cache = torch.zeros( | |
| cache_shape, dtype=self._dtype, device=layer_device | |
| ) | |
| # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, | |
| # preventing compiled graph breaks when updating the cache. | |
| torch._dynamo.mark_static_address(new_layer_key_cache) | |
| torch._dynamo.mark_static_address(new_layer_value_cache) | |
| self.key_cache.append(new_layer_key_cache) | |
| self.value_cache.append(new_layer_value_cache) | |
| def update( | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| layer_idx: int, | |
| cache_kwargs: dict[str, Any] | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | |
| It is VERY important to index using a tensor, otherwise you introduce a copy to the device. | |
| Parameters: | |
| key_states (`torch.Tensor`): | |
| The new key states to cache. | |
| value_states (`torch.Tensor`): | |
| The new value states to cache. | |
| layer_idx (`int`): | |
| The index of the layer to cache the states for. | |
| cache_kwargs (`Dict[str, Any]`, `optional`): | |
| Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input | |
| to know how where to write in the cache. | |
| Return: | |
| A tuple containing the updated key and value states. | |
| """ | |
| if cache_kwargs is None: | |
| cache_kwargs = {} | |
| key_states = key_states.to(self.key_cache[layer_idx].dtype) | |
| value_states = value_states.to(self.value_cache[layer_idx].dtype) | |
| return _static_cache_update( | |
| self.key_cache[layer_idx], | |
| self.value_cache[layer_idx], | |
| key_states, | |
| value_states, | |
| cache_kwargs.get("cache_position"), | |
| ) | |
| def get_seq_length(self, layer_idx: int | None = 0) -> int: | |
| """Returns the sequence length of the cached states that were seen by the model.""" | |
| # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's | |
| # limit the check to the first batch member and head dimension. | |
| # TODO: deprecate this function in favor of `cache_position` | |
| return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() | |
| def get_max_cache_shape(self) -> int | None: | |
| return self.max_cache_len | |
| def reset(self): | |
| """Resets the cache values while preserving the objects""" | |
| for layer_idx in range(len(self.key_cache)): | |
| # In-place ops prevent breaking the static address | |
| self.key_cache[layer_idx].zero_() | |
| self.value_cache[layer_idx].zero_() | |
| def get_mask_sizes( | |
| self, cache_position: torch.Tensor, layer_idx: int | |
| ) -> tuple[int, int]: | |
| """ | |
| Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for | |
| the given layer at `layer_idx`. | |
| The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), | |
| for each layer. | |
| """ | |
| kv_length = self.get_max_cache_shape() | |
| return kv_length, 0 | |
| class Cache: | |
| """ | |
| A cache used for storing hidden states produced by flash linear attention models. | |
| It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`. | |
| """ | |
| is_compileable = True | |
| def __init__(self, seen_tokens: int = 0) -> Cache: | |
| super().__init__() | |
| self.states: list[dict[str, Any]] = [] | |
| self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen | |
| def __getitem__(self, layer_idx: int) -> dict[str, Any]: | |
| if layer_idx < len(self): | |
| return self.states[layer_idx] | |
| else: | |
| raise KeyError( | |
| f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" | |
| ) | |
| def __iter__(self): | |
| for state in self.states: | |
| yield state | |
| def __len__(self): | |
| return len(self.states) | |
| def update( | |
| self, | |
| recurrent_state: torch.Tensor | None = None, | |
| attn_state: tuple[torch.Tensor, torch.Tensor] | None = None, | |
| conv_state: tuple[torch.Tensor] | None = None, | |
| ffn_state: torch.Tensor | None = None, | |
| layer_idx: int = 0, | |
| offset: int | None = 1, | |
| cache_kwargs: dict | None = None, | |
| ): | |
| """ | |
| Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`. | |
| Args: | |
| recurrent_state (`torch.Tensor`, `optional`): | |
| The new recurrent state to cache. | |
| attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`): | |
| The new attention key/value states to cache. | |
| conv_state (`Tuple[torch.Tensor]`, `optional`): | |
| The new convolution state to cache. | |
| layer_idx (`int`, defaults to 0): | |
| The index of the layer to cache the states for. | |
| offset (`int`, `optional`, defaults to 1): | |
| The number of new tokens being processed. | |
| cache_kwargs (`Dict[str, Any]`, `optional`): | |
| Additional arguments for the cache subclass. | |
| Return: | |
| Dictionary of the updated state. | |
| """ | |
| # Update the number of seen tokens | |
| if layer_idx == 0: | |
| self._seen_tokens += offset | |
| if attn_state is not None: | |
| input_size = attn_state[0].shape[-2] | |
| window_size = cache_kwargs.get("window_size", None) | |
| if not isinstance(attn_state, Tuple) or len(attn_state) != 2: | |
| raise ValueError( | |
| "`attn_state` must be a tuple of two tensors for key/value states" | |
| ) | |
| if len(self.states) <= layer_idx: | |
| if attn_state is not None: | |
| if window_size is not None and input_size > window_size: | |
| attn_state = ( | |
| attn_state[0][..., -window_size:, :].contiguous(), | |
| attn_state[1][..., -window_size:, :].contiguous(), | |
| ) | |
| state = dict( | |
| recurrent_state=recurrent_state, | |
| attn_state=attn_state, | |
| conv_state=conv_state, | |
| ffn_state=ffn_state, | |
| ) | |
| self.states.append(state) | |
| else: | |
| state = self.states[layer_idx] | |
| if recurrent_state is not None: | |
| state["recurrent_state"] = recurrent_state | |
| if attn_state is not None: | |
| key_state, value_state = state["attn_state"] | |
| if window_size is not None and key_state.shape[-2] == window_size: | |
| # DO NOT allocate new memory if the cache is full | |
| # roll the key/value states to the left by `input_size` | |
| key_state = key_state.roll(-input_size, -2) | |
| value_state = value_state.roll(-input_size, -2) | |
| # replace the last `input_size` tokens with the new key/value states | |
| key_state[..., -input_size:, :] = attn_state[0] | |
| value_state[..., -input_size:, :] = attn_state[1] | |
| attn_state = (key_state, value_state) | |
| else: | |
| attn_state = ( | |
| torch.cat([key_state, attn_state[0]], -2), | |
| torch.cat([value_state, attn_state[1]], -2), | |
| ) | |
| state["attn_state"] = attn_state | |
| if conv_state is not None: | |
| state["conv_state"] = conv_state | |
| if ffn_state is not None: | |
| state["ffn_state"] = ffn_state | |
| return state | |
| def get_seq_length(self, layer_idx: int | None = 0) -> int: | |
| """Returns the sequence length of the cached states. A layer index can be optionally passed.""" | |
| if len(self.states) <= layer_idx: | |
| return 0 | |
| return self._seen_tokens | |
| def get_max_length(self) -> int | None: | |
| """Returns the maximum sequence length of the cached states. Cache does not have a maximum length.""" | |
| return None | |
| def to_legacy_cache(self) -> tuple: | |
| return tuple(self.states) | |
| def from_legacy_cache( | |
| cls, past_key_values: tuple | None = None, seen_tokens: int = 0 | |
| ) -> Cache: | |
| """Converts a cache in the legacy cache format into an equivalent `Cache`.""" | |
| cache = cls(seen_tokens) | |
| if isinstance(past_key_values, list): | |
| for layer_idx in range(len(past_key_values)): | |
| cache.states.append(past_key_values[layer_idx]) | |
| return cache | |