diff --git a/docs/source/en/cache_explanation.md b/docs/source/en/cache_explanation.md index 13f310669200..cdb7762c7fde 100644 --- a/docs/source/en/cache_explanation.md +++ b/docs/source/en/cache_explanation.md @@ -82,22 +82,18 @@ When you use Transformers' [`Cache`] class, the self-attention module performs s ## Cache storage implementation -The actual storage of key-value pairs varies between cache implementations. As an example, consider the [`DynamicCache`]. +Caches are structured as a list of layers, where each layer contains a key and value cache. The key and value caches are tensors with the shape `[batch_size, num_heads, seq_len, head_dim]`. +Layers can be of different types (e.g. `DynamicLayer`, `StaticLayer`, `SlidingWindowLayer`), which mostly changes how sequence length is handled and how the cache is updated. -In [`DynamicCache`], the key-value pairs are stored as two lists of tensors. Each tensor in the lists have the shape `[batch_size, num_heads, seq_len, head_dim]`. -- `key_cache`: A list of tensors, one for each layer. -- `value_cache`: A list of tensors, one for each layer. +The simplest is a `DynamicLayer` that grows as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token: -When new tokens are processed: - -1. For each layer, the new key and value states are concatenated with the existing cache. ```py -self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) -self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) +cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2) +cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2) ``` -2. The cache grows dynamically as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token. +Other layer types like `StaticLayer` and `SlidingWindowLayer` have a fixed sequence length that is set when the cache is created. This makes them compatible with `torch.compile`. In the case of `SlidingWindowLayer`, existing tokens are shifted out of the cache when a new token is added. The example below demonstrates how to create a generation loop with [`DynamicCache`]. As discussed, the attention mask is a concatenation of past and current token values and `1` is added to the cache position for the next token. diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 1c17e99d5da3..c64ba2a3ca43 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -356,66 +356,93 @@ A [`Constraint`] can be used to force the generation to include specific tokens ## Caches -[[autodoc]] Cache +[[autodoc]] CacheLayerMixin - update + - get_seq_length + - get_mask_sizes + - get_max_cache_shape + - reset + - reorder_cache -[[autodoc]] CacheConfig - - update +[[autodoc]] DynamicLayer + - update + - crop + - batch_repeat_interleave + - batch_select_indices -[[autodoc]] QuantizedCacheConfig - - validate +[[autodoc]] StaticLayer + - update -[[autodoc]] DynamicCache +[[autodoc]] SlidingWindowLayer + - update + +[[autodoc]] CacheProcessor + - pre_update + - post_update + +[[autodoc]] OffloadedCacheProcessor + - pre_update + +[[autodoc]] QuantizedCacheProcessor + - post_update + +[[autodoc]] QuantoQuantizedCacheProcessor + - post_update + +[[autodoc]] HQQQuantizedCacheProcessor + - post_update + +[[autodoc]] Cache - update - get_seq_length + - get_mask_sizes + - get_max_cache_shape + - reset - reorder_cache + - crop + - batch_repeat_interleave + - batch_select_indices + +[[autodoc]] DynamicCache - to_legacy_cache - from_legacy_cache [[autodoc]] QuantizedCache - - update - - get_seq_length [[autodoc]] QuantoQuantizedCache +[[autodoc]] QuantoQuantizedCacheProcessor + [[autodoc]] HQQQuantizedCache +[[autodoc]] HQQQuantizedCacheProcessor + [[autodoc]] OffloadedCache - - update - - prefetch_layer - - evict_previous_layer [[autodoc]] StaticCache - - update - - get_seq_length - - reset [[autodoc]] OffloadedStaticCache - - update - - get_seq_length - - reset [[autodoc]] HybridCache - - update - - get_seq_length - - reset + +[[autodoc]] HybridChunkedCache [[autodoc]] SlidingWindowCache - - update - - reset [[autodoc]] EncoderDecoderCache - - get_seq_length - to_legacy_cache - from_legacy_cache - - reset - - reorder_cache [[autodoc]] MambaCache - update_conv_state - update_ssm_state - reset +[[autodoc]] CacheConfig + +[[autodoc]] QuantizedCacheConfig + + ## Watermark Utils [[autodoc]] WatermarkingConfig diff --git a/docs/source/en/kv_cache.md b/docs/source/en/kv_cache.md index c6c5f655582c..a1b6dd81ff16 100644 --- a/docs/source/en/kv_cache.md +++ b/docs/source/en/kv_cache.md @@ -134,7 +134,7 @@ The [`QuantizedCache`] reduces memory requirements by quantizing the KV values t > [!WARNING] > Quantizing the cache can harm latency if the context length is short and there is enough GPU memory available for generation without enabling cache quantization. Try to find a balance between memory efficiency and latency. -Enable [`QuantizedCache`] by configuring `cache_implementation="quantized"` in [`GenerationConfig`], and indicate the quantization backend in [`QuantizedCacheConfig`]. Any additional quantization related parameters should also be passed either as a dict or an instance of [`QuantizedCacheConfig`]. You should use the default values for these additional parameters unless you're running out-of-memory. In that case, consider decreasing the residual length. +Enable [`QuantizedCache`] by configuring `cache_implementation="quantized"` in [`GenerationConfig`], and the quantization backend, as well as any additional quantization related parameters should also be passed either as a dict. You should use the default values for these additional parameters unless you're running out-of-memory. In that case, consider decreasing the residual length. @@ -143,7 +143,7 @@ For [`HQQQuantizedCache`], we recommend setting the `axis-key` and `axis-value` ```py import torch -from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig +from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto") @@ -161,7 +161,7 @@ For [`QuantoQuantizedCache`], we recommend setting the `axis-key` and `axis-valu ```py import torch -from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig +from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto") @@ -275,7 +275,6 @@ from transformers.cache_utils import ( StaticCache, SlidingWindowCache, QuantoQuantizedCache, - QuantizedCacheConfig, ) model_id = "meta-llama/Llama-2-7b-chat-hf" diff --git a/docs/source/ko/internal/generation_utils.md b/docs/source/ko/internal/generation_utils.md index e4841f0c626a..1a08a79368d3 100644 --- a/docs/source/ko/internal/generation_utils.md +++ b/docs/source/ko/internal/generation_utils.md @@ -345,12 +345,6 @@ generation_output[:2] [[autodoc]] Cache - update -[[autodoc]] CacheConfig - - update - -[[autodoc]] QuantizedCacheConfig - - validate - [[autodoc]] DynamicCache - update - get_seq_length diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3d1566580af3..84892590b1af 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -365,15 +365,28 @@ ] _import_structure["activations"] = [] _import_structure["cache_utils"] = [ + "CacheLayerMixin", + "DynamicLayer", + "StaticLayer", + "SlidingWindowLayer", + "ChunkedSlidingLayer", + "CacheProcessor", + "OffloadedCacheProcessor", + "QuantizedCacheProcessor", + "QuantoQuantizedCacheProcessor", + "HQQQuantizedCacheProcessor", "Cache", "CacheConfig", "DynamicCache", "EncoderDecoderCache", "HQQQuantizedCache", + "HQQQuantizedCacheProcessor", "HybridCache", + "HybridChunkedCache", "OffloadedCache", "OffloadedStaticCache", "QuantizedCache", + "QuantoQuantizedCacheProcessor", "QuantizedCacheConfig", "QuantoQuantizedCache", "SinkCache", diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index eecf0c7c0e80..c8471b60e449 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1,11 +1,12 @@ import copy +import functools import importlib.metadata +import inspect import json import os -import warnings from collections.abc import Iterable from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch from packaging import version @@ -23,353 +24,703 @@ logger = logging.get_logger(__name__) -# Utility functions for static/sliding cache update logic -def _static_cache_update( - k_cache: torch.Tensor, - v_cache: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - cache_position: Optional[torch.LongTensor], -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the static cache tensors in place. - - Args: - k_cache (`torch.Tensor`): The key cache tensor to update. - v_cache (`torch.Tensor`): The value cache tensor to update. - key_states (`torch.Tensor`): The new key states to add. - value_states (`torch.Tensor`): The new value states to add. - cache_position (`Optional[torch.LongTensor]`): The position indices where the new states should be inserted. - If None, the entire cache is overwritten (prefill). +class CacheLayerMixin: + """Base, abstract class for a single layer's cache.""" - Returns: - tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value cache tensors (modified in-place). - """ - if cache_position is None: - # Prefill phase where seq_len potentially equals max_cache_len. Directly copy. - k_cache.copy_(key_states) - v_cache.copy_(value_states) - else: - # Generation phase. Update specific positions. - # Use index_copy_ for in-place update (compile-friendly). - try: - k_cache.index_copy_(2, cache_position, key_states) - v_cache.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # Fallback for devices like MPS where index_copy_ might not be supported. - k_cache[:, :, cache_position] = key_states - v_cache[:, :, cache_position] = value_states - return k_cache, v_cache - - -def _sliding_cache_update( - k_cache: torch.Tensor, - v_cache: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - cache_position: torch.LongTensor, - max_cache_len: int, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the sliding window cache tensors, returning the potentially modified tensors. + is_compileable = False - Args: - k_cache (`torch.Tensor`): The key cache tensor to update. - v_cache (`torch.Tensor`): The value cache tensor to update. - key_states (`torch.Tensor`): The new key states to add. - value_states (`torch.Tensor`): The new value states to add. - cache_position (`torch.LongTensor`): The position indices where the new states should be inserted. - max_cache_len (`int`): The maximum length of the sliding window cache. + def __init__(self): + self.keys, self.values = None, None - Returns: - tuple[`torch.Tensor`, `torch.Tensor`]: The key and value tensors representing the cache state after the update. - For prefill > window, these are the full input states. - Otherwise, they are the updated cache tensors. - """ - # Handle prefill phase when prompt length > sliding_window_size - if cache_position.shape[0] > max_cache_len: - new_k = key_states[:, :, -max_cache_len:, :] - new_v = value_states[:, :, -max_cache_len:, :] - k_cache.copy_(new_k) - v_cache.copy_(new_v) - return key_states, value_states + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Updates KV cache, returns updated keys/values of the layer.""" + raise NotImplementedError(f"Make sure to implement `update` in {self.__class__.__name__}.") - # Sliding window logic for generation phase or prefill < window - slicing = torch.arange(max_cache_len, device=value_states.device) - current_seq_len = cache_position[-1] + 1 # Use last position to determine current length - to_shift = current_seq_len > max_cache_len - indices = (slicing + to_shift.sum()) % max_cache_len + def get_seq_length(self, cache_position=None) -> int: + """Returns the sequence length of this layer's cache.""" + raise NotImplementedError(f"Make sure to implement `get_seq_length` in {self.__class__.__name__}.") - k_out_shifted = k_cache[:, :, indices] - v_out_shifted = v_cache[:, :, indices] + def get_max_cache_shape(self) -> int: + """Returns the maximum sequence length (i.e. max capacity) of this layer's cache.""" + raise NotImplementedError(f"Make sure to implement `get_max_cache_shape` in {self.__class__.__name__}.") - # Clamp cache_position to determine the *target index* within the shifted cache view - update_position = cache_position.clamp(min=0, max=max_cache_len - 1) + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + """Returns mask sizes for the layer.""" + raise NotImplementedError(f"Make sure to implement `get_mask_sizes` in {self.__class__.__name__}.") - try: - k_out_updated = k_out_shifted.index_copy(2, update_position, key_states) - v_out_updated = v_out_shifted.index_copy(2, update_position, value_states) - except NotImplementedError: - # Fallback for MPS: clone and modify the clone - k_out_updated = k_out_shifted.clone() - v_out_updated = v_out_shifted.clone() - k_out_updated[:, :, update_position] = key_states - v_out_updated[:, :, update_position] = value_states + def reset(self) -> None: + """Resets the cache values while preserving the objects""" + self.keys.zero_() + self.values.zero_() - k_cache.copy_(k_out_updated) - v_cache.copy_(v_out_updated) - return k_out_updated, v_out_updated + def reorder_cache(self, beam_idx: torch.LongTensor) -> tuple[torch.Tensor, torch.Tensor]: + """Reorders this layer's cache for beam search.""" + if self.keys.numel(): + device = self.keys.device + self.keys = self.keys.index_select(0, beam_idx.to(device)) + if self.values.numel(): + device = self.values.device + self.values = self.values.index_select(0, beam_idx.to(device)) -class Cache: +class DynamicLayer(CacheLayerMixin): """ - Base, abstract class for all caches. The actual data structure is specific to each subclass. + A cache layer that grows dynamically as more tokens are generated. This is the default for generative models. + It stores the Key and Value states as tensors with shape `[batch_size, num_heads, seq_len, head_dim]`. + + See `CacheLayerMixin` for details on common methods that are implemented by all cache layers. """ - is_compileable = False + @classmethod + def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "DynamicLayer": + """ + Build a `DynamicLayer` instance from pre-existing key/value tensors. - def __init__(self): - super().__init__() + Args: + keys (`torch.Tensor`): + Key cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``. + values (`torch.Tensor`): + Value cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``. + + Returns: + `DynamicLayer`: The newly constructed layer whose internal cache directly references + the supplied tensors. + """ + layer = cls() + layer.keys = keys + layer.values = values + return layer def update( self, key_states: torch.Tensor, value_states: torch.Tensor, - layer_idx: int, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + Updates the cache with the new `key_states` and `value_states`. Parameters: key_states (`torch.Tensor`): The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. These are specific to each subclass and allow new types of - cache to be created. + cache_kwargs (`dict[str, Any]`, *optional*): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicLayer`. Return: A tuple containing the updated key and value states. """ - raise NotImplementedError("Make sure to implement `update` in a subclass.") + if self.keys is None: + self.keys = key_states + self.values = value_states + else: + self.keys = torch.cat([self.keys, key_states], dim=-2) + self.values = torch.cat([self.values, value_states], dim=-2) + return self.keys, self.values - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + def get_seq_length(self, cache_position=None) -> int: + """Returns the sequence length of the cached states.""" + if self.keys is None or self.keys.numel() == 0: + return 0 + return self.keys.shape[-2] - def get_max_cache_shape(self) -> Optional[int]: - """Returns the maximum sequence length (i.e. max capacity) of the cache object""" - raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") - - def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: - """Given the sequence length of the new inputs, returns the usable length of the cache.""" - # Cache without size limit -> all cache is usable - # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache - # length, we will need to evict part of the cache (and thus not all cache is usable) - max_length = self.get_max_cache_shape() - previous_seq_length = self.get_seq_length(layer_idx) - if max_length is not None and previous_seq_length + new_seq_length > max_length: - return max_length - new_seq_length - return previous_seq_length + def get_max_cache_shape(self) -> int: + """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length.""" + return -1 - def reorder_cache(self, beam_idx: torch.LongTensor): + def reorder_cache(self, beam_idx: torch.LongTensor) -> None: """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].numel(): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - if self.value_cache[layer_idx].numel(): - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + if self.keys is not None and self.keys.numel(): + self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device)) + self.values = self.values.index_select(0, beam_idx.to(self.values.device)) - def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + def crop(self, max_length: int) -> None: """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. + Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be + negative to remove `max_length` tokens. """ + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + if self.keys is not None and self.keys.numel(): + self.keys = self.keys[..., :max_length, :] + self.values = self.values[..., :max_length, :] + + def batch_repeat_interleave(self, repeats: int) -> None: + """Repeat the cache `repeats` times in the batch dimension.""" + if self.keys is not None and self.keys.numel(): + self.keys = self.keys.repeat_interleave(repeats, dim=0) + self.values = self.values.repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices: torch.Tensor) -> None: + """Only keep the `indices` in the batch dimension of the cache.""" + if self.keys is not None and self.keys.numel(): + self.keys = self.keys[indices, ...] + self.values = self.values[indices, ...] + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + """Return the length and offset of the cache, used to generate the mask""" + kv_offset = 0 query_length = cache_position.shape[0] past_seen_tokens = self.get_seq_length() kv_length = query_length + past_seen_tokens - return kv_length, 0 + return kv_length, kv_offset -@dataclass -class CacheConfig: +class StaticLayer(CacheLayerMixin): """ - Base class for cache configs + A static cache layer that stores the Key and Value states as static tensors with shape `[batch_size, num_heads, seq_len, head_dim]`. + It allocates its full backing tensors up-front and mutates them in-place. Built for `torch.compile` support. + + See `CacheLayerMixin` for details on common methods that are implemented by all cache layers. """ - cache_implementation: None + is_compileable = True + is_sliding = False - @classmethod - def from_dict(cls, config_dict, **kwargs): + def __init__( + self, + max_cache_len: int, + batch_size: int, + num_heads: int, + head_dim: int, + dtype: torch.dtype = torch.float32, + device: str = "cpu", + sliding_window: Optional[int] = None, + ): """ - Constructs a CacheConfig instance from a dictionary of parameters. Args: - config_dict (dict[str, Any]): Dictionary containing configuration parameters. - **kwargs: Additional keyword arguments to override dictionary values. - - Returns: - CacheConfig: Instance of CacheConfig constructed from the dictionary. + max_cache_len (`int`): + Maximum number of tokens that can be stored, used for tensor preallocation. + batch_size (`int`): + Maximum batch size the cache is pre-allocated for. + num_heads (`int`): + Number of attention heads. + head_dim (`int`): + Per-head hidden dimension. + dtype (`torch.dtype`, defaults to `torch.float32`): + Data type of the cache tensors. + device (`str` or `torch.device`, defaults to `"cpu"`): + Device on which the cache tensors will be materialised. + + Notes: + Static layers allocate their full backing tensors up-front and mutate them + in-place. See the documentation of `Cache` for shared helper methods that + operate uniformly across all layer types. """ - config = cls(**config_dict) - to_remove = [] - for key, value in kwargs.items(): - if hasattr(config, key): - setattr(config, key, value) - to_remove.append(key) - for key in to_remove: - kwargs.pop(key, None) - return config + self.max_cache_len = max_cache_len + self.max_batch_size = batch_size + self.num_heads = num_heads + self.head_dim = head_dim + self.dtype = dtype + self.device = device - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file - def to_json_file(self, json_file_path: Union[str, os.PathLike]): + self.keys = torch.zeros( + (batch_size, num_heads, self.max_cache_len, head_dim), + dtype=dtype, + device=device, + ) + self.values = torch.zeros( + (batch_size, num_heads, self.max_cache_len, head_dim), + dtype=dtype, + device=device, + ) + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + torch._dynamo.mark_static_address(self.keys) + torch._dynamo.mark_static_address(self.values) + + def get_max_cache_shape(self) -> int: + """Return the maximum cache shape of the cache""" + return self.max_cache_len + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: """ - Save this instance to a JSON file. + Update the static cache tensors in place. Args: - json_file_path (`str` or `os.PathLike`): - Path to the JSON file in which this configuration instance's parameters will be saved. - use_diff (`bool`, *optional*, defaults to `True`): - If set to `True`, only the difference between the config instance and the default - `QuantizationConfig()` is serialized to JSON file. + key_states (`torch.Tensor`): The new key states to cache. + value_states (`torch.Tensor`): The new value states to cache. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. + + Returns: + tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value states. """ - with open(json_file_path, "w", encoding="utf-8") as writer: - config_dict = self.to_dict() - json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None + key_states = key_states.to(self.keys.dtype) + value_states = value_states.to(self.values.dtype) - writer.write(json_string) + if cache_position is None: + # Prefill phase where seq_len potentially equals max_cache_len. Directly copy. + self.keys.copy_(key_states) + self.values.copy_(value_states) + else: + # Generation phase. Update specific positions. + # Use index_copy_ for in-place update (compile-friendly). + try: + self.keys.index_copy_(2, cache_position, key_states) + self.values.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # Fallback for devices like MPS where index_copy_ might not be supported. + self.keys[:, :, cache_position] = key_states + self.values[:, :, cache_position] = value_states + return self.keys, self.values + + def get_seq_length(self, cache_position=None) -> int: + """Returns the sequence length of the cached states.""" + if cache_position is not None: + return int(cache_position[-1] + 1) + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + seq_length = (self.keys[0, 0].any(dim=-1)).sum() if self.keys is not None else 0 + return seq_length - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict - def to_dict(self) -> dict[str, Any]: - """ - Serializes this instance to a Python dictionary. Returns: - `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. - """ - return copy.deepcopy(self.__dict__) + def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + """Reorders the cache for beam search, given the selected beam indices.""" + dev = self.keys.device + beam_idx_dev = beam_idx.to(dev) + self.keys = self.keys.index_select(0, beam_idx_dev) + self.values = self.values.index_select(0, beam_idx_dev) + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + """Return the length and offset of the cache, used to generate the attention mask""" + kv_offset = 0 + kv_length = self.max_cache_len + return kv_length, kv_offset - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ - def __iter__(self): - """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" - for attr, value in copy.deepcopy(self.__dict__).items(): - yield attr, value - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ - def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" +class SlidingWindowLayer(StaticLayer): + """ + A static cache layer that implements sliding window attention caching. - def to_json_string(self): + See `CacheLayerMixin` for details on common methods that are implemented by all cache layers. + """ + + def __init__(self, sliding_window, *args, **kwargs): """ - Serializes this instance to a JSON formatted string. - Returns: - str: JSON formatted string representing the configuration instance. + Args: + sliding_window (`int`): + Effective window size: number of tokens that are kept on each update call. """ - return json.dumps(self.__dict__, indent=2) + "\n" + kwargs.pop("max_cache_len", None) + super().__init__(*args, max_cache_len=sliding_window, *args, **kwargs) - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update - def update(self, **kwargs): + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: """ - Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, - returning all the unused kwargs. + Update the sliding window cache tensors in place. Args: - kwargs (`dict[str, Any]`): - Dictionary of attributes to tentatively update this class. + key_states (`torch.Tensor`): The new key states to cache. + value_states (`torch.Tensor`): The new value states to cache. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. Returns: - `dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value states. """ - to_remove = [] - for key, value in kwargs.items(): - if hasattr(self, key): - setattr(self, key, value) - to_remove.append(key) + cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None + if cache_position is None: + raise ValueError("`cache_position` must be provided for SlidingWindowLayer.") - # Remove all the attributes that were updated, without modifying the input dict - unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} - return unused_kwargs + key_states = key_states.to(self.keys.dtype) + value_states = value_states.to(self.values.dtype) + # Handle prefill phase when prompt length > sliding_window_size. + # Note that we store cropped key/value states in the cache but return the full key/value states. + if cache_position.shape[0] > self.max_cache_len: + new_k = key_states[:, :, -self.max_cache_len :, :] + new_v = value_states[:, :, -self.max_cache_len :, :] + self.keys.copy_(new_k) + self.values.copy_(new_v) + return key_states, value_states -@dataclass -class QuantizedCacheConfig(CacheConfig): + # Sliding window logic for generation phase or prefill < window + slicing = torch.arange(self.max_cache_len, device=value_states.device) + current_seq_len = cache_position[-1] + 1 # Use last position to determine current length + to_shift = current_seq_len > self.max_cache_len + indices = (slicing + to_shift.sum()) % self.max_cache_len + + k_out_shifted = self.keys[:, :, indices] + v_out_shifted = self.values[:, :, indices] + + # Clamp cache_position to determine the *target index* within the shifted cache view + update_position = cache_position.clamp(min=0, max=self.max_cache_len - 1) + + try: + k_out_updated = k_out_shifted.index_copy(2, update_position, key_states) + v_out_updated = v_out_shifted.index_copy(2, update_position, value_states) + except NotImplementedError: + # Fallback for MPS: clone and modify the clone + k_out_updated = k_out_shifted.clone() + v_out_updated = v_out_shifted.clone() + k_out_updated[:, :, update_position] = key_states + v_out_updated[:, :, update_position] = value_states + + self.keys.copy_(k_out_updated) + self.values.copy_(v_out_updated) + return self.keys, self.values + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + """Return the length and offset of the cache, used to generate the attention mask""" + query_length = cache_position.shape[0] + first_cache_position = cache_position[0] + + kv_offset = torch.clamp(first_cache_position - self.max_cache_len + 1, min=0) + # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns + kv_length = max(query_length, self.max_cache_len) + return kv_length, kv_offset + + +class ChunkedSlidingLayer(SlidingWindowLayer): """ - Configuration class for quantized cache settings. + An extended SlidingWindowLayer that supports prefill chunking, originally implemented for Llama 4. - Attributes: - backend (`str`, *optional*, defaults to `"quanto"`): - Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`] - nbits (`Optional[int]`, *optional*, defaults to 4): - Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2. - axis_key (`int`, *optional*, defaults to 0): - Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. - axis_value (`int`, *optional*, defaults to 0): - Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. - q_group_size (`Optional[int]`, *optional*, defaults to 64): - Size of the quantization group, should be a divisor of the model's hidden dimension. - Defaults to 64. - residual_length (`Optional[int]`, *optional*, defaults to 128): - Length of the residual cache which will always be stored in original precision. - Defaults to 128. - compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): - The default dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization. - device (`str`, *optional*, defaults to `"cpu"`): - Device on which to perform computations, should be same as the model's device. + See `SlidingWindowLayer` for details on common methods that are implemented by all cache layers. """ - def __init__( + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cumulative_length = 0 + + def update( self, - backend: str = "quanto", - nbits: Optional[int] = 4, - axis_key: Optional[int] = 0, - axis_value: Optional[int] = 0, - q_group_size: Optional[int] = 64, - residual_length: Optional[int] = 128, - compute_dtype: Optional[torch.dtype] = torch.float16, - device: Optional[str] = "cpu", - ): - self.backend = backend - self.nbits = nbits - self.axis_key = axis_key - self.axis_value = axis_value - self.q_group_size = q_group_size - self.residual_length = residual_length - self.compute_dtype = compute_dtype - self.device = device + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None + if cache_position is None: + raise ValueError("`cache_position` must be provided for ChunkedSlidingLayer.") - def validate(self): - """Validates if the arguments passed are correct""" + cumulative_length = self.cumulative_length + self.cumulative_length += key_states.shape[-2] + is_full = cumulative_length >= self.max_cache_len - incorrect_arg_msg = ( - "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " - "but found {found_value}" - ) - # Check that the values are reasonable in general (nbits, axis) - # Later in QuantizedCache init we check if they are supported for that particular backend - if self.nbits not in [1, 2, 3, 4, 8]: - raise ValueError( - incorrect_arg_msg.format( - key="nbits", - correct_value="2 or 4 or 8", - found_value=self.nbits, - ), - ) - if self.q_group_size <= 0: - raise ValueError( - incorrect_arg_msg.format( - key="q_group_size", - correct_value="a positive integer", - found_value=self.q_group_size, - ), - ) - if self.residual_length < 0: - raise ValueError( + if is_full: + full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2) + full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2) + # Fast decoding path -> here as the effective size is still sliding window, it is extremely important + # to return `self.key_cache[layer_idx]` and `self.value_cache[layer_idx]`, as they have the fixed address + # in memory (the values are the same as the full states, but not the address!!) + if key_states.shape[-2] == 1: + self.keys.copy_(full_key_states) + self.values.copy_(full_value_states) + return self.keys, self.values + elif not is_full and cumulative_length + key_states.shape[2] > self.max_cache_len: + if cumulative_length == 0: + full_key_states = key_states + full_value_states = value_states + else: + full_key_states = torch.cat((self.keys[:, :, :cumulative_length, :], key_states), dim=-2) + full_value_states = torch.cat((self.values[:, :, :cumulative_length, :], value_states), dim=-2) + else: + try: + self.keys.index_copy_(2, cache_position, key_states) + self.values.index_copy_(2, cache_position, value_states) + except NotImplementedError: + self.keys[:, :, cache_position] = key_states + self.values[:, :, cache_position] = value_states + return self.keys, self.values + + self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :]) + self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :]) + return full_key_states, full_value_states + + def reset(self) -> None: + super().reset() + self.cumulative_length = 0 + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + query_length = cache_position.shape[0] + first_cache_position = cache_position[0] + sliding_window = self.max_cache_len + + kv_offset = torch.clamp(first_cache_position - sliding_window + 1, min=0) + # This is the true general case for any Cache using local attention (sliding or chunked) + if first_cache_position >= sliding_window: + # Here the Cache is already full + kv_length = sliding_window + query_length - 1 + elif first_cache_position < sliding_window and first_cache_position + query_length > sliding_window: + # Here the Cache becomes full with the new input + kv_length = first_cache_position + query_length + else: + # Here the Cache is still smaller than the local size, but we return the local size as it's static + kv_length = sliding_window + return kv_length, kv_offset + + +class CacheProcessor: + """ + Base class for cache processors. It defines a pre-update and post-update methods that are called before and after the cache update. + This class should be subclassed. + """ + + def __init__(self, cache: "Cache", **kwargs) -> None: + """ + Initialize the processor and perform compatibility checks with the cache. + + Args: + cache (`Cache`): The cache instance this processor will be applied to. + **kwargs: Additional arguments that may be needed for initialization. + """ + raise NotImplementedError(f"Make sure to implement `init` in {self.__class__.__name__}.") + + def pre_update( + self, + cache: "Cache", + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Function called before the cache update. Can modify the key/value states. + + Args: + cache (`Cache`): The cache instance. + key_states (`torch.Tensor`): The new key states to cache. + value_states (`torch.Tensor`): The new value states to cache. + layer_idx (`int`): The index of the layer to cache the states for. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. + + Returns: + The modified key and value states. + """ + return key_states, value_states + + def post_update( + self, + cache: "Cache", + key_tensors: torch.Tensor, + value_tensors: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Function called after the cache update. Can process the cached data. + + Args: + cache (`Cache`): The cache instance. + key_states (`torch.Tensor`): The key states that were cached. + value_states (`torch.Tensor`): The value states that were cached. + layer_idx (`int`): The index of the layer that was updated. + cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. + + Returns: + The final key and value states to return to the model. + """ + return key_tensors, value_tensors + + +class OffloadedCacheProcessor(CacheProcessor): + """ + A cache processor that offloads cache tensors to conserve accelerator memory. + + This processor manages moving cache tensors between accelerator and CPU memory, + using asynchronous prefetching to minimize performance impact. Works with both + dynamic and static layers. + """ + + def __init__(self, cache: "Cache", offload_device: Union[str, torch.device] = "cpu", **kwargs): + """Initialize the offload processor and check device compatibility.""" + self.offload_device = torch.device(offload_device) + self.original_device = [] + self.prefetch_stream = None + self.beam_idx = None + + if not ( + torch.cuda.is_available() + or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available()) + ): + raise RuntimeError( + "OffloadedCacheProcessor can only be used with a GPU" + + (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "") + ) + + self.is_static = any(isinstance(layer, StaticLayer) for layer in cache.layers) + if self.is_static: + for i, layer in enumerate(cache.layers): + device = cache.layer_init_args["device"] if i == 0 else self.offload_device + layer.keys = layer.keys.to(device) + layer.values = layer.values.to(device) + self.original_device.append(cache.layer_init_args["device"]) + if len(cache) != cache.num_hidden_layers: + raise ValueError("If static layers are used, all cache layers must be initialized") + + self.prefetch_stream = ( + torch.Stream() if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.Stream() + ) + + def pre_update( + self, + cache: "Cache", + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Handles prefetching and eviction before cache update.""" + # Update the cache + if len(cache) < layer_idx: + raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.") + elif len(cache) == layer_idx: + self.original_device.append(key_states.device) + self._evict_previous_layer(cache, layer_idx) + else: + # Wait for the previous layer to be evicted (on default stream) + if is_torch_greater_or_equal("2.7", accept_dev=True): + torch.accelerator.current_stream().synchronize() + else: + torch.cuda.current_stream().synchronize() + self._evict_previous_layer(cache, layer_idx) + self._ensure_layer_on_device(cache, layer_idx) + + # Prefetch the next layer + self._prefetch_layer(cache, (layer_idx + 1) % len(cache)) + return key_states, value_states + + def _prefetch_layer(self, cache: "Cache", layer_idx: int): + """Starts prefetching the next layer cache.""" + if layer_idx < len(cache): + with ( + self.prefetch_stream + if is_torch_greater_or_equal("2.7", accept_dev=True) + else torch.cuda.stream(self.prefetch_stream) + ): + # Prefetch next layer tensors to GPU + device = self.original_device[layer_idx] + cache.layers[layer_idx].keys = cache.layers[layer_idx].keys.to(device, non_blocking=True) + cache.layers[layer_idx].values = cache.layers[layer_idx].values.to(device, non_blocking=True) + + def _evict_previous_layer(self, cache: "Cache", layer_idx: int): + """Moves the previous layer cache to the CPU.""" + if len(cache) >= 2: # Layer 0 stays on device to be on-device after all layers are created + # We do it on the default stream so it occurs after all earlier computations on these tensors are done + prev_layer_idx = (layer_idx - 1) % len(cache) + cache.layers[prev_layer_idx].keys = cache.layers[prev_layer_idx].keys.to( + self.offload_device, non_blocking=True + ) + cache.layers[prev_layer_idx].values = cache.layers[prev_layer_idx].values.to( + self.offload_device, non_blocking=True + ) + + def _ensure_layer_on_device(self, cache: "Cache", layer_idx: int): + """Ensures the current layer is on the original device.""" + if layer_idx < len(cache): + # Wait for the previous prefetch to be done + self.prefetch_stream.synchronize() + + # Handle delayed beam search operations + if self.beam_idx is not None: + self.beam_idx = self.beam_idx.to(self.original_device[layer_idx]) + cache.layers[layer_idx].keys = cache.layers[layer_idx].keys.index_select(0, self.beam_idx) + cache.layers[layer_idx].values = cache.layers[layer_idx].values.index_select(0, self.beam_idx) + + +class QuantizedCacheProcessor(CacheProcessor): + """ + A cache processor that applies quantization to cache tensors to reduce memory usage. + + This processor quantizes cache tensors after they are stored, maintaining a residual + length in original precision and quantizing older tokens. + """ + + def __init__( + self, + cache: "Cache", + backend: str = "quanto", + nbits: int = 4, + axis_key: int = 0, + axis_value: int = 0, + q_group_size: int = 64, + residual_length: int = 128, + compute_dtype: torch.dtype = torch.float16, + device: str = "cpu", + ): + """ + Parameters: + backend (`str`, defaults to `"quanto"`): + Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`] + nbits (`int`, defaults to 4): + Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2. + axis_key (`int`, defaults to 0): + Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + axis_value (`int`, defaults to 0): + Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + q_group_size (`int`, defaults to 64): + Size of the quantization group, should be a divisor of the model's hidden dimension. + Defaults to 64. + residual_length (`int`, defaults to 128): + Length of the residual cache which will always be stored in original precision. + Defaults to 128. + compute_dtype (`torch.dtype`, defaults to `torch.float16`): + The default dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization. + device (`str`, defaults to `"cpu"`): + Device on which to perform computations, should be same as the model's device. + """ + self.backend = backend + self.nbits = nbits + self.axis_key = axis_key + self.axis_value = axis_value + self.q_group_size = q_group_size + self.residual_length = residual_length + self.compute_dtype = compute_dtype + self.device = device + self._quantized_keys: list[torch.Tensor] = [] + self._quantized_values: list[torch.Tensor] = [] + + self.validate() + self.erased_length = 0 + + # Only compatible with DynamicCache + if not isinstance(cache.layers[0], DynamicLayer): + raise ValueError("QuantizedCacheProcessor is only compatible with DynamicCache") + + def validate(self): + """Validates if the arguments passed are correct""" + + incorrect_arg_msg = ( + "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " + "but found {found_value}" + ) + # Check that the values are reasonable in general (nbits, axis) + # Later in QuantizedCache init we check if they are supported for that particular backend + if self.nbits not in [1, 2, 3, 4, 8]: + raise ValueError( + incorrect_arg_msg.format( + key="nbits", + correct_value="2 or 4 or 8", + found_value=self.nbits, + ), + ) + if self.q_group_size <= 0: + raise ValueError( + incorrect_arg_msg.format( + key="q_group_size", + correct_value="a positive integer", + found_value=self.q_group_size, + ), + ) + if self.residual_length < 0: + raise ValueError( incorrect_arg_msg.format( key="residual_length", correct_value="a positive integer", @@ -395,45 +746,523 @@ def validate(self): ), ) + def post_update( + self, + cache: "Cache", + key_tensors: torch.Tensor, + value_tensors: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Apply quantization after cache update.""" -@dataclass -class StaticCacheConfig(CacheConfig): + if len(cache) < layer_idx: + raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.") + + # `key_tensors` is the content of the residual cache, after having been updated by DynamicLayer + # On the first forward pass, we quantize the whole prompt (prefill, quantize_length=0) + # On subsequent passes, we accumulate the tokens in the residual cache and quantize when it is full. + if self._is_quantized_length_zero(layer_idx): + self._quantized_keys.append(self._quantize(key_tensors.contiguous(), axis=self.axis_key)) + self._quantized_values.append(self._quantize(value_tensors.contiguous(), axis=self.axis_value)) + + # Clear the residual cache + self.erased_length = key_tensors.shape[-2] + cache.layers[layer_idx].keys = torch.zeros( + 0, + dtype=key_tensors.dtype, + device=key_tensors.device, + ) + cache.layers[layer_idx].values = torch.zeros( + 0, + dtype=value_tensors.dtype, + device=value_tensors.device, + ) + # On prefill, we return the original prompt + keys_to_return, values_to_return = key_tensors, value_tensors + + else: + # Prepend the previously quantized cache + dequant_key = self._dequantize(self._quantized_keys[layer_idx]) + dequant_value = self._dequantize(self._quantized_values[layer_idx]) + keys_to_return = torch.cat([dequant_key, key_tensors], dim=-2) + values_to_return = torch.cat([dequant_value, value_tensors], dim=-2) + if key_tensors.shape[-2] >= self.residual_length: + # Quantize and store + self._quantized_keys[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key) + self._quantized_values[layer_idx] = self._quantize(values_to_return.contiguous(), axis=self.axis_value) + + # Clear the residual cache + self.erased_length += key_tensors.shape[-2] + cache.layers[layer_idx].keys = torch.zeros( + 0, + dtype=key_tensors.dtype, + device=key_tensors.device, + ) + cache.layers[layer_idx].values = torch.zeros( + 0, + dtype=value_tensors.dtype, + device=value_tensors.device, + ) + + return keys_to_return, values_to_return + + def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor: + """Quantize a tensor - to be implemented by specific quantization backends.""" + raise NotImplementedError("Quantization backend must implement _quantize method") + + def _dequantize(self, tensor: torch.Tensor) -> torch.Tensor: + """Dequantize a tensor - to be implemented by specific quantization backends.""" + raise NotImplementedError("Quantization backend must implement _dequantize method") + + def _is_quantized_length_zero(self, layer_idx: int) -> bool: + """Check if quantized cache is empty for layer. Note: shape[-2] is unreliable since quantized tensors are bit-packed and flattened.""" + return layer_idx >= len(self._quantized_keys) + + +class QuantoQuantizedCacheProcessor(QuantizedCacheProcessor): """ - Configuration class for static cache settings. + Quantized cache processor that uses `quanto` as a backend to perform quantization. + Current implementation supports `int2` and `int4` dtypes only. """ - cache_implementation = "static" + def __init__( + self, + cache: "Cache", + backend: str = "quanto", + nbits: int = 4, + axis_key: int = 0, + axis_value: int = 0, + q_group_size: int = 64, + residual_length: int = 128, + compute_dtype: torch.dtype = torch.float16, + device: str = "cpu", + ) -> None: + """Initialize the quanto quantization processor.""" + super().__init__( + cache, backend, nbits, axis_key, axis_value, q_group_size, residual_length, compute_dtype, device + ) - def __init__(self, batch_size: int, max_cache_len: int, device="cpu"): - self.batch_size = batch_size - self.max_cache_len = max_cache_len - self.device = device + if backend != "quanto": + raise ValueError(f"QuantoQuantizedCacheProcessor only supports `quanto` backend, but got {backend}") - def validate(self): - """Validates if the arguments passed are correct""" + if is_optimum_quanto_available(): + optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) + if optimum_quanto_version <= version.parse("0.2.5"): + raise ImportError( + f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCacheProcessor`. Detected version {optimum_quanto_version}." + ) + from optimum.quanto import MaxOptimizer, qint2, qint4 - incorrect_arg_msg = ( - "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " - "but found {found_value}" + if self.nbits not in [2, 4]: + raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") + + if self.axis_key not in [0, -1]: + raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}") + + if self.axis_value not in [0, -1]: + raise ValueError( + f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" + ) + + self.qtype = qint4 if self.nbits == 4 else qint2 + self.optimizer = MaxOptimizer() + + def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor: + """Quantize tensor using quanto backend.""" + if is_optimum_quanto_available(): + from optimum.quanto import quantize_weight + + scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) + qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) + return qtensor + + def _dequantize(self, qtensor: torch.Tensor) -> torch.Tensor: + """Dequantize tensor using quanto backend.""" + return qtensor.dequantize() + + +class HQQQuantizedCacheProcessor(QuantizedCacheProcessor): + """ + Quantized cache processor that uses `HQQ` as a backend to perform quantization. + Current implementation supports `int2`, `int4`, `int8` dtypes. + """ + + def __init__( + self, + cache: "Cache", + backend: str = "quanto", + nbits: int = 4, + axis_key: int = 0, + axis_value: int = 0, + q_group_size: int = 64, + residual_length: int = 128, + compute_dtype: torch.dtype = torch.float16, + device: str = "cpu", + ) -> None: + """Initialize the HQQ quantization processor.""" + super().__init__( + cache, backend, nbits, axis_key, axis_value, q_group_size, residual_length, compute_dtype, device ) - if self.batch_size <= 0: + if backend != "quanto": + raise ValueError(f"HQQQuantizedCacheProcessor only supports `quanto` backend, but got {backend}") + + if self.nbits not in [1, 2, 3, 4, 8]: raise ValueError( - incorrect_arg_msg.format( - key="batch_size", - correct_value="> 0", - found_value=self.batch_size, - ), + f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}" + ) + + if self.axis_key not in [0, 1]: + raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}") + + if self.axis_value not in [0, 1]: + raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}") + + self.quantizer = HQQQuantizer + + def _quantize(self, tensor: torch.Tensor, axis: int) -> tuple[torch.Tensor, dict]: + """Quantize tensor using HQQ backend.""" + qtensor, meta = self.quantizer.quantize( + tensor, + axis=axis, + device=self.device, + compute_dtype=self.compute_dtype, + nbits=self.nbits, + group_size=self.q_group_size, + ) + meta["compute_dtype"] = self.compute_dtype + self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype + meta["scale"] = meta["scale"].to(qtensor.device) + meta["zero"] = meta["zero"].to(qtensor.device) + return qtensor, meta + + def _dequantize(self, qtensor_and_meta: tuple[torch.Tensor, dict]) -> torch.Tensor: + """Dequantize tensor using HQQ backend.""" + quant_tensor, meta = qtensor_and_meta + tensor = self.quantizer.dequantize(quant_tensor, meta) + return tensor + + +def apply_processors( + fn: Callable[..., tuple[torch.Tensor, torch.Tensor]], +) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]: + @functools.wraps(fn) + def _wrapped_update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Wrapper around the update method to apply cache processors. + """ + if self.cache_processor is not None: + key_states, value_states = self.cache_processor.pre_update( + self, key_states, value_states, layer_idx, cache_kwargs + ) + + key_tensors, value_tensors = fn(self, key_states, value_states, layer_idx, cache_kwargs) + + if self.cache_processor is not None: + key_tensors, value_tensors = self.cache_processor.post_update( + self, key_tensors, value_tensors, layer_idx, cache_kwargs + ) + + return key_tensors, value_tensors + + return _wrapped_update + + +class KeyValuesWrapper: + """Helper class for Cache that simulates layer-indexed key/value lists from a layered cache. + This allows for BC access and writing, e.g., cache.key_cache[idx] = ... + Deprecated in favor of Cache.layers[idx].keys/values. TODO: remove in v4.56.0""" + + def __init__(self, layers, cache_type="keys"): + self.layers = layers + self.cache_type = cache_type + + def __getitem__(self, idx): + if isinstance(idx, slice): + return [getattr(layer, self.cache_type) for layer in self.layers[idx]] + return getattr(self.layers[idx], self.cache_type) + + def __setitem__(self, idx, value): + if isinstance(idx, slice): + for layer, val in zip(self.layers[idx], value): + setattr(layer, self.cache_type, val) + else: + setattr(self.layers[idx], self.cache_type, value) + + def __len__(self): + return len(self.layers) + + def __iter__(self): + for layer in self.layers: + yield getattr(layer, self.cache_type) + + def __bool__(self): + return bool(self.layers) + + +class Cache: + """ + Base container for per-layer key/value caches. + + A `Cache` behaves like a list of `CacheLayerMixin` objects, one per model layer. + Sub-classes such as `DynamicCache`, `StaticCache`, or `SlidingWindowCache` + simply pre-select which `CacheLayerMixin` class to use and may attach a + `CacheProcessor` (off-loading, quantization). + + Example + ------- + ```python + from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache + + model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + tok = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + inputs = tok("Hello", return_tensors="pt") + + cache = DynamicCache() + outputs = model(**inputs, past_key_values=cache, use_cache=True) + ``` + + Parameters: + config (`PretrainedConfig`, *optional*): + Model configuration used to infer number of layers, head sizes, default + device/dtype, etc. + cache_processor (`CacheProcessor` or `str`, *optional*): + Cache processor to apply (e.g., "offloaded", "quanto_quantized", "hqq_quantized") + or a CacheProcessor class. + layer_classes (`list[type[CacheLayerMixin]]`, *optional*): + List of `CacheLayerMixin` classes to instantiate for the cache. When shorter than the + required number of layers the list is cycled. Default is [DynamicLayer]. + max_batch_size (`int`, *optional*): Maximum batch size for static caches. + max_cache_len (`int`, *optional*): Maximum sequence length. For hybrid caches, SlidingWindowLayers are + clamped to `min(sliding_window, max_cache_len)`, StaticLayers use full `max_cache_len`. + device (`torch.device`, *optional*): Device for cache tensors. + dtype (`torch.dtype`, *optional*): Data type for cache tensors. + layer_device_map (`dict[int, Union[str, torch.device]]`, *optional*): Per-layer device mapping. + tp_size (`int`, *optional*): Tensor parallel size to adjust the number of key/value heads. + + Additional keyword arguments are forwarded to the chosen layers constructor(s) and CacheProcessors. See the + documentation of the relevant `CacheLayerMixin` class and `CacheProcessor` class for more details. + """ + + def __init__( + self, + config: Optional[PretrainedConfig] = None, + cache_processor: Optional[Union[str, type["CacheProcessor"]]] = None, + layer_classes: Optional[list[type["CacheLayerMixin"]]] = None, + max_batch_size: Optional[int] = None, + max_cache_len: Optional[int] = None, + device: Union[torch.device, str, None] = None, + dtype: Optional[torch.dtype] = None, + layer_device_map: Optional[dict[int, torch.device]] = None, + tp_size: Optional[int] = None, + **kwargs, + ): + self.layers: list["CacheLayerMixin"] = [] + processor_class = PROCESSOR_CLASS_MAP[cache_processor] if isinstance(cache_processor, str) else cache_processor + + if layer_classes is None: + layer_classes = [DynamicLayer] + + self.layer_classes = layer_classes + kwargs.update( + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=device, + dtype=dtype, + layer_device_map=layer_device_map, + tp_size=tp_size, + ) + processor_kwargs, kwargs = parse_processor_args(processor_class, kwargs) + self.layer_init_args = parse_layer_args_from_model_config(config, **kwargs) + self.num_hidden_layers = getattr(config, "num_hidden_layers", 1) + + self.append_new_layers(self.num_hidden_layers - 1) + self.cache_processor = processor_class(self, **processor_kwargs) if processor_class is not None else None + + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self.layers): + return self.layers[layer_idx].keys, self.layers[layer_idx].values + else: + raise KeyError( + f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}" ) - if self.max_cache_len <= 0: - raise ValueError( - incorrect_arg_msg.format( - key="max_cache_len", - correct_value="> 0", - found_value=self.max_cache_len, - ), - ) + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.layers[layer_idx].keys, self.layers[layer_idx].values) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + # Best effort BC support for old-style caches like Mambas, Falcon, HybridChunked that rely on __len__ + if getattr(self, "layers", None) is None: + if getattr(self, "key_cache", None) is not None: + return len(self.key_cache) + return 0 + # Empty dynamic caches initialize an empty layer to be ready for first update + dynamic_empty = ( + getattr(self, "layers", None) is not None + and len(self.layers) == 1 + and isinstance(self.layers[0], DynamicLayer) + and self.layers[0].keys is None + ) + return len(self.layers) if not dynamic_empty else 0 + + def __repr__(self): + return f"{self.__class__.__name__}(layers={self.layers})" + + def append_new_layers(self, layer_idx: int) -> None: + """ + Appends layers to the cache until the layer `layer_idx` is reached. + Used for preallocation in static caches and on the fly in dynamic caches. + + Args: + layer_idx (`int`): + The index of the layer to append. + """ + while len(self.layers) <= layer_idx: + args = self.layer_init_args.copy() + if self.layer_init_args.get("layer_device_map", None) is not None: + args["device"] = args.pop("layer_device_map")[layer_idx] + new_layer = self.layer_classes[len(self.layers) % len(self.layer_classes)](**args) + self.layers.append(new_layer) + + @apply_processors + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`dict[str, Any]`, *optional*): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache to be created. + + Return: + A tuple containing the updated key and value states. + """ + self.append_new_layers(layer_idx) + return self.layers[layer_idx].update(key_states, value_states, cache_kwargs) + + def get_seq_length(self, layer_idx: int = 0, cache_position=None) -> int: + """Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position""" + if layer_idx >= len(self.layers): + return 0 + # Hack since QuantizedCache messes with keys shape as it becomes the residual cache + if self.cache_processor is not None and isinstance(self.cache_processor, QuantizedCacheProcessor): + return self.cache_processor.erased_length + self.layers[layer_idx].get_seq_length(cache_position) + return self.layers[layer_idx].get_seq_length(cache_position) + + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), + for each layer. + """ + kv_length, kv_offset = self.layers[layer_idx].get_mask_sizes(cache_position) + return kv_length, kv_offset + + @property + def key_cache(self) -> KeyValuesWrapper: + """List-like object of key cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].keys`""" + logger.warning_once( + "`cache.key_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].keys` instead." + ) + return KeyValuesWrapper(self.layers, "keys") + + @property + def value_cache(self) -> KeyValuesWrapper: + """List-like object of value cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].values`""" + logger.warning_once( + "`cache.value_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].values` instead." + ) + return KeyValuesWrapper(self.layers, "values") + + ### Wrappers for layer operations and properties ### + + def get_max_cache_shape(self, layer_idx: int = 0) -> int: + """Returns maximum sequence length of the cache object. Dynamic caches do not have a maximum length.""" + return self.layers[layer_idx].get_max_cache_shape() + + def reset(self): + """Recursively reset all layers tensors""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].reset() + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorder the cache for beam search""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].reorder_cache(beam_idx) + + def crop(self, max_length: int): + """Crop the cache to the given length""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].crop(max_length) + + def batch_repeat_interleave(self, repeats: int): + """Repeat and interleave the cache""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].batch_repeat_interleave(repeats) + + def batch_select_indices(self, indices: torch.Tensor): + """Select indices from the cache""" + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].batch_select_indices(indices) + + @property + def max_batch_size(self) -> int: + """Return the maximum batch size of the cache""" + values = [layer.max_batch_size for layer in self.layers] + if len(set(values)) > 1: + raise ValueError(f"Max batch size is not consistent across layers: {values}") + return values[0] + + @property + def max_cache_len(self) -> int: + """Return the maximum cache length of the cache""" + values = [layer.max_cache_len for layer in self.layers] + if len(set(values)) > 1: + raise ValueError(f"Max cache length is not consistent across layers: {values}") + return values[0] + + @property + def is_compileable(self) -> bool: + """Return whether the cache is compileable""" + return all(layer.is_compileable for layer in self.layers) + + @property + def is_sliding(self) -> list[bool]: + """Return whether the layers of the cache are sliding window""" + return [getattr(layer, "is_sliding", False) for layer in self.layers] class DynamicCache(Cache): @@ -443,6 +1272,8 @@ class DynamicCache(Cache): It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is `[batch_size, num_heads, seq_len, head_dim]`. + See `Cache` for details on common methods that are implemented by all cache classes. + Example: ```python @@ -461,119 +1292,35 @@ class DynamicCache(Cache): ``` """ - def __init__(self, _distributed_cache_data: Optional[Iterable] = None) -> None: - super().__init__() - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - - # `_distributed_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36121 + # Specialized constructor for DDP cache data, needed for BC + def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs): + # `ddp_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212 # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the # iterable contains the key and value states for a layer gathered across replicas by torch.distributed # (shape=[global batch size, num_heads, seq_len, head_dim]). - # WARNING: `_distributed_cache_data` must be the first argument in `__init__`, otherwise we'll break + # WARNING: `ddp_cache_data` must be the first argument in `__init__`, otherwise we'll break # compatibility. The name of the argument doesn't matter. - if _distributed_cache_data is not None: - for key_states, value_states in _distributed_cache_data: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - - def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: - """ - Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the - sequence length. - """ - if layer_idx < len(self): - return (self.key_cache[layer_idx], self.value_cache[layer_idx]) - else: - raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") - - def __iter__(self): - """ - Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over - keys and values - """ - for layer_idx in range(len(self)): - yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) - - def __len__(self): - """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds - to the number of layers in the model. - """ - return len(self.key_cache) + if ddp_cache_data is not None: + for key_states, value_states in ddp_cache_data: + self.layers.append(DynamicLayer.from_tensors(key_states, value_states)) + super().__init__(*args, **kwargs) - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. + Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility. """ - # Update the cache - if key_states is not None: - if len(self.key_cache) <= layer_idx: - # There may be skipped layers, fill them with empty lists - for _ in range(len(self.key_cache), layer_idx): - self.key_cache.append(torch.tensor([])) - self.value_cache.append(torch.tensor([])) - self.key_cache.append(key_states) - self.value_cache.append(value_states) - elif ( - not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model - ): # fills previously skipped layers; checking for tensor causes errors - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - is_empty_layer = ( - len(self.key_cache) == 0 # no cache in any layer - or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it - or not self.key_cache[layer_idx].numel() # the layer has no cache - ) - layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 - return layer_seq_length - - def get_max_cache_shape(self) -> Optional[int]: - """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length.""" - return None - - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]: - """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for - backward compatibility.""" legacy_cache = () - for layer_idx in range(len(self)): - legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + for layer in self.layers: + legacy_cache += ((layer.keys, layer.values),) return legacy_cache @classmethod - def from_legacy_cache( - cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None - ) -> "DynamicCache": - """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for - backward compatibility.""" + def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.FloatTensor, torch.FloatTensor], ...]) -> "Cache": + """ + Converts a cache in the legacy cache format into an equivalent `Cache`. Used for + backward compatibility. + """ cache = cls() if past_key_values is not None: for layer_idx in range(len(past_key_values)): @@ -581,117 +1328,53 @@ def from_legacy_cache( cache.update(key_states, value_states, layer_idx) return cache - def crop(self, max_length: int): - """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be - negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" - # In case it is negative - if max_length < 0: - max_length = self.get_seq_length() - abs(max_length) - - if self.get_seq_length() <= max_length: - return - - for idx in range(len(self.key_cache)): - if self.key_cache[idx].numel(): - self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] - self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] - - def batch_split(self, full_batch_size: int, split_size: int) -> list["DynamicCache"]: - """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by - `_split_model_inputs()` in `generation.utils`""" - out = [] - for i in range(0, full_batch_size, split_size): - current_split = DynamicCache() - current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] - current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] - out.append(current_split) - return out - - @classmethod - def from_batch_splits(cls, splits: list["DynamicCache"]) -> "DynamicCache": - """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in - `generation.utils`""" - cache = cls() - for idx in range(len(splits[0])): - key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx].numel()] - value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx].numel()] - if key_cache != []: - layer_keys = torch.cat(key_cache, dim=0) - layer_values = torch.cat(value_cache, dim=0) - cache.update(layer_keys, layer_values, idx) - return cache - - def batch_repeat_interleave(self, repeats: int): - """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" - for layer_idx in range(len(self)): - self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) - self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) - - def batch_select_indices(self, indices: torch.Tensor): - """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" - for layer_idx in range(len(self)): - self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] - # Utilities for `DynamicCache` <> torch.export support -def _flatten_dynamic_cache( - dynamic_cache: DynamicCache, -): - """Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume""" - if not isinstance(dynamic_cache, DynamicCache): - raise RuntimeError("This pytree flattening function should only be applied to DynamicCache") - - if not is_torch_greater_or_equal_than_2_6: - logger.warning_once( - "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions." - ) - - # NOTE it seems _seen_tokens is deprecated, so probably doesn't need tracking - dictionary = { - "key_cache": getattr(dynamic_cache, "key_cache"), - "value_cache": getattr(dynamic_cache, "value_cache"), - } - return torch.utils._pytree._dict_flatten(dictionary) - - -def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): - dictionary = { - "key_cache": getattr(dynamic_cache, "key_cache"), - "value_cache": getattr(dynamic_cache, "value_cache"), - } - return torch.utils._pytree._dict_flatten_with_keys(dictionary) +if is_torch_greater_or_equal("2.3"): -def _unflatten_dynamic_cache( - values, - context: torch.utils._pytree.Context, -): - dictionary = torch.utils._pytree._dict_unflatten(values, context) - cache = DynamicCache() - for k, v in dictionary.items(): - setattr(cache, k, v) - return cache + def _get_cache_dict(cache: DynamicCache): + if any(not isinstance(layer, DynamicLayer) for layer in cache.layers): + raise RuntimeError("This pytree flattening function should only be applied to DynamicCache") + if not is_torch_greater_or_equal_than_2_6: + logger.warning_once( + "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions." + ) -def _flatten_dynamic_cache_for_fx(cache, spec): - dictionary = { - "key_cache": getattr(cache, "key_cache"), - "value_cache": getattr(cache, "value_cache"), - } - return torch.fx._pytree._dict_flatten_spec(dictionary, spec) + return { + "key_cache": [layer.keys for layer in cache.layers if layer.keys is not None], + "value_cache": [layer.values for layer in cache.layers if layer.values is not None], + } + def _unflatten_dynamic_cache( + values, + context: torch.utils._pytree.Context, + ): + dictionary = torch.utils._pytree._dict_unflatten(values, context) + cache = DynamicCache() + # Reconstruct layers from keys and values lists + key_list = dictionary.get("key_cache", []) + value_list = dictionary.get("value_cache", []) + for idx in range(max(len(key_list), len(value_list))): + key = key_list[idx] if idx < len(key_list) else None + value = value_list[idx] if idx < len(value_list) else None + cache.update(key, value, idx) + return cache -if is_torch_greater_or_equal("2.3"): torch.utils._pytree.register_pytree_node( DynamicCache, - _flatten_dynamic_cache, + lambda dynamic_cache: torch.utils._pytree._dict_flatten(_get_cache_dict(dynamic_cache)), _unflatten_dynamic_cache, serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", - flatten_with_keys_fn=_flatten_with_keys_dynamic_cache, + flatten_with_keys_fn=lambda dynamic_cache: torch.utils._pytree._dict_flatten_with_keys( + _get_cache_dict(dynamic_cache) + ), ) # TODO (tmanlaibaatar) This won't be needed in torch 2.7. - torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx) + torch.fx._pytree.register_pytree_flatten_spec( + DynamicCache, lambda cache, spec: torch.fx._pytree._dict_flatten_spec(_get_cache_dict(cache), spec) + ) class OffloadedCache(DynamicCache): @@ -707,121 +1390,149 @@ class OffloadedCache(DynamicCache): ensure the eviction is scheduled after all computations on that cache are finished. """ - def __init__(self) -> None: - if not ( - torch.cuda.is_available() - or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available()) - ): - raise RuntimeError( - "OffloadedCache can only be used with a GPU" - + (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "") - ) + def __init__(self, config: Optional[PretrainedConfig] = None) -> None: + # Create the underlying cache with offload processor + super().__init__(cache_processor=OffloadedCacheProcessor, config=config) - super().__init__() - self.original_device = [] - self.prefetch_stream = None - self.prefetch_stream = ( - torch.Stream() if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.Stream() - ) - self.beam_idx = None # used to delay beam search operations - def prefetch_layer(self, layer_idx: int): - "Starts prefetching the next layer cache" - if layer_idx < len(self): - with ( - self.prefetch_stream - if is_torch_greater_or_equal("2.7", accept_dev=True) - else torch.cuda.stream(self.prefetch_stream) - ): - # Prefetch next layer tensors to GPU - device = self.original_device[layer_idx] - self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) - self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True) +class StaticCache(Cache): + """ + Static Cache class to be used with `torch.compile(model)` and `torch.export()`. - def evict_previous_layer(self, layer_idx: int): - "Moves the previous layer cache to the CPU" - if len(self) > 2: - # We do it on the default stream so it occurs after all earlier computations on these tensors are done - prev_layer_idx = (layer_idx - 1) % len(self) - self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True) - self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True) + See `Cache` for details on common methods that are implemented by all cache classes. - def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: - "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." - if layer_idx < len(self): - # Evict the previous layer if necessary - if is_torch_greater_or_equal("2.7", accept_dev=True): - torch.accelerator.current_stream().synchronize() - else: - torch.cuda.current_stream().synchronize() - self.evict_previous_layer(layer_idx) - # Load current layer cache to its original device if not already there - original_device = self.original_device[layer_idx] - self.prefetch_stream.synchronize() - key_tensor = self.key_cache[layer_idx] - value_tensor = self.value_cache[layer_idx] - # Now deal with beam search ops which were delayed - if self.beam_idx is not None: - self.beam_idx = self.beam_idx.to(original_device) - key_tensor = key_tensor.index_select(0, self.beam_idx) - value_tensor = value_tensor.index_select(0, self.beam_idx) - # Prefetch the next layer - self.prefetch_layer((layer_idx + 1) % len(self)) - return (key_tensor, value_tensor) - else: - raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + Example: - def reorder_cache(self, beam_idx: torch.LongTensor): - """Saves the beam indices and reorders the cache when the tensor is back to its device.""" - # We delay this operation until the tensors are back to their original - # device because performing torch.index_select on the CPU is very slow - del self.beam_idx - self.beam_idx = beam_idx.clone() + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`. - Return: - A tuple containing the updated key and value states. - """ - # Update the cache - if len(self.key_cache) < layer_idx: - raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.") - elif len(self.key_cache) == layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - self.original_device.append(key_states.device) - self.evict_previous_layer(layer_idx) + >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + + >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + StaticCache() + ``` + """ + + def __init__(self, *args, **kwargs): + super().__init__(layer_classes=[StaticLayer], *args, **kwargs) + + +class HybridCache(Cache): + """ + Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window + attention and global attention in every other layer (originally implemented for Gemma2). + Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] + for global attention. For more information, see the documentation of those layer types. + + See `Cache` for details on common methods that are implemented by all cache classes. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") + + >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HybridCache() + ``` + """ + + def __init__(self, config: PretrainedConfig, *args, **kwargs): + if hasattr(config, "layer_types") and getattr(config, "layer_types", None) is not None: + layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types] else: - key_tensor, value_tensor = self[layer_idx] - self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2) + layer_classes = [StaticLayer] + super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) + + +class SlidingWindowCache(Cache): + """ + Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. + Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.sliding_window - 1`, + if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), + we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. + + The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: + + indices = (slicing + to_shift[-1].sum()-1) % self.sliding_window + tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) + + We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) + + See `Cache` for details on common methods that are implemented by all cache classes. + + Example: - return self.key_cache[layer_idx], self.value_cache[layer_idx] + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache + + >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + + >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt") - # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError - # if a method is not supposed to be supported in a subclass we should set it to None - from_legacy_cache = None + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + SlidingWindowCache() + ``` + """ - to_legacy_cache = None + def __init__(self, *args, **kwargs): + super().__init__(layer_classes=[SlidingWindowLayer], *args, **kwargs) class QuantizedCache(DynamicCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. + + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. + + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]`. + + See `Cache` for details on common methods that are implemented by all cache classes. + """ + + def __init__(self, backend, **kwargs) -> None: + if backend == "quanto": + processor = QuantoQuantizedCacheProcessor + elif backend == "hqq": + processor = HQQQuantizedCacheProcessor + else: + raise ValueError(f"Unknown quantization backend `{backend}`") + + super().__init__(cache_processor=processor, **kwargs) + + +class QuantoQuantizedCache(QuantizedCache): """ A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750). It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. @@ -833,94 +1544,10 @@ class QuantizedCache(DynamicCache): It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and Value in original precision states as a list of tensors, one for each layer. The size of each tensor is `[batch_size, num_heads, seq_len - residual_length, head_dim]` - """ - - def __init__(self, cache_config: QuantizedCacheConfig) -> None: - super().__init__() - - # Used only for QuantCache where the seq-length can't be inferred easily from cache contents - self._seen_tokens = 0 - self._quantized_key_cache: list[torch.Tensor] = [] - self._quantized_value_cache: list[torch.Tensor] = [] - - self.nbits = cache_config.nbits - self.residual_length = cache_config.residual_length - self.q_group_size = cache_config.q_group_size - self.axis_key = cache_config.axis_key - self.axis_value = cache_config.axis_value - self.compute_dtype = cache_config.compute_dtype - self.device = cache_config.device - - super().__init__() - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - if len(self.key_cache) < layer_idx: - raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.") - elif len(self.key_cache) == layer_idx: - self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key)) - self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value)) - self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) - self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) - keys_to_return, values_to_return = key_states, value_states - else: - dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) - dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) - keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states] - values_to_return = [dequant_value, self.value_cache[layer_idx], value_states] - - keys_to_return = torch.cat(keys_to_return, dim=-2) - values_to_return = torch.cat(values_to_return, dim=-2) - if ( - self.key_cache[layer_idx].dim() == 4 - and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length - ): - self._quantized_key_cache[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key) - self._quantized_value_cache[layer_idx] = self._quantize( - values_to_return.contiguous(), axis=self.axis_value - ) - self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) - self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - return keys_to_return, values_to_return - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - if len(self.key_cache) <= layer_idx: - return 0 - # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is - # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx - # this part of code otherwise fails when used to verify attn_weight shape in some models - return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 - - def _quantize(self, tensor, axis): - """Quantizes a key/value using a defined quantization method.""" - raise NotImplementedError("Make sure to implement `_quantize` in a subclass.") - - def _dequantize(self, q_tensor): - """Dequantizes back the tensor that was quantized by `self._quantize()`""" - raise NotImplementedError("Make sure to implement `_dequantize` in a subclass.") - - -class QuantoQuantizedCache(QuantizedCache): - """ - Quantized Cache class that uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. + Uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. - Parameters: - cache_config (`QuantizedCacheConfig`): - A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. + See `Cache` for details on common methods that are implemented by all cache classes. Example: @@ -942,51 +1569,26 @@ class QuantoQuantizedCache(QuantizedCache): ``` """ - def __init__(self, cache_config: CacheConfig) -> None: - super().__init__(cache_config) - - if is_optimum_quanto_available(): - optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) - if optimum_quanto_version <= version.parse("0.2.5"): - raise ImportError( - f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}." - ) - from optimum.quanto import MaxOptimizer, qint2, qint4 - - if self.nbits not in [2, 4]: - raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") - - if self.axis_key not in [0, -1]: - raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}") - - if self.axis_value not in [0, -1]: - raise ValueError( - f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" - ) - - self.qtype = qint4 if self.nbits == 4 else qint2 - self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization + def __init__(self, **kwargs) -> None: + Cache.__init__(self, cache_processor=QuantoQuantizedCacheProcessor, **kwargs) - def _quantize(self, tensor, axis): - # We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore - if is_optimum_quanto_available(): - from optimum.quanto import quantize_weight - scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) - qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) - return qtensor +class HQQQuantizedCache(QuantizedCache): + """ + A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. - def _dequantize(self, qtensor): - return qtensor.dequantize() + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the + original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The + quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. + It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and + Value in original precision states as a list of tensors, one for each layer. The size of each tensor + is `[batch_size, num_heads, seq_len - residual_length, head_dim]` -class HQQQuantizedCache(QuantizedCache): - """ - Quantized Cache class that uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. + Uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. - Parameters: - cache_config (`QuantizedCacheConfig`): - A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. + See `Cache` for details on common methods that are implemented by all cache classes. Example: @@ -1008,355 +1610,103 @@ class HQQQuantizedCache(QuantizedCache): ``` """ - def __init__(self, cache_config: CacheConfig) -> None: - super().__init__(cache_config) - if self.nbits not in [1, 2, 3, 4, 8]: - raise ValueError( - f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}" - ) - - if self.axis_key not in [0, 1]: - raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}") - - if self.axis_value not in [0, 1]: - raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}") - - self.quantizer = HQQQuantizer - - def _quantize(self, tensor, axis): - qtensor, meta = self.quantizer.quantize( - tensor, - axis=axis, - device=self.device, - compute_dtype=self.compute_dtype, - nbits=self.nbits, - group_size=self.q_group_size, - ) - meta["compute_dtype"] = self.compute_dtype - self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype - meta["scale"] = meta["scale"].to(qtensor.device) - meta["zero"] = meta["zero"].to(qtensor.device) - return qtensor, meta - - def _dequantize(self, qtensor): - quant_tensor, meta = qtensor - tensor = self.quantizer.dequantize(quant_tensor, meta) - return tensor - - -class SinkCache(Cache): - """ - Is its now a `custom_generate` repository on the Hub: https://huggingface.co/transformers-community/sink_cache. - See [these docs](https://huggingface.co/docs/transformers/generation_strategies#custom-decoding-methods) for - general `custom_generate`usage. - """ - - # TODO (joao, manuel): Remove this class in v4.59.0 - def __init__(self, **kwargs) -> None: - raise NotImplementedError( - "`SinkCache` has been moved as a `custom_generate` repository on the Hub: " - "https://huggingface.co/transformers-community/sink_cache. See the repository for usage examples." - ) + def __init__(self, backend="HQQ", **kwargs) -> None: + assert backend == "HQQ" + Cache.__init__(self, cache_processor=HQQQuantizedCacheProcessor, **kwargs) -class StaticCache(Cache): +class OffloadedStaticCache(StaticCache): """ - Static Cache class to be used with `torch.compile(model)` and `torch.export()`. + A drop-in replacement for StaticCache that conserves accelerator memory by offloading + cache tensors to CPU when not actively being used. - Parameters: - config (`PretrainedConfig`): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. If you are manually setting the batch size, make sure to take into account the - number of beams if you are running beam search - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. - tp_size (`Optional[int]`, *optional*): - The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache - if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the - number of key/value heads will not be adjusted. + This cache maintains the compilation-friendly properties of StaticCache while enabling + much longer sequences by offloading inactive layers to CPU memory. + See `Cache` for details on common methods that are implemented by all cache classes. Example: - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache - >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt") + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> # Prepare a cache class with offloading >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = OffloadedStaticCache( + ... config=model.config, + ... max_batch_size=1, + ... max_cache_len=max_generated_length, + ... device=model.device, + ... dtype=model.dtype + ... ) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - StaticCache() + >>> outputs.past_key_values # access cache with offloaded layers + OffloadedStaticCache() ``` """ - is_compileable = True - - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: torch.dtype = torch.float32, - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - tp_size: Optional[int] = None, - ) -> None: - super().__init__() - self.max_batch_size = max_batch_size - self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len - - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - - self._dtype = dtype - self.num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - if tp_size is not None and tp_size > 1: - if self.num_key_value_heads % tp_size != 0: - raise ValueError( - f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}." - ) - # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. - self.num_key_value_heads //= tp_size - - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - # Note: There will be significant perf decrease if switching to use 5D tensors instead. - cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) - device = torch.device(device) if device is not None else None - for idx in range(config.num_hidden_layers): - if layer_device_map is not None: - layer_device = layer_device_map[idx] - else: - layer_device = device - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, - # preventing compiled graph breaks when updating the cache. - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - It is VERY important to index using a tensor, otherwise you introduce a copy to the device. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input - to know how where to write in the cache. - - Return: - A tuple containing the updated key and value states. - """ - if cache_kwargs is None: - cache_kwargs = {} - - key_states = key_states.to(self.key_cache[layer_idx].dtype) - value_states = value_states.to(self.value_cache[layer_idx].dtype) - return _static_cache_update( - self.key_cache[layer_idx], - self.value_cache[layer_idx], - key_states, - value_states, - cache_kwargs.get("cache_position"), - ) - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states that were seen by the model.""" - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() - - def get_max_cache_shape(self) -> Optional[int]: - return self.max_cache_len - - def reset(self): - """Resets the cache values while preserving the objects""" - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ - kv_length = self.get_max_cache_shape() - return kv_length, 0 + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs) -class SlidingWindowCache(StaticCache): +class HybridChunkedCache(Cache): """ - Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. - Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`, - if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), - we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. - - The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: - - indices = (slicing + to_shift[-1].sum()-1) % self.config.sliding_window - tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, - 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, - 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, - 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) - - We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) + Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window + attention and global attention in every other layer, with support for prefill chunking (originally implemented + for Llama4). + Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] + for global attention. For more information, see the documentation of each subcomponent cache class. - Parameters: - config (`PretrainedConfig`): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. + See `Cache` for details on common methods that are implemented by all cache classes. Example: ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache - >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") - >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt") + >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation - SlidingWindowCache() + HybridCache() ``` """ - is_compileable = True - - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: torch.dtype = torch.float32, - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - ) -> None: - if not hasattr(config, "sliding_window") or config.sliding_window is None: - raise ValueError( - "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - max_cache_len = min(config.sliding_window, max_cache_len) - self.sliding_window = config.sliding_window - super().__init__( - config=config, - max_batch_size=max_batch_size, - max_cache_len=max_cache_len, - device=device, - dtype=dtype, - layer_device_map=layer_device_map, - ) + def __init__(self, config: PretrainedConfig, *args, **kwargs): + hybrid_map = LAYER_CLASS_MAP.copy() + hybrid_map["sliding_attention"] = ChunkedSlidingLayer + hybrid_map["chunked_attention"] = ChunkedSlidingLayer + if hasattr(config, "layer_types") and getattr(config, "layer_types", None) is not None: + layer_classes = [hybrid_map[layer_type] for layer_type in config.layer_types] + else: + layer_classes = [StaticLayer] + super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if cache_kwargs is None: - cache_kwargs = {} - cache_position = cache_kwargs.get("cache_position") - if cache_position is None: - raise ValueError("`cache_position` must be provided for SlidingWindowCache.") - - key_states = key_states.to(self.key_cache[layer_idx].dtype) - value_states = value_states.to(self.value_cache[layer_idx].dtype) - - return _sliding_cache_update( - self.key_cache[layer_idx], - self.value_cache[layer_idx], - key_states, - value_states, - cache_position, - self.max_cache_len, - ) +class OffloadedHybridCache(HybridChunkedCache): + """ + A drop-in replacement for HybridChunkedCache that conserves accelerator memory by offloading + cache tensors to CPU when not actively being used. - def get_max_cache_shape(self) -> Optional[int]: - return self.max_cache_len + This cache maintains the compilation-friendly properties of HybridChunkedCache while enabling + much longer sequences by offloading inactive layers to CPU memory. - def reset(self): - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() + See `Cache` for details on common methods that are implemented by all cache classes. + """ - def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ - query_length = cache_position.shape[0] - first_cache_position = cache_position[0] - # torch.clamp() is equivalent to max() but should be compile-friendly/exportable as first_cache_position is a Tensor - kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0) - # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns - kv_length = max(query_length, self.get_max_cache_shape()) - return kv_length, kv_offset + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs) class EncoderDecoderCache(Cache): @@ -1364,6 +1714,8 @@ class EncoderDecoderCache(Cache): Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and cross-attention caches. + See `Cache` for details on common methods that are implemented by all cache classes. + Example: ```python @@ -1385,6 +1737,9 @@ class EncoderDecoderCache(Cache): """ + # Override @property from Cache + is_compileable = None + def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): super().__init__() self.self_attention_cache = self_attention_cache @@ -1392,7 +1747,7 @@ def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False) self.is_updated = {} - for layer_idx in range(len(cross_attention_cache.key_cache)): + for layer_idx in range(len(cross_attention_cache)): self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) def __iter__(self): @@ -1402,10 +1757,10 @@ def __iter__(self): """ for layer_idx in range(len(self)): yield ( - self.self_attention_cache.key_cache[layer_idx], - self.self_attention_cache.value_cache[layer_idx], - self.cross_attention_cache.key_cache[layer_idx], - self.cross_attention_cache.value_cache[layer_idx], + self.self_attention_cache.layers[layer_idx].keys, + self.self_attention_cache.layers[layer_idx].values, + self.cross_attention_cache.layers[layer_idx].keys, + self.cross_attention_cache.layers[layer_idx].values, ) def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -1415,10 +1770,10 @@ def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch """ if layer_idx < len(self): return ( - self.self_attention_cache.key_cache[layer_idx], - self.self_attention_cache.value_cache[layer_idx], - self.cross_attention_cache.key_cache[layer_idx], - self.cross_attention_cache.value_cache[layer_idx], + self.self_attention_cache.layers[layer_idx].keys, + self.self_attention_cache.layers[layer_idx].values, + self.cross_attention_cache.layers[layer_idx].keys, + self.cross_attention_cache.layers[layer_idx].values, ) else: raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") @@ -1444,7 +1799,7 @@ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]: @classmethod def from_legacy_cache( - cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None + cls, past_key_values: tuple[tuple[torch.FloatTensor, torch.FloatTensor], ...] ) -> "EncoderDecoderCache": """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" cache = cls( @@ -1461,10 +1816,10 @@ def from_legacy_cache( cache.is_updated[layer_idx] = True return cache - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position=None) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor` - return self.self_attention_cache.get_seq_length(layer_idx) + return self.self_attention_cache.get_seq_length(layer_idx, cache_position) def reset(self): if hasattr(self.self_attention_cache, "reset"): @@ -1498,14 +1853,18 @@ def check_dynamic_cache(self, method: str): # TODO(gante, sanchit-gandhi): move following functionality into `.generate` def crop(self, maximum_length: int): - """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be - negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" + """ + Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be + negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search. + """ self.check_dynamic_cache(self.crop.__name__) self.self_attention_cache.crop(maximum_length) def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]": - """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by - `_split_model_inputs()` in `generation.utils`""" + """ + Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils` + """ self.check_dynamic_cache(self.batch_split.__name__) self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) @@ -1515,22 +1874,6 @@ def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDec out.append(EncoderDecoderCache(self_attn, cross_attn)) return out - @classmethod - def from_batch_splits(cls, splits: list["EncoderDecoderCache"]) -> "EncoderDecoderCache": - """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in - `generation.utils`""" - self_attention_cache = DynamicCache() - cross_attention_cache = DynamicCache() - for idx in range(len(splits[0])): - layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0) - layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0) - self_attention_cache.update(layer_keys, layer_values, idx) - - layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0) - layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0) - cross_attention_cache.update(layer_keys, layer_values, idx) - return cls(self_attention_cache, cross_attention_cache) - def batch_repeat_interleave(self, repeats: int): """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" self.check_dynamic_cache(self.batch_repeat_interleave.__name__) @@ -1543,469 +1886,367 @@ def batch_select_indices(self, indices: torch.Tensor): self.self_attention_cache.batch_select_indices(indices) self.cross_attention_cache.batch_select_indices(indices) - def get_max_cache_shape(self) -> Optional[int]: + def get_max_cache_shape(self) -> int: """Returns the maximum sequence length (i.e. max capacity) of the cache object""" return self.self_attention_cache.get_max_cache_shape() def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx) -class HybridCache(Cache): +def parse_processor_args(processor_class: Optional[type["CacheProcessor"]], kwargs: dict) -> tuple[dict, dict]: """ - Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window - attention and global attention in every other layer (originally implemented for Gemma2). - Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] - for global attention.For more information, see the documentation of each subcomponent cache class. - - Parameters: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (torch.dtype, *optional*, defaults to `torch.float32`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. - tp_size (`Optional[int]`, *optional*): - The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache - if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the - number of key/value heads will not be adjusted. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + Parse processor arguments from kwargs based on the processor class init signature. - >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") - - >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + Args: + processor_class: The processor class to inspect, or None + kwargs: Dictionary of keyword arguments - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - HybridCache() - ``` + Returns: + tuple: (processor_kwargs, remaining_kwargs) + """ + try: + params = list(inspect.signature(processor_class.__init__).parameters)[2:] + except Exception: + return {}, kwargs + + processor_kwargs = {k: kwargs[k] for k in params if k in kwargs} + remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs} + return processor_kwargs, remaining_kwargs + + +def parse_layer_args_from_model_config( + config: Optional[PretrainedConfig], + batch_size: Optional[int] = None, + max_cache_len: Optional[int] = None, + device: Union[torch.device, str, None] = None, + dtype: Optional[torch.dtype] = None, + layer_device_map: Optional[dict[int, torch.device]] = None, + tp_size: Optional[int] = None, + max_batch_size: Optional[int] = None, +) -> dict: """ + Parse layer arguments from model configuration for cache initialization. - is_compileable = True + Args: + config (`Optional[PretrainedConfig]`): Model configuration containing shape/device info. + batch_size (`Optional[int]`): Batch size for cache initialization. + max_cache_len (`Optional[int]`): Maximum sequence length for cache. + device (`Union[torch.device, str, None]`): Device for cache tensors. + dtype (`Optional[torch.dtype]`): Data type for cache tensors. + layer_device_map: Per-layer device mapping. + tp_size (`Optional[int]`): Tensor parallel size to adjust number of key/value heads. + max_batch_size (`Optional[int]`): Maximum batch size for cache initialization. - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: torch.dtype = torch.float32, - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - tp_size: Optional[int] = None, - ) -> None: - super().__init__() - if not hasattr(config, "sliding_window") or config.sliding_window is None: - raise ValueError( - "Setting `cache_implementation` to 'hybrid' requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - self.max_cache_len = max_cache_len if max_cache_len is not None else config.max_position_embeddings - # Sliding layers can't be larger than the overall max cache len - self.sliding_window_len = min(config.sliding_window, self.max_cache_len) - self.max_batch_size = max_batch_size - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - self.head_dim = ( - config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + Returns: + `dict`: Dictionary containing parsed layer arguments for cache initialization. + """ + # No model config -> must be a dynamic cache, return bare dict + if config is None: + return {} + # Build the args dict for hybrid, sliding or static + else: + # Hybrid/Sliding caches require a config that supports sliding_window (max_cache_len already used) + if ( + getattr(config, "layer_types", None) is not None + and "sliding_attention" in config.layer_types + and "full_attention" in config.layer_types + ): + if getattr(config, "sliding_window", None) is None: + raise ValueError( + "Setting up a hybrid or sliding window KVCache requires the model config supporting " + "sliding window attention, please check if there is a `sliding_window` field in the model " + "config and it's not set to None." + ) + # Adjust max_cache_len for sliding window layers (they can't be larger than sliding window) + max_cache_len = max_cache_len or config.max_position_embeddings + if getattr(config, "sliding_window", None) is not None: + sliding_window_len = min(config.sliding_window, max_cache_len) + else: + sliding_window_len = None + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads: + head_dim = ( + config.head_dim + if getattr(config, "head_dim", None) is not None + else config.hidden_size // config.num_attention_heads ) - - self._dtype = dtype - self.num_key_value_heads = ( + num_heads = ( config.num_attention_heads if getattr(config, "num_key_value_heads", None) is None else config.num_key_value_heads ) if tp_size is not None and tp_size > 1: - if self.num_key_value_heads % tp_size != 0: + if num_heads % tp_size != 0: raise ValueError( - f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}." + f"Number of key value heads {num_heads} must be divisible by tensor parallel size {tp_size}." ) # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. - self.num_key_value_heads //= tp_size - - # If the attribute does not exist in the config, fallback to a simple StaticCache - if hasattr(config, "layer_types"): - self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types] - else: - self.is_sliding = [False] * config.num_hidden_layers - - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) - sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim) - self.sliding_window = min(config.sliding_window, max_cache_len) - device = torch.device(device) if device is not None else None - for i in range(config.num_hidden_layers): - if layer_device_map is not None: - layer_device = layer_device_map[i] - else: - layer_device = device - # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. - cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if cache_kwargs is None: - cache_kwargs = {} - cache_position = cache_kwargs.get("cache_position") - if cache_position is None: - raise ValueError("`cache_position` must be provided for HybridCache.") - - is_sliding_layer = self.is_sliding[layer_idx] - - # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used - # when the cache is initialized in the forward pass (e.g. Gemma2) - if self.key_cache[layer_idx].device != key_states.device: - self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) - if self.value_cache[layer_idx].device != value_states.device: - self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) - - k_cache = self.key_cache[layer_idx] - v_cache = self.value_cache[layer_idx] - key_states = key_states.to(k_cache.dtype) - value_states = value_states.to(v_cache.dtype) - - if is_sliding_layer: - return _sliding_cache_update( - k_cache, - v_cache, - key_states, - value_states, - cache_position, - k_cache.shape[2], # Use actual cache dim as max cache len - ) - else: - return _static_cache_update(k_cache, v_cache, key_states, value_states, cache_position) - - def get_max_cache_shape(self) -> Optional[int]: - return self.max_cache_len + num_heads //= tp_size + layer_args = { + "batch_size": max_batch_size if max_batch_size is not None else batch_size, + "max_cache_len": max_cache_len, + "device": torch.device(device) if device is not None else None, + "dtype": dtype, + "layer_device_map": layer_device_map, + "head_dim": head_dim, + "num_heads": num_heads, + "sliding_window": sliding_window_len, + } + return {k: v for k, v in layer_args.items() if v is not None} + + +LAYER_CLASS_MAP: dict[str, type["CacheLayerMixin"]] = { + "full_attention": StaticLayer, + "sliding_attention": SlidingWindowLayer, + "chunked_attention": SlidingWindowLayer, +} +PROCESSOR_CLASS_MAP: dict[str, type["CacheProcessor"]] = { + "offloaded": OffloadedCacheProcessor, + "quanto_quantized": QuantizedCacheProcessor, + "hqq_quantized": HQQQuantizedCacheProcessor, +} + + +### Deprecated classes - def get_seq_length(self, layer_idx: Optional[int] = 0): - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - if layer_idx != 0: - raise ValueError( - "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " - "Using the `layer_idx` argument is not supported." - ) - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() - def reset(self): - """Resets the cache values while preserving the objects""" - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() +class SinkCache(Cache): + """ + Is its now a `custom_generate` repository on the Hub: https://huggingface.co/transformers-community/sink_cache. + See [these docs](https://huggingface.co/docs/transformers/generation_strategies#custom-decoding-methods) for + general `custom_generate`usage. + """ - def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ - if self.is_sliding[layer_idx]: - query_length = cache_position.shape[0] - first_cache_position = cache_position[0] + # TODO (joao, manuel): Remove this class in v4.59.0 + def __init__(self, **kwargs) -> None: + raise NotImplementedError( + "`SinkCache` has been moved as a `custom_generate` repository on the Hub: " + "https://huggingface.co/transformers-community/sink_cache. See the repository for usage examples." + ) - local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0) - # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns - local_mask_kv_length = max(query_length, self.sliding_window) - return local_mask_kv_length, local_mask_kv_offset - full_mask_kv_offset = 0 - full_mask_kv_length = self.get_max_cache_shape() - return full_mask_kv_length, full_mask_kv_offset +@dataclass +class CacheConfig: + """ + Base class for cache configs. Deprecated in favor of a simpler dictionary. + """ + cache_implementation: None -class HybridChunkedCache(Cache): - """ - Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window - attention and global attention in every other layer, with support for chunked attention (originally implemented - for Llama4). - Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] - for global attention. For more information, see the documentation of each subcomponent cache class. + def __post_init__(self): + logger.warning_once( + "CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary." + ) - Parameters: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (torch.dtype, *optional*, defaults to `torch.bfloat16`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. + @classmethod + def from_dict(cls, config_dict, **kwargs): + """ + Constructs a CacheConfig instance from a dictionary of parameters. + Args: + config_dict (dict[str, Any]): Dictionary containing configuration parameters. + **kwargs: Additional keyword arguments to override dictionary values. - Example: + Returns: + CacheConfig: Instance of CacheConfig constructed from the dictionary. + """ + logger.warning_once( + "CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary." + ) + config = cls(**config_dict) + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + return config - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. - >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default + `QuantizationConfig()` is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" - >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + writer.write(json_string) - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - HybridCache() - ``` - """ + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict + def to_dict(self) -> dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return copy.deepcopy(self.__dict__) - is_compileable = True + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ + def __iter__(self): + """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" + for attr, value in copy.deepcopy(self.__dict__).items(): + yield attr, value - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: torch.dtype = torch.bfloat16, - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - ) -> None: - super().__init__() - if not hasattr(config, "sliding_window") or config.sliding_window is None: - self.sliding_window = getattr(config.get_text_config(), "attention_chunk_size", 8192) - else: - self.sliding_window = config.sliding_window - self.max_cache_len = max_cache_len - # Sliding layers can't be larger than the overall max cache len - self.sliding_window = min(self.sliding_window, self.max_cache_len) - self.max_batch_size = max_batch_size - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self._dtype = dtype - - # If the attribute does not exist in the config, fallback to a simple StaticCache - if hasattr(config, "layer_types"): - self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types] - else: - self.is_sliding = [False] * config.num_hidden_layers + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - self.cumulative_length = [0 for _ in range(config.num_hidden_layers)] + def to_json_string(self): + """ + Serializes this instance to a JSON formatted string. + Returns: + str: JSON formatted string representing the configuration instance. + """ + return json.dumps(self.__dict__, indent=2) + "\n" - def initialise_cache_layer(self, layer_idx, key_states): - if len(self.key_cache) > layer_idx: - return + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, + returning all the unused kwargs. - num_key_value_heads = key_states.shape[1] - device = key_states.device - global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim) - sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self.sliding_window, self.head_dim) - # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. - cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) + Args: + kwargs (`dict[str, Any]`): + Dictionary of attributes to tentatively update this class. - def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): - cumulative_length = self.cumulative_length[layer_idx] - # Update it now that we saved the value above - self.cumulative_length[layer_idx] += key_states.shape[-2] - is_full = cumulative_length >= max_cache_len - if is_full: - full_key_states = torch.cat((k_out[:, :, 1:, :], key_states), dim=-2) - full_value_states = torch.cat((v_out[:, :, 1:, :], value_states), dim=-2) - # Fast decoding path -> here as the effective size is still sliding window, it is extremely important - # to return `self.key_cache[layer_idx]` and `self.value_cache[layer_idx]`, as they have the fixed address - # in memory (the values are the same as the full states, but not the address!!) - if key_states.shape[-2] == 1: - self.key_cache[layer_idx].copy_(full_key_states) - self.value_cache[layer_idx].copy_(full_value_states) - return self.key_cache[layer_idx], self.value_cache[layer_idx] - elif not is_full and cumulative_length + key_states.shape[2] > max_cache_len: - # Fast prefill path, no need to cat() in this case (which creates a copy even if cating from 0 dim) - if cumulative_length == 0: - full_key_states = key_states - full_value_states = value_states - else: - full_key_states = torch.cat((k_out[:, :, :cumulative_length, :], key_states), dim=-2) - full_value_states = torch.cat((v_out[:, :, :cumulative_length, :], value_states), dim=-2) - else: - self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) - self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :]) - self.value_cache[layer_idx].copy_(full_value_states[:, :, -max_cache_len:, :]) - # we should return the whole states instead of k_out, v_out to take the whole prompt - # into consideration when building kv cache instead of just throwing away tokens outside of the window - return full_key_states, full_value_states + Returns: + `dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) - def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states + # Remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs - self.key_cache[layer_idx] = k_out - self.value_cache[layer_idx] = v_out - return k_out, v_out - def update( +@dataclass +class QuantizedCacheConfig(CacheConfig): + """ + Configuration class for quantized cache settings. Deprecated in favor of a simpler dictionary. + + Attributes: + backend (`str`, *optional*, defaults to `"quanto"`): + Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`] + nbits (`Optional[int]`, *optional*, defaults to 4): + Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2. + axis_key (`int`, *optional*, defaults to 0): + Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + axis_value (`int`, *optional*, defaults to 0): + Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. + q_group_size (`Optional[int]`, *optional*, defaults to 64): + Size of the quantization group, should be a divisor of the model's hidden dimension. + Defaults to 64. + residual_length (`Optional[int]`, *optional*, defaults to 128): + Length of the residual cache which will always be stored in original precision. + Defaults to 128. + compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization. + device (`str`, *optional*, defaults to `"cpu"`): + Device on which to perform computations, should be same as the model's device. + """ + + def __init__( self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if cache_kwargs is None: - cache_kwargs = {} - cache_position = cache_kwargs.get("cache_position") - self.initialise_cache_layer(layer_idx, key_states) - - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) - - update_fn = self._sliding_update if self.is_sliding[layer_idx] else self._static_update - return update_fn( - cache_position, - layer_idx, - key_states, - value_states, - k_out, - v_out, - k_out.shape[2], + backend: str = "quanto", + nbits: Optional[int] = 4, + axis_key: Optional[int] = 0, + axis_value: Optional[int] = 0, + q_group_size: Optional[int] = 64, + residual_length: Optional[int] = 128, + compute_dtype: Optional[torch.dtype] = torch.float16, + device: Optional[str] = "cpu", + ): + logger.warning_once( + "CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary." ) + self.backend = backend + self.nbits = nbits + self.axis_key = axis_key + self.axis_value = axis_value + self.q_group_size = q_group_size + self.residual_length = residual_length + self.compute_dtype = compute_dtype + self.device = device - def get_max_cache_shape(self) -> Optional[int]: - return self.max_cache_len + def validate(self): + """Validates if the arguments passed are correct""" - def get_seq_length(self, layer_idx: Optional[int] = 0): - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - if layer_idx != 0: + incorrect_arg_msg = ( + "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " + "but found {found_value}" + ) + # Check that the values are reasonable in general (nbits, axis) + # Later in QuantizedCache init we check if they are supported for that particular backend + if self.nbits not in [1, 2, 3, 4, 8]: + raise ValueError( + incorrect_arg_msg.format( + key="nbits", + correct_value="2 or 4 or 8", + found_value=self.nbits, + ), + ) + if self.q_group_size <= 0: raise ValueError( - "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " - "Using the `layer_idx` argument is not supported." + incorrect_arg_msg.format( + key="q_group_size", + correct_value="a positive integer", + found_value=self.q_group_size, + ), + ) + if self.residual_length < 0: + raise ValueError( + incorrect_arg_msg.format( + key="residual_length", + correct_value="a positive integer", + found_value=self.residual_length, + ), ) - if len(self.key_cache) == 0: - return 0 - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() - - def reset(self): - """Resets the cache values while preserving the objects""" - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - self.cumulative_length = [0 for _ in range(len(self.cumulative_length))] - def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. - """ - if self.is_sliding[layer_idx]: - query_length = cache_position.shape[0] - first_cache_position = cache_position[0] - - local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0) - # This is the true general case for any Cache using local attention (sliding or chunked) - if first_cache_position >= self.sliding_window: - # Here the Cache is already full - local_mask_kv_length = self.sliding_window + query_length - 1 - elif ( - first_cache_position < self.sliding_window - and first_cache_position + query_length > self.sliding_window - ): - # Here the Cache becomes full with the new input - local_mask_kv_length = first_cache_position + query_length - else: - # Here the Cache is still smaller than the local size, but we return the local size as it's static - local_mask_kv_length = self.sliding_window - return local_mask_kv_length, local_mask_kv_offset + if self.axis_key not in [0, 1, -1]: + raise ValueError( + incorrect_arg_msg.format( + key="axis_key", + correct_value="`1` or `0`, `-1`", + found_value=self.axis_key, + ), + ) - full_mask_kv_offset = 0 - full_mask_kv_length = self.get_max_cache_shape() - return full_mask_kv_length, full_mask_kv_offset + if self.axis_value not in [0, 1, -1]: + raise ValueError( + incorrect_arg_msg.format( + key="axis_value", + correct_value="`1` or `0` or `-1`", + found_value=self.axis_value, + ), + ) -class OffloadedHybridCache(HybridChunkedCache): - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: torch.dtype = torch.bfloat16, - offload_device: Union[str, torch.device] = torch.device("cpu"), - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - ): - super().__init__(config, max_batch_size, max_cache_len, device, dtype, layer_device_map) +@dataclass +class StaticCacheConfig(CacheConfig): + """ + Configuration class for static cache settings. + """ - # TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps - # track of the original device of each layer - unique_devices = set(layer_device_map.values()) if layer_device_map else set() - if len(unique_devices) > 1: - raise ValueError(f"OffloadedHybridCache does not support multiple devices. Got devices: {unique_devices}") + cache_implementation = "static" - self.offload_device = torch.device(offload_device) - # Create new CUDA stream for parallel prefetching. - self._prefetch_stream = torch.cuda.Stream() if torch._C._get_accelerator().type == "cuda" else None - # Those will be dynamically created as the other layers (for TP) - self.device_key_cache = None - self.device_value_cache = None - # This gives the index of which on-device full layer to use (we need 2 to avoid race conditions when prefetching) - self.active_device_layer = 0 + def __init__(self, batch_size: int, max_cache_len: int, device="cpu"): + logger.warning_once( + "CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary." + ) + self.batch_size = batch_size + self.max_cache_len = max_cache_len + self.device = device def initialise_cache_layer(self, layer_idx, key_states): """Overridden to use the correct device if offloaded layer (and pin memory).""" @@ -2093,289 +2334,13 @@ def _prefetch_layer_in_context(self, layer_idx: int) -> None: self.device_value_cache[self.active_device_layer].fill_(0.0) -class OffloadedStaticCache(StaticCache): - """ - Static cache class to be used with `torch.compile(model)` that offloads to the CPU or - another device. - - Args: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize - the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. - max_cache_len (`int`): - The maximum sequence length with which the model will be used. - device (`Union[str, torch.device]`): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (`torch.dtype`, *optional*): - The default `dtype` to use when initializing the cache. - offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`): - The device to offload to. Defaults to CPU. - layer_device_map (`dict[int, Union[str, torch.device, int]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. - tp_size (`Optional[int]`, *optional*): - The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache - if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the - number of key/value heads will not be adjusted. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache - - >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - - >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = OffloadedStaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation - ``` - """ - - is_compileable = True - - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: Optional[int], - device: Union[str, torch.device], - dtype: Optional[torch.dtype] = None, - offload_device: Union[str, torch.device] = torch.device("cpu"), - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - tp_size: Optional[int] = None, - ) -> None: - super(Cache, self).__init__() - - # TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps - # track of the original device of each layer - unique_devices = set(layer_device_map.values()) if layer_device_map else set() - if len(unique_devices) > 1: - raise ValueError(f"OffloadedStaticCache does not support multiple devices. Got devices: {unique_devices}") - - self.max_batch_size = max_batch_size - self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len - self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0]) - self.offload_device = torch.device(offload_device) - self._dtype = dtype if dtype is not None else torch.float32 - - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads - - num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - if tp_size is not None and tp_size > 1: - if num_key_value_heads % tp_size != 0: - raise ValueError( - f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}." - ) - # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. - num_key_value_heads //= tp_size - - cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim) - - # Create offloaded CPU tensors. - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - - for i in range(config.num_hidden_layers): - # First layer is always on-device. - device = self.device if i == 0 else self.offload_device - - key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, device) - - self.key_cache.append(key_cache) - self.value_cache.append(value_cache) - - # Create device tensors. - self._device_key_cache: list[torch.Tensor] = [] - self._device_value_cache: list[torch.Tensor] = [] - - for i in range(2): - key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, self.device) - - self._device_key_cache.append(key_cache) - self._device_value_cache.append(value_cache) - - # Create new CUDA stream for parallel prefetching. - self._prefetch_stream = torch.cuda.Stream() if self.device.type == "cuda" else None - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - It is VERY important to index using a tensor, otherwise you introduce a copy to the device. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, *optional*): - Additional arguments for the cache subclass. The `OffloadedStaticCache` needs the - `cache_position` input to know how where to write in the cache. - - Return: - A tuple containing the updated key and value states. - """ - - key_states = key_states.to(self.key_cache[layer_idx].dtype) - value_states = value_states.to(self.value_cache[layer_idx].dtype) - - if layer_idx == 0: - # Always there. - k_out = self.key_cache[0] - v_out = self.value_cache[0] - else: - # Wait for prefetch stream. - if self._prefetch_stream is not None: - torch.cuda.default_stream(self.device).wait_stream(self._prefetch_stream) - - k_out = self._device_key_cache[layer_idx & 1] - v_out = self._device_value_cache[layer_idx & 1] - - self._prefetch_layer(layer_idx + 1) - - cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None - if cache_position is None: - k_out.copy_(key_states) - v_out.copy_(value_states) - - # Copy the values to the offloaded device as well. - if layer_idx == 0: - self.key_cache[layer_idx].copy_(key_states.to(self.offload_device)) - self.value_cache[layer_idx].copy_(value_states.to(self.offload_device)) - else: - # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to - # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does - # explicitly an in-place operation, that avoids copies and uses less memory. - try: - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS - # device. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - # Copy the values to the offloaded device as well. - if layer_idx != 0: - cache_position = cache_position.to(self.offload_device) - key_states = key_states.to(self.offload_device) - value_states = value_states.to(self.offload_device) - - try: - self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) - self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS - # device. - self.key_cache[layer_idx][:, :, cache_position] = key_states - self.value_cache[layer_idx][:, :, cache_position] = value_states - - return k_out, v_out - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - is_empty_layer = ( - len(self.key_cache) == 0 # no cache in any layer - or len(self.key_cache) <= layer_idx # hasn't run a layer with cache after it - or not self.key_cache[layer_idx].numel() # the layer has no cache - ) - layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 - return layer_seq_length - - def get_max_cache_shape(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states.""" - - return self.max_cache_len - - def reset(self) -> None: - """Resets the cache values while preserving the objects.""" - - # Zero out cache. - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address. - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - def _create_key_value_cache_tensors( - self, shape: tuple[int, ...], device: torch.device - ) -> tuple[torch.Tensor, torch.Tensor]: - """Creates K/V cache tensors on a device. Pins memory for CPU tensors. Marks them as static - addresses for non-CPU tensors. - - Args: - shape (`tuple[int, ...]`): Shape. - device (`torch.device`): Device. - - Returns: - Key and value cache tensors as a tuple. - """ - - is_cpu_device = device == torch.device("cpu") - - key_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device) - value_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device) - - # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, - # preventing compiled graph breaks when updating the cache. - torch._dynamo.mark_static_address(key_cache) - torch._dynamo.mark_static_address(value_cache) - - return key_cache, value_cache - - def _prefetch_layer(self, layer_idx: int) -> None: - """Prefetch a layer to the device. Needs to be called in order of layer indices.""" - - # Don't fetch layers that do not exist. - if layer_idx >= len(self.key_cache): - return - - # Alternate between two on-device caches. - if self._prefetch_stream is not None: - with torch.cuda.stream(self._prefetch_stream): - self._prefetch_layer_in_context(layer_idx) - else: - self._prefetch_layer_in_context(layer_idx) - - def _prefetch_layer_in_context(self, layer_idx: int) -> None: - """Performs the actual copy of the layer to device cache.""" - - self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True) - self._device_value_cache[layer_idx & 1].copy_(self.value_cache[layer_idx], non_blocking=True) - - # TODO (manuel, joao): remove this class, it is here only for backwards compatibility # PEP 562: Lazy loading for deprecated location of MambaCache def __getattr__(name: str) -> Any: if name == "MambaCache": - warnings.warn( - ( - "Importing `MambaCache` from `transformers.cache_utils` is deprecated and will be removed " - "in a future version. Please import it from `transformers` or `transformers.models.mamba.cache_mamba` instead." - ), - FutureWarning, - stacklevel=2, + logger.warning_once( + "Importing `MambaCache` from `transformers.cache_utils` is deprecated and will be removed " + "in a future version. Please import it from `transformers` or `transformers.models.mamba.cache_mamba` instead." ) class MambaCache: diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 165252927c1c..7d2cd21effb2 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -44,7 +44,6 @@ logger = logging.get_logger(__name__) METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version") -CACHE_CONFIG_MAPPING = {} NEED_SETUP_CACHE_CLASSES_MAPPING = {} QUANT_BACKEND_CLASSES_MAPPING = {} ALL_CACHE_IMPLEMENTATIONS = [] @@ -56,16 +55,12 @@ HybridChunkedCache, OffloadedHybridCache, OffloadedStaticCache, - QuantizedCacheConfig, QuantoQuantizedCache, SlidingWindowCache, StaticCache, - StaticCacheConfig, ) from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor - CACHE_CONFIG_MAPPING["quantized"] = QuantizedCacheConfig - CACHE_CONFIG_MAPPING["static"] = StaticCacheConfig NEED_SETUP_CACHE_CLASSES_MAPPING = { "static": StaticCache, "offloaded_static": OffloadedStaticCache, @@ -76,9 +71,7 @@ "offloaded_hybrid_chunked": OffloadedHybridCache, } QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache} - ALL_CACHE_IMPLEMENTATIONS = ( - list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(CACHE_CONFIG_MAPPING.keys()) + ["offloaded", "dynamic"] - ) + ALL_CACHE_IMPLEMENTATIONS = list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + ["offloaded", "dynamic", "quantized"] class GenerationMode(ExplicitEnum): @@ -188,10 +181,8 @@ class GenerationConfig(PushToHubMixin): If none is specified, we will use the default cache for the model (which is often [`DynamicCache`]). See our [cache documentation](https://huggingface.co/docs/transformers/en/kv_cache) for further information. - cache_config (`CacheConfig` or `dict`, *optional*, default to `None`): - Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and - it will be converted to its respective `CacheConfig` internally. - Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`. + cache_config (`dict`, *optional*, default to `None`): + Arguments used in the key-value cache class can be passed in `cache_config`. return_legacy_cache (`bool`, *optional*, default to `True`): Whether to return the legacy or new format of the cache when `DynamicCache` is used by default. @@ -406,10 +397,16 @@ def __init__(self, **kwargs): self.use_cache = kwargs.pop("use_cache", True) self.cache_implementation = kwargs.pop("cache_implementation", None) self.cache_config = kwargs.pop("cache_config", None) - if self.cache_implementation is not None and self.cache_implementation in CACHE_CONFIG_MAPPING: - cache_config_class = CACHE_CONFIG_MAPPING[self.cache_implementation] - if isinstance(self.cache_config, dict): - self.cache_config = cache_config_class.from_dict(self.cache_config) + if self.cache_config is not None and not isinstance(self.cache_config, dict): + warnings.warn( + ( + "Passing a CacheConfig object is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary." + ), + FutureWarning, + stacklevel=2, + ) + self.cache_config = self.cache_config.to_dict() + self.return_legacy_cache = kwargs.pop("return_legacy_cache", None) self.prefill_chunk_size = kwargs.pop("prefill_chunk_size", None) @@ -611,17 +608,6 @@ def validate(self, strict=False): f"Invalid `cache_implementation` ({self.cache_implementation}). Choose one of: " f"{ALL_CACHE_IMPLEMENTATIONS}" ) - if self.cache_config is not None: - cache_class = CACHE_CONFIG_MAPPING.get(self.cache_implementation) - if cache_class is None: - raise ValueError( - "You provided a `cache_config` but the cache implementation you are using " - f"({self.cache_implementation}) does not require any config. Make sure to use the " - "correct cache implementation matching your cache config." - ) - if not isinstance(self.cache_config, cache_class): - self.cache_config = cache_class.from_dict(self.cache_config) - self.cache_config.validate() # 1.3. Performance attributes if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig): raise ValueError( diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e360acdac341..94b243277d29 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -35,7 +35,6 @@ HybridChunkedCache, OffloadedCache, OffloadedHybridCache, - QuantizedCacheConfig, ) from ..configuration_utils import PretrainedConfig from ..dynamic_module_utils import ( @@ -2064,22 +2063,22 @@ def _prepare_cache_for_generation( cache_config = ( generation_config.cache_config if generation_config.cache_config is not None - else QuantizedCacheConfig() + else {"backend": "quanto"} ) - cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] + cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config["backend"]] - if cache_config.backend == "quanto" and not is_optimum_quanto_available(): + if cache_config["backend"] == "quanto" and not is_optimum_quanto_available(): raise ImportError( "You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. " "Please install it via with `pip install optimum-quanto`" ) - elif cache_config.backend == "HQQ" and not is_hqq_available(): + elif cache_config["backend"] == "HQQ" and not is_hqq_available(): raise ImportError( "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " "Please install it via with `pip install hqq`" ) - model_kwargs[cache_name] = cache_class(cache_config) + model_kwargs[cache_name] = cache_class(**cache_config) elif generation_config.cache_implementation == "offloaded": model_kwargs[cache_name] = OffloadedCache() elif generation_config.cache_implementation == "dynamic": @@ -5215,106 +5214,6 @@ def _ranking_fast( return selected_idx -def _split(data, full_batch_size: int, split_size: int): - """ - Takes care of three cases: - 1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim - 2. data is a tuple: e.g. hidden_states, attentions etc. Keep the tuple as it is and split each tensor in it and - return a list of tuples - 3. data is a tuple of tuples, e.g. past_key_values. Keep the tuple as it is and split each tuple in it and - return a list of tuples of tuples - (see documentation of ModelOutput) - """ - if data is None: - return [None] * (full_batch_size // split_size) - if isinstance(data, torch.Tensor): - return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)] - # New cache format - elif isinstance(data, DynamicCache) or ( - isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache) - ): - return data.batch_split(full_batch_size, split_size) - elif isinstance(data, tuple): - # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) - if isinstance(data[0], tuple): - return [ - tuple(tuple(tensor[i : i + split_size] for tensor in inner_tuple) for inner_tuple in data) - for i in range(0, full_batch_size, split_size) - ] - - else: - return [ - tuple(sub_tensor[i : i + split_size] for sub_tensor in data) - for i in range(0, full_batch_size, split_size) - ] - else: - raise TypeError(f"Unexpected attribute type: {type(data)}") - - -def _split_model_inputs( - model_input: Union[ModelOutput, dict], split_size: int, full_batch_size: int, config: PretrainedConfig -) -> list[Union[ModelOutput, dict]]: - """ - Split a ModelOutput object (or its subclasses) or Dict into a list of same-class objects based on a specified split - size. The input object is dict when it was prepared for forward pass and ModelOutput when it was returned from - previous forward pass. - """ - # Edge case: if model_input is None, return a list of Nones - # this happens with Whisper where encoder_outputs is None - if model_input is None: - return [model_input] * (full_batch_size // split_size) - # Infer the class from the object - model_output_cls = type(model_input) - if (full_batch_size % split_size) != 0: - raise ValueError("`full_batch_size` must be divisible by `split_size`") - - if split_size > full_batch_size: - raise ValueError("`split_size` must be smaller or equal to `full_batch_size`") - - # Helper function to split tensors or tuples of tensors - - # Find all the dataclass fields (e.g., last_hidden_state, pooler_output etc.) and split them - keys = ( - model_input.__dataclass_fields__.keys() if hasattr(model_input, "__dataclass_fields__") else model_input.keys() - ) - # We only keep keys that are in the model_input - keys = [k for k in keys if k in model_input] - # Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a - # ModelOutput object. - # bool should not be split but replicated for each split - bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] - keys_to_ignore = ["cache_position", "encoder_outputs", "logits_to_keep"] - non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] - - # we split the tensors and tuples of tensors - data_split_list = [ - {k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys} - for i in range(full_batch_size // split_size) - ] - # bool values are the same and replicated for each split - bool_data = {k: model_input[k] for k in bool_keys} - # encoder_outputs is a ModelOutput object and should be split by its own - if "encoder_outputs" in model_input: - encoder_outputs_split = _split_model_inputs( - model_input["encoder_outputs"], split_size, full_batch_size, config.get_text_config() - ) - data_split_list = [ - {**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list) - ] - # logits_to_keep should be replicated for each split, similar to bool values - if "logits_to_keep" in model_input: - data_split_list = [ - {**data_split, "logits_to_keep": model_input["logits_to_keep"]} for data_split in data_split_list - ] - - # Convert each dictionary in the list to an object of the inferred class - split_model_inputs: list[Union[ModelOutput, dict]] = [ - model_output_cls(**data_split, **bool_data) for data_split in data_split_list - ] - - return split_model_inputs - - def stack_model_outputs(model_outputs: list[ModelOutput], config: PretrainedConfig) -> ModelOutput: """ Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the @@ -5339,11 +5238,6 @@ def _concat(data): return None if isinstance(data[0], torch.Tensor): return torch.cat(data, dim=0) - # New cache format - elif isinstance(data[0], DynamicCache): - return DynamicCache.from_batch_splits(data) - elif isinstance(data[0], EncoderDecoderCache): - return EncoderDecoderCache.from_batch_splits(data) elif isinstance(data[0], tuple): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0][0], tuple): diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 71777d123cda..6fa0e6348d66 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -290,14 +290,14 @@ def __init__(self, model: PreTrainedModel): self.model = model self.static_cache = StaticCache( config=self.model.config, - max_batch_size=self.model.generation_config.cache_config.batch_size, - max_cache_len=self.model.generation_config.cache_config.max_cache_len, - device=self.model.generation_config.cache_config.device, + max_batch_size=self.model.generation_config.cache_config.get("batch_size"), + max_cache_len=self.model.generation_config.cache_config.get("max_cache_len"), + device=self.model.generation_config.cache_config.get("device"), dtype=self.model.dtype, ) - for i in range(len(self.static_cache.key_cache)): - self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) - self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False) + for i in range(len(self.static_cache)): + self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False) + self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False) def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor): """ @@ -429,9 +429,9 @@ def __init__( ) # Register all key and value cache tensors as buffers - for i in range(len(self.cache.key_cache)): - self.register_buffer(f"key_cache_{i}", self.cache.key_cache[i], persistent=False) - self.register_buffer(f"value_cache_{i}", self.cache.value_cache[i], persistent=False) + for i in range(len(self.cache)): + self.register_buffer(f"key_cache_{i}", self.cache.layers[i].keys, persistent=False) + self.register_buffer(f"value_cache_{i}", self.cache.layers[i].values, persistent=False) def forward( self, @@ -580,9 +580,9 @@ def __init__(self, model, max_static_cache_length, batch_size): self.cache = EncoderDecoderCache(self.static_cache, DynamicCache()) # Register cache buffers to make them exportable - for i in range(len(self.static_cache.key_cache)): - self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) - self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False) + for i in range(len(self.static_cache)): + self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False) + self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False) def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): # Get outputs from decoder diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 3548e706a5e0..617bbb8087e6 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -85,7 +85,7 @@ class BambaFlashAttentionKwargs(TypedDict, total=False): seq_idx: torch.IntTensor -class HybridMambaAttentionDynamicCache(DynamicCache): +class HybridMambaAttentionDynamicCache(Cache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -99,6 +99,10 @@ class HybridMambaAttentionDynamicCache(DynamicCache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ + key_cache = None + value_cache = None + is_compileable = False + def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None): super().__init__() self.layers_block_type = config.layers_block_type diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 055b696405dc..02877c3d89cf 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -229,8 +229,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index d567808f95af..ac3636a048a0 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1287,8 +1287,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index a873cd6b6967..d17e1664da5b 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -207,8 +207,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 65f0378ef531..445ef06b0bea 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -228,8 +228,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 9030bd1e5ce6..dc7eac4390c5 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -212,8 +212,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index 2bf05cf683ce..89d99b20c33b 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -362,8 +362,8 @@ def forward( is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False if past_key_values is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx] - value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx] + key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys + value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values else: key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2) value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2) diff --git a/src/transformers/models/dia/modular_dia.py b/src/transformers/models/dia/modular_dia.py index 8c84d936c543..94775242a526 100644 --- a/src/transformers/models/dia/modular_dia.py +++ b/src/transformers/models/dia/modular_dia.py @@ -181,8 +181,8 @@ def forward( is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False if past_key_values is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx] - value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx] + key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys + value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values else: key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2) value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2) diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 33955d3a6b0e..8b099342f6ee 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -62,7 +62,7 @@ logger = logging.get_logger(__name__) -class FalconHybridMambaAttentionDynamicCache(DynamicCache): +class FalconHybridMambaAttentionDynamicCache(Cache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -76,6 +76,10 @@ class FalconHybridMambaAttentionDynamicCache(DynamicCache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ + key_cache = None + value_cache = None + is_compileable = False + def __init__( self, config: FalconH1Config, diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 1411cccef9a9..2c4f3e669854 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1329,7 +1329,7 @@ def forward( if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None: # Device of past layer may be different from current one - indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device) + indices = cache_position.to(past_key_value.layers[self.kv_shared_layer_index].keys.device) # In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond) if isinstance(past_key_value, HybridCache) and self.is_sliding: max_length = past_key_value.sliding_window @@ -1340,9 +1340,9 @@ def forward( ) # Device of past layer may be different from current one - key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device) - value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to( - query_states.device + key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices].to(query_states.device) + value_states = ( + past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices].to(query_states.device) ) else: key_states = self.k_proj(hidden_states).view(hidden_shape) diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index 8e3bcfd1f156..a17bfdba026b 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -1770,7 +1770,7 @@ def forward( if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None: # Device of past layer may be different from current one - indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device) + indices = cache_position.to(past_key_value.layers[self.kv_shared_layer_index].keys.device) # In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond) if isinstance(past_key_value, HybridCache) and self.is_sliding: max_length = past_key_value.sliding_window @@ -1781,9 +1781,9 @@ def forward( ) # Device of past layer may be different from current one - key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device) - value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to( - query_states.device + key_states = past_key_value.layers[self.kv_shared_layer_index].keys[:, :, indices].to(query_states.device) + value_states = ( + past_key_value.layers[self.kv_shared_layer_index].values[:, :, indices].to(query_states.device) ) else: key_states = self.k_proj(hidden_states).view(hidden_shape) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 43822682df35..9f8c08f96f81 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -701,8 +701,9 @@ def forward( # Ensure layer_past is on same device as hidden_states (might not be correct) if past_key_values is not None: - past_key_values.key_cache = past_key_values.key_cache.to(hidden_states.device) - past_key_values.value_cache = past_key_values.value_cache.to(hidden_states.device) + for layer in past_key_values.layers: + layer.keys = layer.keys.to(hidden_states.device) + layer.values = layer.values.to(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states if causal_mask is not None: diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 8b3f3d1dccd9..40fb784e3ed0 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -221,7 +221,7 @@ def forward( return attn_output, attn_weights -class HybridMambaAttentionDynamicCache(DynamicCache): +class HybridMambaAttentionDynamicCache(Cache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -235,6 +235,10 @@ class HybridMambaAttentionDynamicCache(DynamicCache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ + key_cache = None + value_cache = None + is_compileable = False + def __init__(self, config: GraniteMoeHybridConfig, batch_size, dtype=torch.float16, device=None): super().__init__() self.layers_block_type = config.layers_block_type diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index ed4cda28e47d..e60b3a1f1288 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -487,8 +487,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) @@ -604,8 +604,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py index 0b3ecb59367e..99c3f07bd2eb 100644 --- a/src/transformers/models/informer/modular_informer.py +++ b/src/transformers/models/informer/modular_informer.py @@ -293,8 +293,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index d6a2aaabd83b..8f259bb0c017 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -183,7 +183,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class HybridMambaAttentionDynamicCache(DynamicCache): +class HybridMambaAttentionDynamicCache(Cache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -197,6 +197,10 @@ class HybridMambaAttentionDynamicCache(DynamicCache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ + key_cache = None + value_cache = None + is_compileable = False + def __init__(self, config, batch_size, dtype=torch.float16, device=None): super().__init__() self.dtype = dtype diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index 0d383769d1c4..ca653e4114af 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -128,6 +128,12 @@ class Lfm2HybridConvCache(DynamicCache): Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`. """ + # Override @property existing in Cache + max_batch_size = None + is_compileable = False + key_cache = None + value_cache = None + def __init__( self, config: Lfm2Config, @@ -135,7 +141,8 @@ def __init__( dtype: torch.dtype = torch.float32, device: Union[torch.device, str, None] = None, ): - super().__init__() # initialize key and value cache + self.key_cache = [] + self.value_cache = [] self.max_batch_size = max_batch_size self.layer_types = config.layer_types self.first_attention_layer = self.layer_types.index("full_attention") @@ -218,6 +225,35 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return 0 return self.key_cache[layer_idx].shape[-2] + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), + for each layer. + """ + full_mask_kv_offset = 0 + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length() + kv_length = query_length + past_seen_tokens + return kv_length, full_mask_kv_offset + + def crop(self, max_length: int): + """Crop the cache to the given length""" + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + for idx in range(len(self.key_cache)): + if self.key_cache[idx].numel(): + self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] + self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + return self.key_cache[layer_idx], self.value_cache[layer_idx] + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: raise NotImplementedError("Lfm2HybridConvCache does not have a legacy cache equivalent.") diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py index 338e6ec5242d..75ef05e3182c 100644 --- a/src/transformers/models/lfm2/modular_lfm2.py +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -89,6 +89,12 @@ class Lfm2HybridConvCache(DynamicCache): Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`. """ + # Override @property existing in Cache + max_batch_size = None + is_compileable = False + key_cache = None + value_cache = None + def __init__( self, config: Lfm2Config, @@ -96,7 +102,8 @@ def __init__( dtype: torch.dtype = torch.float32, device: Union[torch.device, str, None] = None, ): - super().__init__() # initialize key and value cache + self.key_cache = [] + self.value_cache = [] self.max_batch_size = max_batch_size self.layer_types = config.layer_types self.first_attention_layer = self.layer_types.index("full_attention") @@ -179,6 +186,35 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return 0 return self.key_cache[layer_idx].shape[-2] + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), + for each layer. + """ + full_mask_kv_offset = 0 + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length() + kv_length = query_length + past_seen_tokens + return kv_length, full_mask_kv_offset + + def crop(self, max_length: int): + """Crop the cache to the given length""" + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + for idx in range(len(self.key_cache)): + if self.key_cache[idx].numel(): + self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] + self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + return self.key_cache[layer_idx], self.value_cache[layer_idx] + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: raise NotImplementedError("Lfm2HybridConvCache does not have a legacy cache equivalent.") diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index d4e29c619b11..71e5178ef573 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -481,8 +481,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 6790872107a1..59641c60846b 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -294,8 +294,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 83de5dd6aebc..da09d6954c2a 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -39,10 +39,6 @@ logger = logging.get_logger(__name__) -if is_mambapy_available(): - from mambapy.pscan import pscan -else: - pscan = None if is_mamba_ssm_available(): from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn @@ -334,6 +330,10 @@ def cuda_kernels_forward( # fmt: off def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None): + if is_mambapy_available(): + from mambapy.pscan import pscan + else: + pscan = None batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated MLP's linear projection diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 5f988b3a82ee..8a3e7965f506 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -229,8 +229,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 0a6880415f9e..5aafbc1de6e8 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -238,8 +238,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 6923fdc91abb..f47bb101decb 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -108,16 +108,14 @@ def batch_repeat_interleave(self, repeats: int): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0) else: - self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) - self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.layers[layer_idx].batch_repeat_interleave(repeats) def batch_select_indices(self, indices: torch.Tensor): for layer_idx in range(len(self)): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...] else: - self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + self.layers[layer_idx].batch_select_indices(indices) def crop(self, max_length: int): raise RuntimeError("MiniMaxCache doesnot support `crop` method") diff --git a/src/transformers/models/minimax/modular_minimax.py b/src/transformers/models/minimax/modular_minimax.py index 423ae27717c4..6844a9d0fc63 100644 --- a/src/transformers/models/minimax/modular_minimax.py +++ b/src/transformers/models/minimax/modular_minimax.py @@ -217,16 +217,14 @@ def batch_repeat_interleave(self, repeats: int): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0) else: - self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) - self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.layers[layer_idx].batch_repeat_interleave(repeats) def batch_select_indices(self, indices: torch.Tensor): for layer_idx in range(len(self)): if self.linear_cache[layer_idx] != []: self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...] else: - self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + self.layers[layer_idx].batch_select_indices(indices) def crop(self, max_length: int): raise RuntimeError("MiniMaxCache doesnot support `crop` method") diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 7c126f42f1e5..fbc8b287b6ae 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -496,8 +496,8 @@ def forward( ) elif cache_position[0] != 0: key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], + past_key_value.layers[self.layer_idx].keys, + past_key_value.layers[self.layer_idx].values, ) else: raise ValueError( diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 9b229e4074c0..c2472e43c151 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -234,8 +234,8 @@ def forward( # use key_value_states if cross attention current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = past_key_value.layers[self.layer_idx].keys + value_states = past_key_value.layers[self.layer_idx].values else: key_states = ( self.k_proj(current_states) diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 9706d99d7cd5..818ffea807be 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -333,8 +333,8 @@ def forward( # use key_value_states if cross attention current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = past_key_value.layers[self.layer_idx].keys + value_states = past_key_value.layers[self.layer_idx].values else: key_states = ( self.k_proj(current_states) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index a9d0fd9781a7..0d2ed6b402d0 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -379,8 +379,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 05e34da4d2b3..d37ee0007c5a 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -228,8 +228,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index ee48a02b04cf..e2e04804c038 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -249,8 +249,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 0662acf7e698..4257132452b1 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -352,10 +352,6 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 5d080e8f0c99..820d59bad3e9 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -773,8 +773,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.key(current_states) value_states = self.value(current_states) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 4236476349b5..88304f49647f 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -424,8 +424,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 795dfb587421..bdf712a2f949 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -323,8 +323,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index d71dffb98782..f22484047bdc 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -516,8 +516,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index b5ff699f69a3..e1570a731d33 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -504,8 +504,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index c2fdbf5fc7d4..eba0f5790311 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -339,8 +339,8 @@ def forward( key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx) past_key_value.is_updated[self.layer_idx] = True else: - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index 5c72d76b4e17..72d19151ce2a 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -291,8 +291,8 @@ def forward( key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx) past_key_value.is_updated[self.layer_idx] = True else: - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index d03430002fce..84d5de004c51 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -394,8 +394,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 70a174474b7d..817ea913aed5 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -602,8 +602,8 @@ def forward( current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 47b11acfd8dc..cf2e999449bf 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -288,8 +288,8 @@ def forward( current_states = encoder_hidden_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values else: key_states = self.k(current_states) value_states = self.v(current_states) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index f72ce7bd40ba..1cdca9ba0382 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1149,8 +1149,8 @@ def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None for layer_idx in range(self.config.decoder_layers): layer_past_key_values = [] for cache_cls in [values.self_attention_cache, values.cross_attention_cache]: - for v in [cache_cls.key_cache, cache_cls.value_cache]: - layer_past_key_values.append(v[layer_idx][batch_idx][None].cpu()) + for v in [cache_cls.layers[layer_idx].keys, cache_cls.layers[layer_idx].values]: + layer_past_key_values.append(v[batch_idx][None].cpu()) all_past_key_values.append(tuple(layer_past_key_values)) return EncoderDecoderCache.from_legacy_cache(tuple(all_past_key_values)) else: diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 7b74c4c0b853..917113cbb00c 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -330,8 +330,8 @@ def forward( current_states = key_value_states if key_value_states is not None else hidden_states if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = past_key_value.layers[self.layer_idx].keys + value_states = past_key_value.layers[self.layer_idx].values else: key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index b1a8f8c51696..16290ea4e1b7 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -93,7 +93,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class ZambaHybridDynamicCache(DynamicCache): +class ZambaHybridDynamicCache(Cache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -107,8 +107,13 @@ class ZambaHybridDynamicCache(DynamicCache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ + key_cache = None + value_cache = None + is_compileable = False + def __init__(self, config, batch_size, dtype=torch.float16, device=None): self.dtype = dtype + self.is_compileable = False self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba self.intermediate_size = config.mamba_expand * config.hidden_size @@ -138,6 +143,12 @@ def __init__(self, config, batch_size, dtype=torch.float16, device=None): self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + def __len__(self): + return len(self.key_cache) + + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + return self.key_cache[layer_idx], self.value_cache[layer_idx] + # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.update def update( self, diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index fe1482bcdfcb..77838cc63ac5 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -97,7 +97,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class Zamba2HybridDynamicCache(DynamicCache): +class Zamba2HybridDynamicCache(Cache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -111,6 +111,10 @@ class Zamba2HybridDynamicCache(DynamicCache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ + key_cache = None + value_cache = None + is_compileable = False + def __init__( self, config: Zamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None ): @@ -143,6 +147,12 @@ def __init__( self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + def __len__(self): + return len(self.key_cache) + + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + return self.key_cache[layer_idx], self.value_cache[layer_idx] + def update( self, key_states: torch.Tensor, @@ -1364,7 +1374,7 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if past_key_values and not past_key_values.has_previous_state: + if past_key_values is not None and not past_key_values.has_previous_state: past_key_values.has_previous_state = True output = BaseModelOutputWithPast( diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 05565c60d6f4..55e3eb45e860 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -1133,7 +1133,7 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if past_key_values and not past_key_values.has_previous_state: + if past_key_values is not None and not past_key_values.has_previous_state: past_key_values.has_previous_state = True output = BaseModelOutputWithPast( diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index fab1672b5c86..26a5a7bf3d1e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1577,9 +1577,7 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): # 3. Check cache shapes # 3.1. Encoder-Decoder checks if config.is_encoder_decoder: - num_cache_decoder_layers = ( - len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache) - ) + num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache) self.assertEqual(num_cache_decoder_layers, num_decoder_layers) for i in range(num_decoder_layers): @@ -1587,30 +1585,30 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_key_cache = ( - past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i] + self_attention_layer_keys = ( + past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys ) - self_attention_layer_value_cache = ( - past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i] + self_attention_layer_values = ( + past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values ) - self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) # Cross attention (ignore 3rd dim, see default shape preparation) - cross_attention_layer_key_cache = ( - past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i] + cross_attention_layer_keys = ( + past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys ) - cross_attention_layer_value_cache = ( - past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i] + cross_attention_layer_values = ( + past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values ) - cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :] - cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :] - self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2]) - self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3]) + cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :] + cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :] + self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2]) + self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3]) # 3.2. Decoder-only checks else: - num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache) + num_cache_decoder_layers = len(past_kv) self.assertEqual(num_cache_decoder_layers, num_decoder_layers) for i in range(num_decoder_layers): @@ -1618,10 +1616,18 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i] - self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i] - self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + if is_legacy_cache: + self_attention_layer_keys = past_kv[i][0] + self_attention_layer_values = past_kv[i][1] + elif getattr(past_kv, "layers", None) is None: + # Cache is lot layered (i.e, Mamba derivatives) + self_attention_layer_keys = past_kv.key_cache[i] + self_attention_layer_values = past_kv.value_cache[i] + else: + self_attention_layer_keys = past_kv.layers[i].keys + self_attention_layer_values = past_kv.layers[i].values + self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) @pytest.mark.generate def test_generate_from_random_inputs_embeds(self): @@ -1804,8 +1810,8 @@ def test_generate_from_inputs_embeds_with_static_cache(self): max_length = max_new_tokens + inputs_embeds.shape[1] - 1 cache_shape = [batch_size, num_key_value_heads, max_length, head_dim] self.assertIsInstance(outputs.past_key_values, StaticCache) - self.assertEqual(len(outputs.past_key_values.key_cache), num_hidden_layers) - self.assertListEqual(list(outputs.past_key_values.key_cache[0].shape), cache_shape) + self.assertEqual(len(outputs.past_key_values), num_hidden_layers) + self.assertListEqual(list(outputs.past_key_values.layers[0].keys.shape), cache_shape) @pytest.mark.generate def test_generate_continue_from_past_key_values(self): @@ -2027,8 +2033,8 @@ def test_generate_with_static_cache(self): num_hidden_layers = text_config.num_hidden_layers cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim) self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache)) - self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers) - self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape) + self.assertTrue(len(static_cache_generation.past_key_values) == num_hidden_layers) + self.assertTrue(static_cache_generation.past_key_values.layers[0].keys.shape == cache_shape) # Check 2: The outputs must be similar to the case with dynamic cache dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict) @@ -2629,12 +2635,12 @@ def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_value if isinstance(decoder_past_key_values, Cache): self.assertListEqual( - [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], - [expected_shape] * len(decoder_past_key_values.key_cache), + [layer.keys.shape for layer in decoder_past_key_values.layers], + [expected_shape] * len(decoder_past_key_values.layers), ) self.assertListEqual( - [value_tensor.shape for value_tensor in decoder_past_key_values.value_cache], - [expected_shape] * len(decoder_past_key_values.value_cache), + [layer.values.shape for layer in decoder_past_key_values.layers], + [expected_shape] * len(decoder_past_key_values.layers), ) # Legacy cache format checks. This branch should be removed when all models use `Cache` by default @@ -4040,13 +4046,13 @@ def test_generate_with_static_cache_multi_accelerator(self): self.assertTrue(isinstance(results.past_key_values, StaticCache)) # check device of each layer - key_cache_0 = results.past_key_values.key_cache[0] - value_cache_0 = results.past_key_values.value_cache[0] - self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0)) + keys_0 = results.past_key_values.layers[0].keys + values_0 = results.past_key_values.layers[0].values + self.assertTrue(keys_0.device == values_0.device == torch.device(0)) - key_cache_1 = results.past_key_values.key_cache[1] - value_cache_1 = results.past_key_values.value_cache[1] - self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) + keys_1 = results.past_key_values.layers[1].keys + values_1 = results.past_key_values.layers[1].values + self.assertTrue(keys_1.device == values_1.device == torch.device(1)) @pytest.mark.generate @require_torch_multi_accelerator @@ -4118,13 +4124,13 @@ def test_init_static_cache_multi_accelerator(self): results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs) # check device of each layer - key_cache_0 = results.past_key_values.key_cache[0] - value_cache_0 = results.past_key_values.value_cache[0] - self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0)) + keys_0 = results.past_key_values.layers[0].keys + values_0 = results.past_key_values.layers[0].values + self.assertTrue(keys_0.device == values_0.device == torch.device(0)) - key_cache_1 = results.past_key_values.key_cache[1] - value_cache_1 = results.past_key_values.value_cache[1] - self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) + keys_1 = results.past_key_values.layers[1].keys + values_1 = results.past_key_values.layers[1].values + self.assertTrue(keys_1.device == values_1.device == torch.device(1)) @slow def test_padding_input_contrastive_search_gpt2(self): diff --git a/tests/models/deepseek_v2/test_modeling_deepseek_v2.py b/tests/models/deepseek_v2/test_modeling_deepseek_v2.py index 02d087cb8b9a..0bdc6884590f 100644 --- a/tests/models/deepseek_v2/test_modeling_deepseek_v2.py +++ b/tests/models/deepseek_v2/test_modeling_deepseek_v2.py @@ -168,14 +168,9 @@ def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_value expected_value_shape = expected_common_shape + (config.v_head_dim,) if isinstance(decoder_past_key_values, Cache): - self.assertListEqual( - [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], - [expected_key_shape] * len(decoder_past_key_values.key_cache), - ) - self.assertListEqual( - [value_tensor.shape for value_tensor in decoder_past_key_values.value_cache], - [expected_value_shape] * len(decoder_past_key_values.value_cache), - ) + for layer in decoder_past_key_values.layers: + self.assertEqual(layer.keys.shape, expected_key_shape) + self.assertEqual(layer.values.shape, expected_value_shape) @unittest.skip("Deepseek-V2 uses MLA which has a special head dim and is not compatible with StaticCache shape") def test_generate_compilation_all_outputs(self): diff --git a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py index 6c0c3a19d067..87f7b2abb0e9 100644 --- a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py +++ b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py @@ -440,13 +440,11 @@ def test_past_key_values_format(self): # difference: last dim k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim v_embed_dim = config.v_head_dim - self_attention_key_cache_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim) - self_attention_value_cache_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim) + self_attention_keys_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim) + self_attention_values_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim) # build the full cache shapes num_hidden_layers = config.num_hidden_layers - all_cache_shapes = [ - [self_attention_key_cache_shape, self_attention_value_cache_shape] for _ in range(num_hidden_layers) - ] + all_cache_shapes = [[self_attention_keys_shape, self_attention_values_shape] for _ in range(num_hidden_layers)] super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes) @require_torch_large_accelerator diff --git a/tests/models/dia/test_modeling_dia.py b/tests/models/dia/test_modeling_dia.py index 447491f90102..c6d1547afb40 100644 --- a/tests/models/dia/test_modeling_dia.py +++ b/tests/models/dia/test_modeling_dia.py @@ -399,12 +399,12 @@ def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_value if isinstance(decoder_past_key_values, Cache): self.assertListEqual( - [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], - [expected_shape] * len(decoder_past_key_values.key_cache), + [layer.keys.shape for layer in decoder_past_key_values.layers], + [expected_shape] * len(decoder_past_key_values.layers), ) self.assertListEqual( - [value_tensor.shape for value_tensor in decoder_past_key_values.value_cache], - [expected_shape] * len(decoder_past_key_values.value_cache), + [layer.values.shape for layer in decoder_past_key_values.layers], + [expected_shape] * len(decoder_past_key_values.layers), ) def _check_scores(self, batch_size, scores, generated_length, config): diff --git a/tests/models/falcon_h1/test_modeling_falcon_h1.py b/tests/models/falcon_h1/test_modeling_falcon_h1.py index 31c9b72e28b5..5c538ac4b410 100644 --- a/tests/models/falcon_h1/test_modeling_falcon_h1.py +++ b/tests/models/falcon_h1/test_modeling_falcon_h1.py @@ -38,7 +38,7 @@ if is_torch_available(): import torch - from transformers import AutoTokenizer, FalconH1ForCausalLM, FalconH1Model + from transformers import AutoTokenizer, Cache, FalconH1ForCausalLM, FalconH1Model from transformers.models.falcon_h1.modeling_falcon_h1 import ( FalconHybridMambaAttentionDynamicCache, ) @@ -272,6 +272,43 @@ class FalconH1ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM {"feature-extraction": FalconH1Model, "text-generation": FalconH1ForCausalLM} if is_torch_available() else {} ) + def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config): + self.assertIsInstance(decoder_past_key_values, (tuple, Cache)) + + # (batch, head, seq_length, head_features) + expected_shape = ( + batch_size, + config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads, + cache_length, + config.hidden_size // config.num_attention_heads, + ) + + if isinstance(decoder_past_key_values, Cache): + self.assertListEqual( + [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], + [expected_shape] * len(decoder_past_key_values.key_cache), + ) + self.assertListEqual( + [value_cache.shape for value_cache in decoder_past_key_values.value_cache], + [expected_shape] * len(decoder_past_key_values.value_cache), + ) + + # Legacy cache format checks. This branch should be removed when all models use `Cache` by default + else: + self.assertListEqual( + [isinstance(iter_past_key_values, tuple) for iter_past_key_values in decoder_past_key_values], + [True] * len(decoder_past_key_values), + ) + # check shape key, value + self.assertListEqual( + [layer_past_key_values[0].shape for layer_past_key_values in decoder_past_key_values], + [expected_shape] * len(decoder_past_key_values), + ) + self.assertListEqual( + [layer_past_key_values[1].shape for layer_past_key_values in decoder_past_key_values], + [expected_shape] * len(decoder_past_key_values), + ) + def setUp(self): self.model_tester = FalconH1ModelTester(self) self.config_tester = ConfigTester(self, config_class=FalconH1Config, hidden_size=64) diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index b0a0a6a3ccb4..ecd2af9fdc6c 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -235,8 +235,8 @@ def copy_cache(cache: DynamicCache): """Deep copy a DynamicCache to reuse the same one multiple times.""" new_cache = cache for i in range(len(cache)): - new_cache.key_cache[i] = cache.key_cache[i].clone() - new_cache.value_cache[i] = cache.value_cache[i].clone() + new_cache.layers[i].keys = cache.layers[i].keys.clone() + new_cache.layers[i].values = cache.layers[i].values.clone() # Cached forward once with the attention mask provided and the other time without it (which should assume full attention) # We need to run both on a copy of the cache, otherwise it is modified in-place diff --git a/tests/models/t5gemma/test_modeling_t5gemma.py b/tests/models/t5gemma/test_modeling_t5gemma.py index 0020c5c78edc..100787ece9ea 100644 --- a/tests/models/t5gemma/test_modeling_t5gemma.py +++ b/tests/models/t5gemma/test_modeling_t5gemma.py @@ -271,7 +271,7 @@ def create_and_check_model( self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertIsNotNone(decoder_past) self.parent.assertEqual(len(decoder_past.self_attention_cache), config.decoder.num_hidden_layers) - self.parent.assertEqual(len(decoder_past.cross_attention_cache.key_cache), config.decoder.num_hidden_layers) + self.parent.assertEqual(len(decoder_past.cross_attention_cache), config.decoder.num_hidden_layers) def check_prepare_lm_labels_via_shift_left( self, @@ -1060,9 +1060,7 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): # 3. Check cache shapes # 3.1. Encoder-Decoder checks if config.is_encoder_decoder: - num_cache_decoder_layers = ( - len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache) - ) + num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache) self.assertEqual(num_cache_decoder_layers, num_decoder_layers) for i in range(num_decoder_layers): @@ -1070,30 +1068,30 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 5) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_key_cache = ( - past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i] + self_attention_layer_keys = ( + past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys ) - self_attention_layer_value_cache = ( - past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i] + self_attention_layer_values = ( + past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values ) - self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) # Cross attention (ignore 3rd dim, see default shape preparation) - cross_attention_layer_key_cache = ( - past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i] + cross_attention_layer_keys = ( + past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys ) - cross_attention_layer_value_cache = ( - past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i] + cross_attention_layer_values = ( + past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values ) - cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :] - cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :] - self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2]) - self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3]) + cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :] + cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :] + self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2]) + self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3]) # 3.2. Decoder-only checks else: - num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache) + num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv) self.assertEqual(num_cache_decoder_layers, num_decoder_layers) for i in range(num_decoder_layers): @@ -1101,10 +1099,10 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple # Self attention - self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i] - self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i] - self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys + self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values + self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) @unittest.skip("Mismatch issue doesn't exist in T5Gemma.") def test_load_with_mismatched_shapes(self): diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index b1998b7cfede..8dbcc1194314 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -36,7 +36,7 @@ slow, torch_device, ) -from transformers.utils import is_optimum_quanto_available, is_torch_greater_or_equal +from transformers.utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal if is_torch_available(): @@ -49,8 +49,12 @@ DynamicCache, Gemma2Config, GenerationConfig, + HQQQuantizedCacheProcessor, HybridCache, + HybridChunkedCache, LlamaConfig, + QuantizedCache, + QuantoQuantizedCacheProcessor, SlidingWindowCache, StaticCache, convert_and_export_with_cache, @@ -252,6 +256,59 @@ def test_cache_beam_search(self, cache_implementation): decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True) self.assertListEqual(decoded, EXPECTED_GENERATION) + @parameterized.expand([("quanto"), ("HQQ")]) + def test_quantized_cache_generation(self, backend): + """Tests that QuantizedCache works as expected for both `quanto` and `hqq` backends.""" + if backend == "quanto": + if not is_optimum_quanto_available(): + self.skipTest("Quanto is not available") + axis_key, axis_value = 0, 0 + # This output is taken from a run with the same parameters, and is known to be correct + expected_generation = ["The cat's whiskers are also a sign of anxiety."] + elif backend == "HQQ": + if not is_hqq_available(): + self.skipTest("HQQ is not available") + axis_key, axis_value = 1, 1 + # HQQ has slightly different numerics + expected_generation = ["The cat's whiskers are also a sign of anxiety."] + else: + return + + inputs = self.tokenizer(["The cat"], return_tensors="pt").to(self.model.device) + + gen_out = self.model.generate( + **inputs, + do_sample=False, + max_new_tokens=10, + return_dict_in_generate=True, + cache_implementation="quantized", + cache_config={ + "backend": backend, + "nbits": 4, + "q_group_size": 16, + "residual_length": 4, + "axis_key": axis_key, + "axis_value": axis_value, + }, + disable_compile=True, + ) + + self.assertIsInstance(gen_out.past_key_values, QuantizedCache) + processor = gen_out.past_key_values.cache_processor + if backend == "quanto": + self.assertIsInstance(processor, QuantoQuantizedCacheProcessor) + elif backend == "hqq": + self.assertIsInstance(processor, HQQQuantizedCacheProcessor) + + decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True) + self.assertListEqual(decoded, expected_generation) + + self.assertTrue(len(processor._quantized_keys) > 0) + + # Check that something is actually quantized + has_been_quantized = any((q[0] if isinstance(q, tuple) else q).numel() > 0 for q in processor._quantized_keys) + self.assertTrue(has_been_quantized) + @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) def test_cache_extra_left_padding(self, cache_implementation): """Tests that adding extra left-padding does not affect the generation with the cache""" @@ -566,7 +623,7 @@ def test_dynamic_cache_exportability(self): past_key_values=DynamicCache(), use_cache=True, ) - self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers) + self.assertTrue(len(res.past_key_values) == model.config.num_hidden_layers) self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs)) self.assertEqual( 3, @@ -587,11 +644,9 @@ def test_dynamic_cache_exportability(self): use_cache=True, ) self.assertTrue(torch.allclose(res.logits, res_eager.logits, atol=1e-5)) - for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache): - self.assertTrue(torch.allclose(k1, k2, atol=1e-5)) - - for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache): - self.assertTrue(torch.allclose(v1, v2, atol=1e-5)) + for l1, l2 in zip(res.past_key_values.layers, res_eager.past_key_values.layers): + self.assertTrue(torch.allclose(l1.keys, l2.keys, atol=1e-5)) + self.assertTrue(torch.allclose(l1.values, l2.values, atol=1e-5)) def test_dynamic_cache_exportability_multiple_run(self): # When exporting with DynamicCache, you should export two graphs: @@ -615,7 +670,7 @@ def test_dynamic_cache_exportability_multiple_run(self): past_key_values=DynamicCache(), use_cache=True, ) - self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers) + self.assertTrue(len(res.past_key_values) == model.config.num_hidden_layers) self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs)) self.assertEqual( 3, @@ -640,9 +695,9 @@ def test_dynamic_cache_exportability_multiple_run(self): shapes = torch.export.ShapesCollection() dyn = torch.export.Dim("seq", max=512) - for ix in range(len(past_key_values.key_cache)): - shapes[past_key_values.key_cache[ix]] = (None, None, dyn, None) - shapes[past_key_values.value_cache[ix]] = (None, None, dyn, None) + for ix in range(len(past_key_values)): + shapes[past_key_values.layers[ix].keys] = (None, None, dyn, None) + shapes[past_key_values.layers[ix].values] = (None, None, dyn, None) ep_second = torch.export.export( model, @@ -683,11 +738,9 @@ def test_dynamic_cache_exportability_multiple_run(self): use_cache=True, ) - for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache): - self.assertTrue(torch.allclose(k1, k2, atol=1e-5)) - - for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache): - self.assertTrue(torch.allclose(v1, v2, atol=1e-5)) + for l1, l2 in zip(res_export_2.past_key_values.layers, res_eager_2.past_key_values.layers): + self.assertTrue(torch.allclose(l1.keys, l2.keys, atol=1e-5)) + self.assertTrue(torch.allclose(l1.values, l2.values, atol=1e-5)) @unittest.skip("Runs on my machine locally, passed, no idea why it does not online") def test_static_cache_exportability(self): @@ -726,8 +779,8 @@ def test_static_cache_exportability(self): self.assertEqual(model.generation_config.cache_implementation, cache_implementation) self.assertEqual(model.generation_config.max_length, max_cache_len) self.assertTrue(model.generation_config.cache_config is not None) - self.assertEqual(model.generation_config.cache_config.batch_size, batch_size) - self.assertEqual(model.generation_config.cache_config.max_cache_len, max_cache_len) + self.assertEqual(model.generation_config.cache_config.get("batch_size"), batch_size) + self.assertEqual(model.generation_config.cache_config.get("max_cache_len"), max_cache_len) exported_program = convert_and_export_with_cache(model) @@ -830,7 +883,7 @@ def setUp(self): head_dim=1, hidden_size=1, sliding_window=self.window_size, - sliding_window_pattern=2, # Default pattern for hybrid sliding + layer_types=["full_attention"] * 1, # Static cache by default ) def test_static_cache_out_of_bounds(self): @@ -867,7 +920,7 @@ def test_static_cache(self): cache_kwargs={"cache_position": torch.tensor([2])}, ) self.assertEqual( - static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed" + static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed" ) # Scenario 2: Fill to capacity @@ -878,7 +931,7 @@ def test_static_cache(self): cache_kwargs={"cache_position": torch.tensor([3])}, ) self.assertEqual( - static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed" + static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed" ) def test_sliding_window_cache(self): @@ -897,7 +950,9 @@ def test_sliding_window_cache(self): result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens) """ # Scenario 1: Update within window, no slide yet - sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + config = copy.deepcopy(self.config) + config.layer_types = ["sliding_attention"] * config.num_hidden_layers + sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] sliding_cache.update( key_states=prefill, @@ -912,13 +967,13 @@ def test_sliding_window_cache(self): cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, ) self.assertEqual( - sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + sliding_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "SlidingWindowCache Scenario 1 failed", ) # Scenario 2: Update causing slide - sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None] sliding_cache.update( key_states=prefill, @@ -933,13 +988,13 @@ def test_sliding_window_cache(self): cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, ) self.assertEqual( - sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + sliding_cache.layers[0].keys[0, 0, :, 0].tolist(), [2.0, 3.0, 4.0, 5.0], "SlidingWindowCache Scenario 2 failed", ) # Scenario 3: Long prompt handling - sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None] sliding_cache.update( key_states=long_prefill, @@ -948,13 +1003,13 @@ def test_sliding_window_cache(self): cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size}, ) self.assertEqual( - sliding_cache.key_cache[0][0, 0, :, 0].tolist(), + sliding_cache.layers[0].keys[0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "SlidingWindowCache Scenario 3 failed", ) def test_hybrid_cache_static_mode(self): - """Test HybridCache in static mode with hardcoded assertions. + """Test HybridCache with only 1 static layer. Scenario 1: Static layer behavior prefill: [1.0, 2.0, 0.0, 0.0] @@ -964,7 +1019,7 @@ def test_hybrid_cache_static_mode(self): update pos 3: [1.0, 2.0, 3.0, 4.0] """ config = copy.deepcopy(self.config) - config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0) + config.layer_types = ["full_attention"] * config.num_hidden_layers # Scenario 1 hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) @@ -982,7 +1037,7 @@ def test_hybrid_cache_static_mode(self): cache_kwargs={"cache_position": torch.tensor([2])}, ) self.assertEqual( - hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "HybridCache Static Scenario 1 failed", ) @@ -995,7 +1050,7 @@ def test_hybrid_cache_static_mode(self): cache_kwargs={"cache_position": torch.tensor([3])}, ) self.assertEqual( - hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "HybridCache Static Scenario 2 failed", ) @@ -1018,8 +1073,10 @@ def test_hybrid_cache_sliding_mode(self): input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens) """ + config = copy.deepcopy(self.config) + config.layer_types = ["sliding_attention"] * config.num_hidden_layers # Scenario 1: Update within window, no slide yet - hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] hybrid_cache.update( key_states=prefill, @@ -1034,13 +1091,13 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "HybridCache Sliding Scenario 1 failed", ) # Scenario 2: Update causing first slide - hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None] hybrid_cache.update( key_states=prefill, @@ -1055,7 +1112,7 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), [2.0, 3.0, 4.0, 5.0], "HybridCache Sliding Scenario 2 failed", ) @@ -1068,13 +1125,13 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "HybridCache Sliding Scenario 3 failed", ) # Scenario 4: Long prompt handling - hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None] hybrid_cache.update( key_states=long_prefill, @@ -1083,7 +1140,278 @@ def test_hybrid_cache_sliding_mode(self): cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size}, ) self.assertEqual( - hybrid_cache.key_cache[0][0, 0, :, 0].tolist(), + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), [3.0, 4.0, 5.0, 6.0], "HybridCache Sliding Scenario 4 failed", ) + + def test_dynamic_cache(self): + """Test DynamicCache with manually prefilled states and hardcoded assertions. + Scenario 1: prefill and update for one layer + prefill: [1.0, 2.0] + update pos 2: [1.0, 2.0, 3.0] + Scenario 2: prefill and update for two layers independently + """ + prefill = torch.tensor([1.0, 2.0])[None, None, :, None] + update3 = torch.tensor(3.0)[None, None, None, None] + update4 = torch.tensor(4.0)[None, None, None, None] + + # Scenario 1: prefill and update for one layer + cache = DynamicCache() + cache.update(prefill, prefill, 0) + cache.update(update3, update3, 0) + self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0], "DynamicCache Scenario 1 failed") + cache.update(update4, update4, 0) + self.assertEqual( + cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 1 (to 4) failed" + ) + + # Scenario 2: prefill and update for two layers independently + prefill1 = torch.tensor([10.0, 20.0])[None, None, :, None] + update3_1 = torch.tensor(30.0)[None, None, None, None] + update4_1 = torch.tensor(40.0)[None, None, None, None] + + cache = DynamicCache() + cache.update(prefill, prefill, 0) + cache.update(prefill1, prefill1, 1) + + cache.update(update3, update3, 0) + cache.update(update3_1, update3_1, 1) + cache.update(update4, update4, 0) + cache.update(update4_1, update4_1, 1) + self.assertEqual( + cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 2 layer 0 failed" + ) + self.assertEqual( + cache.layers[1].keys[0, 0, :, 0].tolist(), + [10.0, 20.0, 30.0, 40.0], + "DynamicCache Scenario 2 layer 1 failed", + ) + + def test_hybrid_cache(self): + """ + Test HybridCache with a mix of static and sliding layers, + with prefill size bigger than sliding window. + + prefill: + static: [1.0, 2.0, 3.0] + sliding: [10.0, 20.0, 30.0] + (stores only [20.0, 30.0]) + + update pos 4: + static: [1.0, 2.0, 3.0, 5.0] + sliding: [30.0, 50.0] + """ + config = copy.deepcopy(self.config) + config.num_hidden_layers = 2 + config.layer_types = ["full_attention", "sliding_attention"] + config.sliding_window = 2 + hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) + + # Prefill both layers up to cache capacity + prefill_static = torch.tensor([1.0, 2.0, 3.0])[None, None, :, None] + # Sliding window is 2, so it should return full [10.0, 20.0, 30.0], but store only [20.0, 30.0] + prefill_sliding = torch.tensor([10.0, 20.0, 30.0])[None, None, :, None] + + # Update static layer (layer 0) + res_static = hybrid_cache.update( + key_states=prefill_static, + value_states=prefill_static, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(3)}, + ) + + # Update sliding layer (layer 1) + res_sliding = hybrid_cache.update( + key_states=prefill_sliding, + value_states=prefill_sliding, + layer_idx=1, + cache_kwargs={"cache_position": torch.arange(3), "sliding_window": self.window_size}, + ) + + # Verify initial states + self.assertEqual( + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), + [1.0, 2.0, 3.0, 0.0], + "Initial static layer state is wrong", + ) + self.assertEqual( + res_static[0][0, 0, :, 0].tolist(), + [1.0, 2.0, 3.0, 0.0], + "Static layer did not return the correct value.", + ) + self.assertEqual( + hybrid_cache.layers[1].keys[0, 0, :, 0].tolist(), + [20.0, 30.0], + "Initial sliding layer state is wrong", + ) + self.assertEqual( + res_sliding[0][0, 0, :, 0].tolist(), + [10.0, 20.0, 30.0], + "Sliding layer did not return the correct value.", + ) + + # Update at position 4 + new_key_static = torch.tensor(5.0)[None, None, None, None] + new_key_sliding = torch.tensor(50.0)[None, None, None, None] + + # Update static layer (layer 0) + hybrid_cache.update( + key_states=new_key_static, + value_states=new_key_static, + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([3])}, + ) + + # Update sliding layer (layer 1) + hybrid_cache.update( + key_states=new_key_sliding, + value_states=new_key_sliding, + layer_idx=1, + cache_kwargs={"cache_position": torch.tensor([3])}, + ) + + # The static layer does not slide, so it should have updated the element at position 3 + self.assertEqual( + hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), + [1.0, 2.0, 3.0, 5.0], + "Static layer did not update as expected.", + ) + + # The sliding layer should have shifted, discarding the first element and adding the new one at the end + self.assertEqual( + hybrid_cache.layers[1].keys[0, 0, :, 0].tolist(), + [30.0, 50.0], + "Sliding layer did not slide as expected.", + ) + + def test_hybrid_chunked_cache(self): + """ + Test HybridChunkedCache with both static and sliding layers and special cases: + 1. a pre-fill longer than the sliding window + 2. a single-token decoding step (normal generation) + 3. a multi-token decoding step after the window is already full + + Sliding-window size: 2 + Static layer is full-attention. + ───────────────────────────────────────────── + Prefill: + static : [1, 2, 3] + sliding : [10, 20, 30] (cache keeps [20, 30]) + +1 token: + static : [1, 2, 3, 5] + sliding : [30, 50] (returned [30, 50]) + +2 tokens: + sliding : [60, 70] (returned [50, 60, 70]) + """ + + config = copy.deepcopy(self.config) + config.num_hidden_layers = 2 + config.layer_types = ["full_attention", "sliding_attention"] + config.sliding_window = 2 + max_cache_len = 4 + chunked_cache = HybridChunkedCache(config=config, max_batch_size=1, max_cache_len=max_cache_len) + + # 1) PREFILL (3 tokens > sliding_window) + prefill_static = torch.tensor([1.0, 2.0, 3.0])[None, None, :, None] + prefill_sliding = torch.tensor([10.0, 20.0, 30.0])[None, None, :, None] + + res_static = chunked_cache.update( + key_states=prefill_static, + value_states=prefill_static, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(3)}, + ) + res_sliding = chunked_cache.update( + key_states=prefill_sliding, + value_states=prefill_sliding, + layer_idx=1, + cache_kwargs={"cache_position": torch.arange(3)}, + ) + + # Static layer keeps everything + self.assertEqual(res_static[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0]) + # Sliding layer returned full prompt but stored the tail + self.assertEqual(res_sliding[0][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0]) + self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [20.0, 30.0]) + + # 2) ONE-TOKEN UPDATE (normal decode) + new_static = torch.tensor(5.0)[None, None, None, None] + new_sliding = torch.tensor(50.0)[None, None, None, None] + + chunked_cache.update( + key_states=new_static, + value_states=new_static, + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([3])}, + ) + res_one = chunked_cache.update( + key_states=new_sliding, + value_states=new_sliding, + layer_idx=1, + cache_kwargs={"cache_position": torch.tensor([3])}, + ) + + self.assertEqual(chunked_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 5.0]) + self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [30.0, 50.0]) + self.assertEqual(res_one[0][0, 0, :, 0].tolist(), [30.0, 50.0]) + + # 3) TWO-TOKEN UPDATE after window is full + new_sliding_2 = torch.tensor([60.0, 70.0])[None, None, :, None] + res_two = chunked_cache.update( + key_states=new_sliding_2, + value_states=new_sliding_2, + layer_idx=1, + cache_kwargs={"cache_position": torch.tensor([4, 5])}, # arbitrary positions; ignored in full mode + ) + + # Cache now keeps the latest two tokens + self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [60.0, 70.0]) + # Returned tensor contains previous last token + new ones + self.assertEqual(res_two[0][0, 0, :, 0].tolist(), [50.0, 60.0, 70.0]) + + def test_hybrid_chunked_cache_extra_cases(self): + """ + Covers the new cases that appear on prefill chunking: + 1) Not full multi-token update (cache_position[0] + update_len <= max_cache_len) + 2) Multi-token update crossing the window (cache_position[0] < max_cache_len and cache_position[0] + update_len > max_cache_len) + + Single sliding layer, max_cache_len = 3. + + Step 0 (prefill 2 tokens, update_len < max_cache_len + cache = [10, 20, 0] returned [10, 20, 0] + + Step 1 (add 2 tokens, p = 2, update_len = 2, p + update_len = 4 > max_cache_len) + cache = [20, 30, 40] returned [10, 20, 30, 40] + """ + + config = copy.deepcopy(self.config) + config.num_hidden_layers = 1 + config.layer_types = ["sliding_attention"] + config.sliding_window = 3 + cache = HybridChunkedCache(config, max_batch_size=1, max_cache_len=3) + + # Step 0 : multi-token prefill + first_chunk = torch.tensor([10.0, 20.0])[None, None, :, None] # L = 2 + returned_0 = cache.update( + key_states=first_chunk, + value_states=first_chunk, + layer_idx=0, + cache_kwargs={"cache_position": torch.arange(2)}, # p = 0,1 + ) + + # internal cache should have first two tokens and a zero pad + self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [10.0, 20.0, 0.0]) + self.assertEqual(returned_0[0][0, 0, :, 0].tolist(), [10.0, 20.0, 0.0]) + + # Step 1 : multi-token update crossing the window boundary + second_chunk = torch.tensor([30.0, 40.0])[None, None, :, None] # L = 2 + returned_1 = cache.update( + key_states=second_chunk, + value_states=second_chunk, + layer_idx=0, + cache_kwargs={"cache_position": torch.tensor([2, 3])}, # p = 2 + ) + + self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [20.0, 30.0, 40.0]) + self.assertEqual(returned_1[0][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0, 40.0])