diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index c64ba2a3ca43..b19e724e06d0 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -363,37 +363,34 @@ A [`Constraint`] can be used to force the generation to include specific tokens - get_max_cache_shape - reset - reorder_cache + - lazy_initialization [[autodoc]] DynamicLayer - update + - lazy_initialization - crop - batch_repeat_interleave - batch_select_indices [[autodoc]] StaticLayer - update + - lazy_initialization [[autodoc]] SlidingWindowLayer - update + - lazy_initialization -[[autodoc]] CacheProcessor - - pre_update - - post_update - -[[autodoc]] OffloadedCacheProcessor - - pre_update - -[[autodoc]] QuantizedCacheProcessor - - post_update - -[[autodoc]] QuantoQuantizedCacheProcessor - - post_update +[[autodoc]] QuantoQuantizedLayer + - update + - lazy_initialization -[[autodoc]] HQQQuantizedCacheProcessor - - post_update +[[autodoc]] HQQQuantizedLayer + - update + - lazy_initialization [[autodoc]] Cache - update + - early_initialization - get_seq_length - get_mask_sizes - get_max_cache_shape @@ -411,12 +408,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] QuantoQuantizedCache -[[autodoc]] QuantoQuantizedCacheProcessor - [[autodoc]] HQQQuantizedCache -[[autodoc]] HQQQuantizedCacheProcessor - [[autodoc]] OffloadedCache [[autodoc]] StaticCache diff --git a/docs/source/en/kv_cache.md b/docs/source/en/kv_cache.md index a1b6dd81ff16..256bba7c7625 100644 --- a/docs/source/en/kv_cache.md +++ b/docs/source/en/kv_cache.md @@ -312,7 +312,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id) # Init StaticCache with big enough max-length (1024 tokens for the below example) # You can also init a DynamicCache, if that suits you better -prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device=model.device.type, dtype=torch.bfloat16) +prompt_cache = StaticCache(config=model.config, max_cache_len=1024) INITIAL_PROMPT = "You are a helpful assistant. " inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to(model.device.type) diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index 0295a5bf1b34..9b6bdb8b614f 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -93,11 +93,8 @@ model.generation_config.max_new_tokens = 16 past_key_values = StaticCache( config=model.config, - max_batch_size=1, # If you plan to reuse the cache, make sure the cache length is large enough for all cases max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2), - device=model.device, - dtype=model.dtype ) outputs = model.generate(**input_ids, past_key_values=past_key_values) print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) @@ -159,7 +156,7 @@ from torch.nn.attention import SDPBackend, sdpa_kernel batch_size, seq_length = inputs["input_ids"].shape with torch.no_grad(): past_key_values = StaticCache( - config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype + config=model.config, max_cache_len=4096 ) cache_position = torch.arange(seq_length, device=torch_device) generated_ids = torch.zeros( diff --git a/docs/source/en/model_doc/gemma2.md b/docs/source/en/model_doc/gemma2.md index 84f11b1eb24f..08ff2359f4c1 100644 --- a/docs/source/en/model_doc/gemma2.md +++ b/docs/source/en/model_doc/gemma2.md @@ -138,8 +138,7 @@ visualizer("You are an assistant. Make sure you print me") inputs = tokenizer(text="My name is Gemma", return_tensors="pt") max_generated_length = inputs.input_ids.shape[1] + 10 - past_key_values = HybridCache(config=model.config, max_batch_size=1, - max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + past_key_values = HybridCache(config=model.config, max_cache_len=max_generated_length) outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) ``` diff --git a/docs/source/ko/internal/generation_utils.md b/docs/source/ko/internal/generation_utils.md index 9ef510fc2088..cedc34bd74f7 100644 --- a/docs/source/ko/internal/generation_utils.md +++ b/docs/source/ko/internal/generation_utils.md @@ -362,21 +362,11 @@ generation_output[:2] [[autodoc]] SlidingWindowLayer - update -[[autodoc]] CacheProcessor - - pre_update - - post_update - -[[autodoc]] OffloadedCacheProcessor - - pre_update - -[[autodoc]] QuantizedCacheProcessor - - post_update - -[[autodoc]] QuantoQuantizedCacheProcessor - - post_update +[[autodoc]] QuantoQuantizedLayer + - update -[[autodoc]] HQQQuantizedCacheProcessor - - post_update +[[autodoc]] HQQQuantizedLayer + - update [[autodoc]] Cache - update @@ -397,12 +387,8 @@ generation_output[:2] [[autodoc]] QuantoQuantizedCache -[[autodoc]] QuantoQuantizedCacheProcessor - [[autodoc]] HQQQuantizedCache -[[autodoc]] HQQQuantizedCacheProcessor - [[autodoc]] OffloadedCache [[autodoc]] StaticCache diff --git a/docs/source/ko/llm_optims.md b/docs/source/ko/llm_optims.md index f6eaa58c0004..2a631721b88d 100644 --- a/docs/source/ko/llm_optims.md +++ b/docs/source/ko/llm_optims.md @@ -99,11 +99,8 @@ model.generation_config.max_new_tokens = 16 past_key_values = StaticCache( config=model.config, - max_batch_size=1, # 캐시를 재사용할 계획이 있는 경우, 모든 경우에 충분한 캐시 길이를 설정해야 합니다 max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2), - device=model.device, - dtype=model.dtype ) outputs = model.generate(**input_ids, past_key_values=past_key_values) print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) @@ -161,7 +158,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu batch_size, seq_length = inputs["input_ids"].shape with torch.no_grad(): past_key_values = StaticCache( - config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype + config=model.config, max_cache_len=4096 ) cache_position = torch.arange(seq_length, device=torch_device) generated_ids = torch.zeros( diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e7167c2d2900..f99eca0e0bbf 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -377,23 +377,18 @@ "StaticLayer", "SlidingWindowLayer", "ChunkedSlidingLayer", - "CacheProcessor", - "OffloadedCacheProcessor", - "QuantizedCacheProcessor", - "QuantoQuantizedCacheProcessor", - "HQQQuantizedCacheProcessor", + "QuantoQuantizedLayer", + "HQQQuantizedLayer", "Cache", "CacheConfig", "DynamicCache", "EncoderDecoderCache", "HQQQuantizedCache", - "HQQQuantizedCacheProcessor", "HybridCache", "HybridChunkedCache", "OffloadedCache", "OffloadedStaticCache", "QuantizedCache", - "QuantoQuantizedCacheProcessor", "QuantizedCacheConfig", "QuantoQuantizedCache", "SinkCache", @@ -586,9 +581,12 @@ # All modeling imports from .cache_utils import Cache as Cache from .cache_utils import CacheConfig as CacheConfig + from .cache_utils import ChunkedSlidingLayer as ChunkedSlidingLayer from .cache_utils import DynamicCache as DynamicCache + from .cache_utils import DynamicLayer as DynamicLayer from .cache_utils import EncoderDecoderCache as EncoderDecoderCache from .cache_utils import HQQQuantizedCache as HQQQuantizedCache + from .cache_utils import HQQQuantizedLayer as HQQQuantizedLayer from .cache_utils import HybridCache as HybridCache from .cache_utils import MambaCache as MambaCache from .cache_utils import OffloadedCache as OffloadedCache @@ -596,9 +594,12 @@ from .cache_utils import QuantizedCache as QuantizedCache from .cache_utils import QuantizedCacheConfig as QuantizedCacheConfig from .cache_utils import QuantoQuantizedCache as QuantoQuantizedCache + from .cache_utils import QuantoQuantizedLayer as QuantoQuantizedLayer from .cache_utils import SinkCache as SinkCache from .cache_utils import SlidingWindowCache as SlidingWindowCache + from .cache_utils import SlidingWindowLayer as SlidingWindowLayer from .cache_utils import StaticCache as StaticCache + from .cache_utils import StaticLayer as StaticLayer from .configuration_utils import PretrainedConfig as PretrainedConfig from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS as SLOW_TO_FAST_CONVERTERS from .convert_slow_tokenizer import convert_slow_tokenizer as convert_slow_tokenizer diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index bb5aac99b33b..3fcfecbf911e 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1,26 +1,33 @@ import copy -import functools -import importlib.metadata -import inspect import json import os from abc import ABC, abstractmethod from collections.abc import Iterable from dataclasses import dataclass -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union import torch -from packaging import version from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6 from .configuration_utils import PretrainedConfig -from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging +from .utils import ( + is_hqq_available, + is_quanto_greater, + is_torch_greater_or_equal, + is_torchdynamo_compiling, + logging, +) +if _is_quanto_greater_than_0_2_5 := is_quanto_greater("0.2.5", accept_dev=True): + from optimum.quanto import MaxOptimizer, qint2, qint4, quantize_weight + if is_hqq_available(): from hqq.core.quantize import Quantizer as HQQQuantizer +_is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True) + logger = logging.get_logger(__name__) @@ -35,12 +42,12 @@ def __init__(self): @abstractmethod def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - cache_kwargs: Optional[dict[str, Any]] = None, + self, key_states: torch.Tensor, value_states: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None ) -> tuple[torch.Tensor, torch.Tensor]: ... + @abstractmethod + def lazy_initialization(self, key_states: torch.Tensor): ... + @abstractmethod def get_seq_length(self, cache_position=None) -> int: ... @@ -50,10 +57,23 @@ def get_max_cache_shape(self) -> int: ... @abstractmethod def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: ... + def offload(self): + """Offload this layer's data to CPU device.""" + if self.keys is not None: + self.keys = self.keys.to("cpu", non_blocking=True) + self.values = self.values.to("cpu", non_blocking=True) + + def prefetch(self): + """In case of layer offloading, this allows to move the data back to the layer's device ahead of time.""" + if self.keys is not None and self.keys.device != self.device: + self.keys = self.keys.to(self.device, non_blocking=True) + self.values = self.values.to(self.device, non_blocking=True) + def reset(self) -> None: """Resets the cache values while preserving the objects""" - self.keys.zero_() - self.values.zero_() + if self.keys is not None: + self.keys.zero_() + self.values.zero_() def reorder_cache(self, beam_idx: torch.LongTensor) -> tuple[torch.Tensor, torch.Tensor]: """Reorders this layer's cache for beam search.""" @@ -75,6 +95,11 @@ class DynamicLayer(CacheLayerMixin): is_sliding = False + def lazy_initialization(self, key_states: torch.Tensor): + self.dtype, self.device = key_states.dtype, key_states.device + self.keys = torch.tensor([], dtype=self.dtype, device=self.device) + self.values = torch.tensor([], dtype=self.dtype, device=self.device) + def update( self, key_states: torch.Tensor, @@ -95,12 +120,12 @@ def update( Return: A tuple containing the updated key and value states. """ + # Lazy initialization if self.keys is None: - self.keys = key_states - self.values = value_states - else: - self.keys = torch.cat([self.keys, key_states], dim=-2) - self.values = torch.cat([self.values, value_states], dim=-2) + self.lazy_initialization(key_states) + + self.keys = torch.cat([self.keys, key_states], dim=-2) + self.values = torch.cat([self.values, value_states], dim=-2) return self.keys, self.values def get_seq_length(self, cache_position=None) -> int: @@ -170,6 +195,7 @@ def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "DynamicLayer the supplied tensors. """ layer = cls() + layer.dtype, layer.device = keys.dtype, keys.device layer.keys = keys layer.values = values return layer @@ -186,61 +212,49 @@ class StaticLayer(CacheLayerMixin): is_compileable = True is_sliding = False - def __init__( - self, - max_cache_len: int, - batch_size: int, - num_heads: int, - head_dim: int, - dtype: torch.dtype = torch.float32, - device: str = "cpu", - sliding_window: Optional[int] = None, - ): + def __init__(self, max_cache_len: int): """ Args: max_cache_len (`int`): Maximum number of tokens that can be stored, used for tensor preallocation. - batch_size (`int`): - Maximum batch size the cache is pre-allocated for. - num_heads (`int`): - Number of attention heads. - head_dim (`int`): - Per-head hidden dimension. - dtype (`torch.dtype`, defaults to `torch.float32`): - Data type of the cache tensors. - device (`str` or `torch.device`, defaults to `"cpu"`): - Device on which the cache tensors will be materialised. - - Notes: - Static layers allocate their full backing tensors up-front and mutate them - in-place. See the documentation of `Cache` for shared helper methods that - operate uniformly across all layer types. """ + super().__init__() self.max_cache_len = max_cache_len - self.max_batch_size = batch_size - self.num_heads = num_heads - self.head_dim = head_dim - self.dtype = dtype - self.device = device + + def lazy_initialization(self, key_states: torch.Tensor): + """ + Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device, + num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving + devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well). + + If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this + function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we + internally don't compile the prefill, this is guaranteed to have been called already when compiling. + If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache, + it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs, + i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should + not be compiled anyway for performances! + """ + self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape + self.dtype, self.device = key_states.dtype, key_states.device self.keys = torch.zeros( - (batch_size, num_heads, self.max_cache_len, head_dim), - dtype=dtype, - device=device, + (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim), + dtype=self.dtype, + device=self.device, ) self.values = torch.zeros( - (batch_size, num_heads, self.max_cache_len, head_dim), - dtype=dtype, - device=device, + (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim), + dtype=self.dtype, + device=self.device, ) - # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, - # preventing compiled graph breaks when updating the cache. - torch._dynamo.mark_static_address(self.keys) - torch._dynamo.mark_static_address(self.values) - - def get_max_cache_shape(self) -> int: - """Return the maximum cache shape of the cache""" - return self.max_cache_len + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph + # breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case. + # As prefill should never be compiled, this is not an issue and it will still be run (except when users compile + # prefill explicitly, but this should be avoided!) + if not is_torchdynamo_compiling(): + torch._dynamo.mark_static_address(self.keys) + torch._dynamo.mark_static_address(self.values) def update( self, @@ -259,34 +273,31 @@ def update( Returns: tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value states. """ - cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None - key_states = key_states.to(self.keys.dtype) - value_states = value_states.to(self.values.dtype) - - # This may be needed if the Layer was not created with the right device in the beginning, i.e. if it did not respect - # the device_map. However, even if it is the case, this will only run once, because then the new states received - # will always have the same device - if self.device != key_states.device: - self.device = key_states.device - self.keys = self.keys.to(self.device) - self.values = self.values.to(self.device) - - if cache_position is None: - # Prefill phase where seq_len potentially equals max_cache_len. Directly copy. - self.keys.copy_(key_states) - self.values.copy_(value_states) - else: - # Generation phase. Update specific positions. - # Use index_copy_ for in-place update (compile-friendly). - try: - self.keys.index_copy_(2, cache_position, key_states) - self.values.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # Fallback for devices like MPS where index_copy_ might not be supported. - self.keys[:, :, cache_position] = key_states - self.values[:, :, cache_position] = value_states + # Lazy initialization + if self.keys is None: + self.lazy_initialization(key_states) + + # Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention, + # in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len) + cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None + cache_position = ( + cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device) + ) + + # Update the cache + try: + self.keys.index_copy_(2, cache_position, key_states) + self.values.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # Fallback for devices like MPS where index_copy_ might not be supported. + self.keys[:, :, cache_position] = key_states + self.values[:, :, cache_position] = value_states return self.keys, self.values + def get_max_cache_shape(self) -> int: + """Return the maximum cache shape of the cache""" + return self.max_cache_len + def get_seq_length(self, cache_position=None) -> int: """Returns the sequence length of the cached states.""" if cache_position is not None: @@ -319,15 +330,17 @@ class SlidingWindowLayer(StaticLayer): is_sliding = True - def __init__(self, sliding_window, *args, **kwargs): + def __init__(self, max_cache_len: int, sliding_window: int): """ Args: + max_cache_len (`int`): + Maximum number of tokens that can be stored, used for tensor preallocation. sliding_window (`int`): - Effective window size: number of tokens that are kept on each update call. + The size of the sliding window. """ - max_cache_len = kwargs.pop("max_cache_len", None) - max_cache_len = min(sliding_window, max_cache_len) if max_cache_len is not None else sliding_window - super().__init__(*args, max_cache_len=max_cache_len, *args, **kwargs) + effective_max_cache_len = min(sliding_window, max_cache_len) + super().__init__(max_cache_len=effective_max_cache_len) + self.cumulative_length = 0 def update( self, @@ -346,54 +359,46 @@ def update( Returns: tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value states. """ - cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None - if cache_position is None: - raise ValueError("`cache_position` must be provided for SlidingWindowLayer.") + # Lazy initialization + if self.keys is None: + self.lazy_initialization(key_states) - # This may be needed if the Layer was not created with the right device in the beginning, i.e. if it did not respect - # the device_map. However, even if it is the case, this will only run once, because then the new states received - # will always have the same device - if self.device != key_states.device: - self.device = key_states.device - self.keys = self.keys.to(self.device) - self.values = self.values.to(self.device) + cache_position = cache_kwargs.get("cache_position") - key_states = key_states.to(self.keys.dtype) - value_states = value_states.to(self.values.dtype) + is_full = self.cumulative_length >= self.max_cache_len + # Update it now that we saved the value above + self.cumulative_length += key_states.shape[-2] # Handle prefill phase when prompt length > sliding_window_size. # Note that we store cropped key/value states in the cache but return the full key/value states. if cache_position.shape[0] > self.max_cache_len: - new_k = key_states[:, :, -self.max_cache_len :, :] - new_v = value_states[:, :, -self.max_cache_len :, :] - self.keys.copy_(new_k) - self.values.copy_(new_v) + self.keys.copy_(key_states[:, :, -self.max_cache_len :, :]) + self.values.copy_(value_states[:, :, -self.max_cache_len :, :]) + # Return the full states here return key_states, value_states - # Sliding window logic for generation phase or prefill < window - slicing = torch.arange(self.max_cache_len, device=self.device) - current_seq_len = cache_position[-1] + 1 # Use last position to determine current length - to_shift = current_seq_len > self.max_cache_len - indices = (slicing + to_shift.sum()) % self.max_cache_len - - k_out_shifted = self.keys[:, :, indices] - v_out_shifted = self.values[:, :, indices] - - # Clamp cache_position to determine the *target index* within the shifted cache view - update_position = cache_position.clamp(min=0, max=self.max_cache_len - 1) + # Here we only assume decoding stage, i.e. 1 token at a time + if is_full: + # Roll all values to the left by 1 position + new_keys = self.keys.roll(-1, dims=-2) + new_values = self.values.roll(-1, dims=-2) + # Overwrite the last position with new states + # (note: very important to use a tensor to index here, see https://github.com/pytorch/pytorch/issues/159855) + index = torch.tensor([-1], dtype=int, device=self.device) + new_keys[:, :, index] = key_states + new_values[:, :, index] = value_states + + # Copy back into `self` (do not just assign again) in order to keep the static dynamo address + self.keys.copy_(new_keys) + self.values.copy_(new_values) + else: + try: + self.keys.index_copy_(2, cache_position, key_states) + self.values.index_copy_(2, cache_position, value_states) + except NotImplementedError: + self.keys[:, :, cache_position] = key_states + self.values[:, :, cache_position] = value_states - try: - k_out_updated = k_out_shifted.index_copy(2, update_position, key_states) - v_out_updated = v_out_shifted.index_copy(2, update_position, value_states) - except NotImplementedError: - # Fallback for MPS: clone and modify the clone - k_out_updated = k_out_shifted.clone() - v_out_updated = v_out_shifted.clone() - k_out_updated[:, :, update_position] = key_states - v_out_updated[:, :, update_position] = value_states - - self.keys.copy_(k_out_updated) - self.values.copy_(v_out_updated) return self.keys, self.values def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: @@ -406,6 +411,14 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: kv_length = max(query_length, self.max_cache_len) return kv_length, kv_offset + def reset(self) -> None: + super().reset() + self.cumulative_length = 0 + + def get_seq_length(self, cache_position=None) -> int: + """Returns the sequence length of the cached states.""" + return self.cumulative_length + class ChunkedSlidingLayer(SlidingWindowLayer): """ @@ -414,31 +427,22 @@ class ChunkedSlidingLayer(SlidingWindowLayer): See `SlidingWindowLayer` for details on common methods that are implemented by all cache layers. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.cumulative_length = 0 - def update( self, key_states: torch.Tensor, value_states: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None - if cache_position is None: - raise ValueError("`cache_position` must be provided for ChunkedSlidingLayer.") - - # This may be needed if the Layer was not created with the right device in the beginning, i.e. if it did not respect - # the device_map. However, even if it is the case, this will only run once, because then the new states received - # will always have the same device - if self.device != key_states.device: - self.device = key_states.device - self.keys = self.keys.to(self.device) - self.values = self.values.to(self.device) + # Lazy initialization + if self.keys is None: + self.lazy_initialization(key_states) + + cache_position = cache_kwargs.get("cache_position") cumulative_length = self.cumulative_length - self.cumulative_length += key_states.shape[-2] is_full = cumulative_length >= self.max_cache_len + # Update it now that we saved the value above + self.cumulative_length += key_states.shape[-2] if is_full: full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2) @@ -451,6 +455,7 @@ def update( self.values.copy_(full_value_states) return self.keys, self.values elif not is_full and cumulative_length + key_states.shape[2] > self.max_cache_len: + # Fast prefill path, no need to cat() in this case, as the cache is currently empty if cumulative_length == 0: full_key_states = key_states full_value_states = value_states @@ -468,12 +473,10 @@ def update( self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :]) self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :]) + # we should return the whole states instead of `self.keys/values` here, as otherwise we lose some context + # which is outside the window return full_key_states, full_value_states - def reset(self) -> None: - super().reset() - self.cumulative_length = 0 - def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: query_length = cache_position.shape[0] first_cache_position = cache_position[0] @@ -493,392 +496,111 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: return kv_length, kv_offset -class CacheProcessor: - """ - Base class for cache processors. It defines a pre-update and post-update methods that are called before and after the cache update. - This class should be subclassed. - """ - - def __init__(self, cache: "Cache", **kwargs) -> None: - """ - Initialize the processor and perform compatibility checks with the cache. - - Args: - cache (`Cache`): The cache instance this processor will be applied to. - **kwargs: Additional arguments that may be needed for initialization. - """ - raise NotImplementedError(f"Make sure to implement `init` in {self.__class__.__name__}.") - - def pre_update( - self, - cache: "Cache", - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Function called before the cache update. Can modify the key/value states. - - Args: - cache (`Cache`): The cache instance. - key_states (`torch.Tensor`): The new key states to cache. - value_states (`torch.Tensor`): The new value states to cache. - layer_idx (`int`): The index of the layer to cache the states for. - cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. - - Returns: - The modified key and value states. - """ - return key_states, value_states - - def post_update( - self, - cache: "Cache", - key_tensors: torch.Tensor, - value_tensors: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Function called after the cache update. Can process the cached data. - - Args: - cache (`Cache`): The cache instance. - key_states (`torch.Tensor`): The key states that were cached. - value_states (`torch.Tensor`): The value states that were cached. - layer_idx (`int`): The index of the layer that was updated. - cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. - - Returns: - The final key and value states to return to the model. - """ - return key_tensors, value_tensors - - -class OffloadedCacheProcessor(CacheProcessor): +class QuantizedLayer(DynamicLayer): """ - A cache processor that offloads cache tensors to conserve accelerator memory. - - This processor manages moving cache tensors between accelerator and CPU memory, - using asynchronous prefetching to minimize performance impact. Works with both - dynamic and static layers. - """ - - def __init__(self, cache: "Cache", offload_device: Union[str, torch.device] = "cpu", **kwargs): - """Initialize the offload processor and check device compatibility.""" - self.offload_device = torch.device(offload_device) - self.original_device = [] - self.prefetch_stream = None - self.beam_idx = None - - if not ( - torch.cuda.is_available() - or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available()) - ): - raise RuntimeError( - "OffloadedCacheProcessor can only be used with a GPU" - + (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "") - ) - - self.is_static = any(isinstance(layer, StaticLayer) for layer in cache.layers) - if self.is_static: - for i, layer in enumerate(cache.layers): - device = cache.layer_init_kwargs["device"] if i == 0 else self.offload_device - layer.keys = layer.keys.to(device) - layer.values = layer.values.to(device) - self.original_device.append(cache.layer_init_kwargs["device"]) - if len(cache) != cache.num_hidden_layers: - raise ValueError("If static layers are used, all cache layers must be initialized") - - self.prefetch_stream = ( - torch.Stream() if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.Stream() - ) - - def pre_update( - self, - cache: "Cache", - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Handles prefetching and eviction before cache update.""" - # Update the cache - if len(cache) < layer_idx: - raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.") - elif len(cache) == layer_idx: - self.original_device.append(key_states.device) - self._evict_previous_layer(cache, layer_idx) - else: - # Wait for the previous layer to be evicted (on default stream) - if is_torch_greater_or_equal("2.7", accept_dev=True): - torch.accelerator.current_stream().synchronize() - else: - torch.cuda.current_stream().synchronize() - self._evict_previous_layer(cache, layer_idx) - self._ensure_layer_on_device(cache, layer_idx) - - # Prefetch the next layer - self._prefetch_layer(cache, (layer_idx + 1) % len(cache)) - return key_states, value_states - - def _prefetch_layer(self, cache: "Cache", layer_idx: int): - """Starts prefetching the next layer cache.""" - if layer_idx < len(cache): - with ( - self.prefetch_stream - if is_torch_greater_or_equal("2.7", accept_dev=True) - else torch.cuda.stream(self.prefetch_stream) - ): - # Prefetch next layer tensors to GPU - device = self.original_device[layer_idx] - cache.layers[layer_idx].keys = cache.layers[layer_idx].keys.to(device, non_blocking=True) - cache.layers[layer_idx].values = cache.layers[layer_idx].values.to(device, non_blocking=True) - - def _evict_previous_layer(self, cache: "Cache", layer_idx: int): - """Moves the previous layer cache to the CPU.""" - if len(cache) >= 2: # Layer 0 stays on device to be on-device after all layers are created - # We do it on the default stream so it occurs after all earlier computations on these tensors are done - prev_layer_idx = (layer_idx - 1) % len(cache) - cache.layers[prev_layer_idx].keys = cache.layers[prev_layer_idx].keys.to( - self.offload_device, non_blocking=True - ) - cache.layers[prev_layer_idx].values = cache.layers[prev_layer_idx].values.to( - self.offload_device, non_blocking=True - ) - - def _ensure_layer_on_device(self, cache: "Cache", layer_idx: int): - """Ensures the current layer is on the original device.""" - if layer_idx < len(cache): - # Wait for the previous prefetch to be done - self.prefetch_stream.synchronize() - - # Handle delayed beam search operations - if self.beam_idx is not None: - self.beam_idx = self.beam_idx.to(self.original_device[layer_idx]) - cache.layers[layer_idx].keys = cache.layers[layer_idx].keys.index_select(0, self.beam_idx) - cache.layers[layer_idx].values = cache.layers[layer_idx].values.index_select(0, self.beam_idx) - - -class QuantizedCacheProcessor(CacheProcessor): - """ - A cache processor that applies quantization to cache tensors to reduce memory usage. - - This processor quantizes cache tensors after they are stored, maintaining a residual - length in original precision and quantizing older tokens. + A quantized layer similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by + applying quantization. + + The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` + is set as a maximum capacity for the original precision cache. When the length goes beyond maximum capacity, the original + precision cache is discarded and moved into the quantized cache. The quantization is done per-channel with a set `q_group_size` + for both Keys and Values, in contrast to what was described in the paper. """ def __init__( self, - cache: "Cache", - backend: str = "quanto", nbits: int = 4, axis_key: int = 0, axis_value: int = 0, q_group_size: int = 64, residual_length: int = 128, - compute_dtype: torch.dtype = torch.float16, - device: str = "cpu", ): - """ - Parameters: - backend (`str`, defaults to `"quanto"`): - Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`] - nbits (`int`, defaults to 4): - Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2. - axis_key (`int`, defaults to 0): - Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. - axis_value (`int`, defaults to 0): - Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. - q_group_size (`int`, defaults to 64): - Size of the quantization group, should be a divisor of the model's hidden dimension. - Defaults to 64. - residual_length (`int`, defaults to 128): - Length of the residual cache which will always be stored in original precision. - Defaults to 128. - compute_dtype (`torch.dtype`, defaults to `torch.float16`): - The default dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization. - device (`str`, defaults to `"cpu"`): - Device on which to perform computations, should be same as the model's device. - """ - self.backend = backend + super().__init__(self) self.nbits = nbits self.axis_key = axis_key self.axis_value = axis_value self.q_group_size = q_group_size self.residual_length = residual_length - self.compute_dtype = compute_dtype - self.device = device - self._quantized_keys: list[torch.Tensor] = [] - self._quantized_values: list[torch.Tensor] = [] - - self.validate() - self.erased_length = 0 - - # Only compatible with DynamicCache - if not isinstance(cache.layers[0], DynamicLayer): - raise ValueError("QuantizedCacheProcessor is only compatible with DynamicCache") - - def validate(self): - """Validates if the arguments passed are correct""" - - incorrect_arg_msg = ( - "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " - "but found {found_value}" - ) - # Check that the values are reasonable in general (nbits, axis) - # Later in QuantizedCache init we check if they are supported for that particular backend - if self.nbits not in [1, 2, 3, 4, 8]: - raise ValueError( - incorrect_arg_msg.format( - key="nbits", - correct_value="2 or 4 or 8", - found_value=self.nbits, - ), - ) - if self.q_group_size <= 0: - raise ValueError( - incorrect_arg_msg.format( - key="q_group_size", - correct_value="a positive integer", - found_value=self.q_group_size, - ), - ) - if self.residual_length < 0: - raise ValueError( - incorrect_arg_msg.format( - key="residual_length", - correct_value="a positive integer", - found_value=self.residual_length, - ), - ) - - if self.axis_key not in [0, 1, -1]: - raise ValueError( - incorrect_arg_msg.format( - key="axis_key", - correct_value="`1` or `0`, `-1`", - found_value=self.axis_key, - ), - ) - - if self.axis_value not in [0, 1, -1]: - raise ValueError( - incorrect_arg_msg.format( - key="axis_value", - correct_value="`1` or `0` or `-1`", - found_value=self.axis_value, - ), - ) + self.cumulative_length = 0 - def post_update( + def update( self, - cache: "Cache", - key_tensors: torch.Tensor, - value_tensors: torch.Tensor, - layer_idx: int, + key_states: torch.Tensor, + value_states: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Apply quantization after cache update.""" - - if len(cache) < layer_idx: - raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.") - - # `key_tensors` is the content of the residual cache, after having been updated by DynamicLayer - # On the first forward pass, we quantize the whole prompt (prefill, quantize_length=0) - # On subsequent passes, we accumulate the tokens in the residual cache and quantize when it is full. - if self._is_quantized_length_zero(layer_idx): - self._quantized_keys.append(self._quantize(key_tensors.contiguous(), axis=self.axis_key)) - self._quantized_values.append(self._quantize(value_tensors.contiguous(), axis=self.axis_value)) - - # Clear the residual cache - self.erased_length = key_tensors.shape[-2] - cache.layers[layer_idx].keys = torch.zeros( - 0, - dtype=key_tensors.dtype, - device=key_tensors.device, - ) - cache.layers[layer_idx].values = torch.zeros( - 0, - dtype=value_tensors.dtype, - device=value_tensors.device, - ) - # On prefill, we return the original prompt - keys_to_return, values_to_return = key_tensors, value_tensors + """ + Updates the cache with the new `key_states` and `value_states`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + cache_kwargs (`dict[str, Any]`, *optional*): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicLayer`. + + Return: + A tuple containing the updated key and value states. + """ + self.cumulative_length += key_states.shape[-2] + + # Lazy initialization + if self.keys is None: + self.lazy_initialization(key_states) + self._quantized_keys = self._quantize(key_states.contiguous(), axis=self.axis_key) + self._quantized_values = self._quantize(value_states.contiguous(), axis=self.axis_value) + return key_states, value_states + dequant_keys = self._dequantize(self._quantized_keys) + dequant_values = self._dequantize(self._quantized_values) + keys_to_return = torch.cat([dequant_keys, self.keys, key_states], dim=-2) + values_to_return = torch.cat([dequant_values, self.values, value_states], dim=-2) + if self.keys.dim() == 4 and self.keys.shape[-2] + 1 >= self.residual_length: + self._quantized_keys = self._quantize(keys_to_return.contiguous(), axis=self.axis_key) + self._quantized_values = self._quantize(values_to_return.contiguous(), axis=self.axis_value) + self.keys = torch.tensor([], dtype=key_states.dtype, device=key_states.device) + self.values = torch.tensor([], dtype=key_states.dtype, device=key_states.device) else: - # Prepend the previously quantized cache - dequant_key = self._dequantize(self._quantized_keys[layer_idx]) - dequant_value = self._dequantize(self._quantized_values[layer_idx]) - keys_to_return = torch.cat([dequant_key, key_tensors], dim=-2) - values_to_return = torch.cat([dequant_value, value_tensors], dim=-2) - if key_tensors.shape[-2] >= self.residual_length: - # Quantize and store - self._quantized_keys[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key) - self._quantized_values[layer_idx] = self._quantize(values_to_return.contiguous(), axis=self.axis_value) - - # Clear the residual cache - self.erased_length += key_tensors.shape[-2] - cache.layers[layer_idx].keys = torch.zeros( - 0, - dtype=key_tensors.dtype, - device=key_tensors.device, - ) - cache.layers[layer_idx].values = torch.zeros( - 0, - dtype=value_tensors.dtype, - device=value_tensors.device, - ) + self.keys = torch.cat([self.keys, key_states], dim=-2) + self.values = torch.cat([self.values, value_states], dim=-2) return keys_to_return, values_to_return - def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor: - """Quantize a tensor - to be implemented by specific quantization backends.""" - raise NotImplementedError("Quantization backend must implement _quantize method") - - def _dequantize(self, tensor: torch.Tensor) -> torch.Tensor: - """Dequantize a tensor - to be implemented by specific quantization backends.""" - raise NotImplementedError("Quantization backend must implement _dequantize method") + def get_seq_length(self, cache_position=None) -> int: + """Returns the sequence length of the cached states.""" + return self.cumulative_length - def _is_quantized_length_zero(self, layer_idx: int) -> bool: - """Check if quantized cache is empty for layer. Note: shape[-2] is unreliable since quantized tensors are bit-packed and flattened.""" - return layer_idx >= len(self._quantized_keys) + @abstractmethod + def _quantize(self, tensor, axis): ... + @abstractmethod + def _dequantize(self, q_tensor): ... -class QuantoQuantizedCacheProcessor(QuantizedCacheProcessor): - """ - Quantized cache processor that uses `quanto` as a backend to perform quantization. - Current implementation supports `int2` and `int4` dtypes only. - """ +class QuantoQuantizedLayer(QuantizedLayer): def __init__( self, - cache: "Cache", - backend: str = "quanto", nbits: int = 4, axis_key: int = 0, axis_value: int = 0, q_group_size: int = 64, residual_length: int = 128, - compute_dtype: torch.dtype = torch.float16, - device: str = "cpu", - ) -> None: - """Initialize the quanto quantization processor.""" + ): super().__init__( - cache, backend, nbits, axis_key, axis_value, q_group_size, residual_length, compute_dtype, device + nbits=nbits, + axis_key=axis_key, + axis_value=axis_value, + q_group_size=q_group_size, + residual_length=residual_length, ) - if backend != "quanto": - raise ValueError(f"QuantoQuantizedCacheProcessor only supports `quanto` backend, but got {backend}") - - if is_optimum_quanto_available(): - optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) - if optimum_quanto_version <= version.parse("0.2.5"): - raise ImportError( - f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCacheProcessor`. Detected version {optimum_quanto_version}." - ) - from optimum.quanto import MaxOptimizer, qint2, qint4 + if not _is_quanto_greater_than_0_2_5: + raise ImportError( + "You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. " + "Detected version {optimum_quanto_version}." + ) if self.nbits not in [2, 4]: raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") @@ -892,47 +614,36 @@ def __init__( ) self.qtype = qint4 if self.nbits == 4 else qint2 - self.optimizer = MaxOptimizer() - - def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor: - """Quantize tensor using quanto backend.""" - if is_optimum_quanto_available(): - from optimum.quanto import quantize_weight + self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization - scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) - qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) - return qtensor + def _quantize(self, tensor, axis): + scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) + qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) + return qtensor - def _dequantize(self, qtensor: torch.Tensor) -> torch.Tensor: - """Dequantize tensor using quanto backend.""" + def _dequantize(self, qtensor): return qtensor.dequantize() -class HQQQuantizedCacheProcessor(QuantizedCacheProcessor): - """ - Quantized cache processor that uses `HQQ` as a backend to perform quantization. - Current implementation supports `int2`, `int4`, `int8` dtypes. - """ - +class HQQQuantizedLayer(QuantizedLayer): def __init__( self, - cache: "Cache", - backend: str = "quanto", nbits: int = 4, axis_key: int = 0, axis_value: int = 0, q_group_size: int = 64, residual_length: int = 128, - compute_dtype: torch.dtype = torch.float16, - device: str = "cpu", - ) -> None: - """Initialize the HQQ quantization processor.""" + ): super().__init__( - cache, backend, nbits, axis_key, axis_value, q_group_size, residual_length, compute_dtype, device + nbits=nbits, + axis_key=axis_key, + axis_value=axis_value, + q_group_size=q_group_size, + residual_length=residual_length, ) - if backend != "quanto": - raise ValueError(f"HQQQuantizedCacheProcessor only supports `quanto` backend, but got {backend}") + if not is_hqq_available(): + raise ImportError("You need to install `hqq` to use `HQQQuantizedLayer`") if self.nbits not in [1, 2, 3, 4, 8]: raise ValueError( @@ -947,58 +658,32 @@ def __init__( self.quantizer = HQQQuantizer - def _quantize(self, tensor: torch.Tensor, axis: int) -> tuple[torch.Tensor, dict]: - """Quantize tensor using HQQ backend.""" + def _quantize(self, tensor, axis): qtensor, meta = self.quantizer.quantize( tensor, axis=axis, - device=self.device, - compute_dtype=self.compute_dtype, + device=self.keys.device, + compute_dtype=self.keys.dtype, nbits=self.nbits, group_size=self.q_group_size, ) - meta["compute_dtype"] = self.compute_dtype - self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype + meta["compute_dtype"] = self.keys.dtype + self.quantizer.cuda(qtensor, meta=meta, device=self.keys.device) # Move to device and cast to dtype meta["scale"] = meta["scale"].to(qtensor.device) meta["zero"] = meta["zero"].to(qtensor.device) return qtensor, meta - def _dequantize(self, qtensor_and_meta: tuple[torch.Tensor, dict]) -> torch.Tensor: - """Dequantize tensor using HQQ backend.""" - quant_tensor, meta = qtensor_and_meta + def _dequantize(self, qtensor): + quant_tensor, meta = qtensor tensor = self.quantizer.dequantize(quant_tensor, meta) return tensor -def apply_processors( - fn: Callable[..., tuple[torch.Tensor, torch.Tensor]], -) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]: - @functools.wraps(fn) - def _wrapped_update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Wrapper around the update method to apply cache processors. - """ - if self.cache_processor is not None: - key_states, value_states = self.cache_processor.pre_update( - self, key_states, value_states, layer_idx, cache_kwargs - ) - - key_tensors, value_tensors = fn(self, key_states, value_states, layer_idx, cache_kwargs) - - if self.cache_processor is not None: - key_tensors, value_tensors = self.cache_processor.post_update( - self, key_tensors, value_tensors, layer_idx, cache_kwargs - ) - - return key_tensors, value_tensors - - return _wrapped_update +LAYER_CLASS_MAP: dict[str, type[CacheLayerMixin]] = { + "full_attention": StaticLayer, + "sliding_attention": SlidingWindowLayer, + "chunked_attention": ChunkedSlidingLayer, +} class KeyValuesWrapper: @@ -1035,144 +720,80 @@ def __bool__(self): class Cache: """ - Base container for per-layer key/value caches. - - A `Cache` behaves like a list of `CacheLayerMixin` objects, one per model layer. - Sub-classes such as `DynamicCache`, `StaticCache`, or `SlidingWindowCache` - simply pre-select which `CacheLayerMixin` class to use and may attach a - `CacheProcessor` (off-loading, quantization). - - Example - ------- - ```python - from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache - - model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") - tok = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") - inputs = tok("Hello", return_tensors="pt") - - cache = DynamicCache() - outputs = model(**inputs, past_key_values=cache, use_cache=True) - ``` + A `Cache` is mostly a list of `CacheLayerMixin` objects, one per model layer. It serves as a container for + the Cache of each layer. Parameters: - layer_classes (`type[CacheLayerMixin]` or `list[type[CacheLayerMixin]]`): - A list of `CacheLayerMixin` classes to instantiate for the cache. If only a `CacheLayerMixin` class is - provided, then it is used for all layers. - config (`PretrainedConfig`, *optional*): - Model configuration used to infer number of layers, head sizes, default - device/dtype, etc. - cache_processor (`CacheProcessor` or `str`, *optional*): - Cache processor to apply (e.g., "offloaded", "quanto_quantized", "hqq_quantized") - or a CacheProcessor class. - max_batch_size (`int`, *optional*): Maximum batch size for static caches. - max_cache_len (`int`, *optional*): Maximum sequence length. For hybrid caches, SlidingWindowLayers are - clamped to `min(sliding_window, max_cache_len)`, StaticLayers use full `max_cache_len`. - device (`torch.device`, *optional*): Device for cache tensors. - dtype (`torch.dtype`, *optional*): Data type for cache tensors. - layer_device_map (`dict[int, Union[str, torch.device]]`, *optional*): Per-layer device mapping. - tp_size (`int`, *optional*): Tensor parallel size to adjust the number of key/value heads. - - Additional keyword arguments are forwarded to the chosen layers constructor(s) and CacheProcessors. See the - documentation of the relevant `CacheLayerMixin` class and `CacheProcessor` class for more details. + layers (`Optional`, *optional*): + A list of pre-created `CacheLayerMixin`. If omitted (`None`), then `layer_class_to_replicate` will + be used. + layer_class_to_replicate (`type[CacheLayerMixin]`, *optional*): + Only used if `layers` is omitted (`None`), in which case it will be used as the base class for each layer, + and the layers will be added lazily as soon as `update` is called with a `layer_idx` greater than the current + list of layers. + offloading (`bool`, *optional*, defaults to `False`): + Whether to perform offloading of the layers to `cpu`, to save GPU memory. + offload_only_non_sliding (`bool`, *optional*, defaults to `True`): + If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because + usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster). """ def __init__( self, - layer_classes: Union[list[type[CacheLayerMixin]], type[CacheLayerMixin]], - config: Optional[PretrainedConfig] = None, - cache_processor: Optional[Union[str, type[CacheProcessor]]] = None, - max_batch_size: Optional[int] = None, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: Optional[torch.dtype] = None, - layer_device_map: Optional[dict[int, torch.device]] = None, - tp_size: Optional[int] = None, - **kwargs, + layers: Optional[list[CacheLayerMixin]] = None, + layer_class_to_replicate: Optional[type[CacheLayerMixin]] = None, + offloading: bool = False, + offload_only_non_sliding: bool = True, ): - self.layers: list[CacheLayerMixin] = [] - self.layer_classes = layer_classes - - processor_class = PROCESSOR_CLASS_MAP[cache_processor] if isinstance(cache_processor, str) else cache_processor - kwargs.update( - max_batch_size=max_batch_size, - max_cache_len=max_cache_len, - device=device, - dtype=dtype, - layer_device_map=layer_device_map, - tp_size=tp_size, - ) - processor_kwargs, kwargs = parse_processor_args(processor_class, kwargs) - - self.layer_init_kwargs = parse_layer_args_from_model_config(config, **kwargs) - self.num_hidden_layers = getattr(config, "num_hidden_layers", 1) - - self.append_new_layers(self.num_hidden_layers - 1) - self.cache_processor = processor_class(self, **processor_kwargs) if processor_class is not None else None - - def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: - """ - Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the - sequence length. - """ - if layer_idx < len(self.layers): - return self.layers[layer_idx].keys, self.layers[layer_idx].values - else: - raise KeyError( - f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}" + if layers is not None and layer_class_to_replicate is not None: + raise ValueError( + "You can construct a Cache either from a list `layers` of all the predefined `CacheLayer`, or from a " + "`layer_class_to_replicate`, in which case the Cache will append a new layer corresponding to " + "`layer_class_to_replicate` for each new call to `update` with an idx not already in the Cache." + ) + if layers is None and layer_class_to_replicate is None: + raise ValueError( + "You should provide exactly one of `layers` or `layer_class_to_replicate` to initialize a Cache." ) + self.layers = layers if layers is not None else [] + self.layer_class_to_replicate = layer_class_to_replicate + self.offloading = offloading + if self.offloading: + self.only_non_sliding = offload_only_non_sliding + self.prefetch_stream = torch.Stream() if _is_torch_greater_or_equal_than_2_7 else torch.cuda.Stream() - def __iter__(self): - """ - Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over - keys and values - """ - for layer_idx in range(len(self)): - yield (self.layers[layer_idx].keys, self.layers[layer_idx].values) + def __repr__(self): + return f"{self.__class__.__name__}(layers={self.layers})" - def __len__(self): + def prefetch(self, layer_idx: int, only_non_sliding: bool = True): """ - Support for backwards-compatible `past_key_values` length, e.g. `len(past_key_values)`. This value corresponds - to the number of layers in the model. + Prefetch a given layer on its device. If `only_non_sliding` is True, it will try to prefetch only the layers + which are non-sliding. If the `layer_idx` is outside the range, this will circle back to the first layers. + Note that we use a non-default stream for this, to avoid blocking. """ - # Best effort BC support for old-style caches like Mambas, Falcon, HybridChunked that rely on __len__ - if getattr(self, "layers", None) is None: - if getattr(self, "key_cache", None) is not None: - return len(self.key_cache) - return 0 - # Empty dynamic caches initialize an empty layer to be ready for first update - dynamic_empty = ( - getattr(self, "layers", None) is not None - and len(self.layers) == 1 - and isinstance(self.layers[0], DynamicLayer) - and self.layers[0].keys is None - ) - return len(self.layers) if not dynamic_empty else 0 + if only_non_sliding: + # Try to find next non-sliding, starting at `layer_idx` + try: + layer_idx = layer_idx + self.is_sliding[layer_idx:].index(False) + # In this case, we need to circle back to the begining + except ValueError: + layer_idx = self.is_sliding.index(False) + else: + layer_idx = layer_idx if layer_idx < len(self.layers) else 0 - def __repr__(self): - return f"{self.__class__.__name__}(layers={self.layers})" + # Prefetch + with self.prefetch_stream if _is_torch_greater_or_equal_than_2_7 else torch.cuda.stream(self.prefetch_stream): + self.layers[layer_idx].prefetch() - def append_new_layers(self, layer_idx: int) -> None: + def offload(self, layer_idx: int, only_non_sliding: bool = True): """ - Appends layers to the cache until the layer `layer_idx` is reached. - Used for preallocation in static caches and on the fly in dynamic caches. - - Args: - layer_idx (`int`): - The index of the layer to append. + Offload a given `layer_idx`. If `only_non_sliding` is True, it will offload `layer_idx` only if it is a + non-sliding layer. Note that we do it on the default stream, so that we ensure all earlier + computation in the layer's `update` methods are finished. """ - while len(self.layers) <= layer_idx: - kwargs = self.layer_init_kwargs.copy() - if self.layer_init_kwargs.get("layer_device_map", None) is not None: - kwargs["device"] = kwargs.pop("layer_device_map")[len(self.layers)] + if not (only_non_sliding and self.is_sliding[layer_idx]): + self.layers[layer_idx].offload() - new_layer_class = ( - self.layer_classes[len(self.layers)] if isinstance(self.layer_classes, list) else self.layer_classes - ) - new_layer = new_layer_class(**kwargs) - self.layers.append(new_layer) - - @apply_processors def update( self, key_states: torch.Tensor, @@ -1197,48 +818,62 @@ def update( Return: A tuple containing the updated key and value states. """ - self.append_new_layers(layer_idx) - return self.layers[layer_idx].update(key_states, value_states, cache_kwargs) + # In this case, the `layers` were not provided, and we must append as much as `layer_idx` + if self.layer_class_to_replicate is not None: + while len(self.layers) <= layer_idx: + self.layers.append(self.layer_class_to_replicate()) + + if self.offloading: + # Wait for the stream to finish if needed, and start prefetching the next layer + torch.cuda.default_stream(key_states.device).wait_stream(self.prefetch_stream) + self.prefetch(layer_idx + 1, self.only_non_sliding) + + keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs) + + if self.offloading: + self.offload(layer_idx, self.only_non_sliding) + + return keys, values + + def early_initialization( + self, batch_size: int, num_heads: int, head_dim: int, dtype: torch.dtype, device: torch.device + ): + """ + Initialize all the layers in advance (it's otherwise lazily initialized on the first `update` call). + This is useful for our `export` recipes, as `export` needs everything in advance. + """ + # Note that the initialization needs all dimensions (except -2), as well as device and dtype, so we use + # this fake tensor approach. It has size 0 on the -2 dimension, so it does not allocate any data (it only + # creates an empty tensor with correct shape, dtype and device), which is very efficient and practical + fake_keys_tensor = torch.zeros((batch_size, num_heads, 0, head_dim), dtype=dtype, device=device) + # Init all layers + for layer in self.layers: + layer.lazy_initialization(fake_keys_tensor) def get_seq_length(self, layer_idx: int = 0, cache_position=None) -> int: - """Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position""" + """Returns the sequence length of the cache for the given layer.""" if layer_idx >= len(self.layers): return 0 - # Hack since QuantizedCache messes with keys shape as it becomes the residual cache - if self.cache_processor is not None and isinstance(self.cache_processor, QuantizedCacheProcessor): - return self.cache_processor.erased_length + self.layers[layer_idx].get_seq_length(cache_position) return self.layers[layer_idx].get_seq_length(cache_position) def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: """ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for the given layer at `layer_idx`. - The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), - for each layer. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. """ - kv_length, kv_offset = self.layers[layer_idx].get_mask_sizes(cache_position) - return kv_length, kv_offset - - @property - def key_cache(self) -> KeyValuesWrapper: - """List-like object of key cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].keys`""" - logger.warning_once( - "`cache.key_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].keys` instead." - ) - return KeyValuesWrapper(self.layers, "keys") - - @property - def value_cache(self) -> KeyValuesWrapper: - """List-like object of value cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].values`""" - logger.warning_once( - "`cache.value_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].values` instead." - ) - return KeyValuesWrapper(self.layers, "values") - - ### Wrappers for layer operations and properties ### + # For DynamicCache, where the layers are created at runtime -> if it was not yet created, the size is + # simply the shape of `cache_position` + if layer_idx >= len(self.layers): + return cache_position.shape[0], 0 + return self.layers[layer_idx].get_mask_sizes(cache_position) def get_max_cache_shape(self, layer_idx: int = 0) -> int: """Returns maximum sequence length of the cache object. Dynamic caches do not have a maximum length.""" + # For DynamicCache, where the layers are created at runtime -> if it was not yet created, return -1 + # as DynamicLayer does + if layer_idx >= len(self.layers): + return -1 return self.layers[layer_idx].get_max_cache_shape() def reset(self): @@ -1283,6 +918,9 @@ def max_cache_len(self) -> int: @property def is_compileable(self) -> bool: """Return whether the cache is compileable""" + # For DynamicCache dispatching the layers lazily (otherwise, all([]) is True) + if len(self.layers) == 0: + return False return all(layer.is_compileable for layer in self.layers) @property @@ -1290,6 +928,50 @@ def is_sliding(self) -> list[bool]: """Return whether the layers of the cache are sliding window""" return [getattr(layer, "is_sliding", False) for layer in self.layers] + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self.layers): + return self.layers[layer_idx].keys, self.layers[layer_idx].values + else: + raise KeyError( + f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}" + ) + + def __iter__(self): + """ + Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.layers[layer_idx].keys, self.layers[layer_idx].values) + + def __len__(self): + """ + This value corresponds to the number of layers in the model. + """ + # Note: for DynamicCache, layers are initialized lazily, so this will not be accurate before the first + # forward through all the layers + return len(self.layers) + + @property + def key_cache(self) -> KeyValuesWrapper: + """List-like object of key cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].keys`""" + logger.warning_once( + "`cache.key_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].keys` instead." + ) + return KeyValuesWrapper(self.layers, "keys") + + @property + def value_cache(self) -> KeyValuesWrapper: + """List-like object of value cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].values`""" + logger.warning_once( + "`cache.value_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].values` instead." + ) + return KeyValuesWrapper(self.layers, "values") + class DynamicCache(Cache): """ @@ -1319,17 +1001,18 @@ class DynamicCache(Cache): """ # Specialized constructor for DDP cache data, needed for BC - def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs): - super().__init__(layer_classes=DynamicLayer, *args, **kwargs) + def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None): # `ddp_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212 # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the # iterable contains the key and value states for a layer gathered across replicas by torch.distributed # (shape=[global batch size, num_heads, seq_len, head_dim]). - # WARNING: `ddp_cache_data` must be the first argument in `__init__`, otherwise we'll break - # compatibility. The name of the argument doesn't matter. if ddp_cache_data is not None: + layers = [] for key_states, value_states in ddp_cache_data: - self.layers.append(DynamicLayer.from_tensors(key_states, value_states)) + layers.append(DynamicLayer.from_tensors(key_states, value_states)) + super().__init__(layers=layers) + else: + super().__init__(layer_class_to_replicate=DynamicLayer) def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: """ @@ -1403,22 +1086,16 @@ def _unflatten_dynamic_cache( ) -class OffloadedCache(DynamicCache): +class OffloadedCache(Cache): """ - A drop-in replacement for DynamicCache that conserves accelerator(GPU, XPU) memory at the expense of more CPU memory. + A drop-in replacement for DynamicCache that conserves accelerator (GPU, XPU) memory at the expense of more CPU memory. Useful for generating from models with very long context. - In addition to the default accelerator stream, where all forward() computations happen, - this class uses another stream, the prefetch stream, which it creates itself. - Since scheduling of operations on separate streams happens independently, this class uses - the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing. - The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to - ensure the eviction is scheduled after all computations on that cache are finished. + See `Cache` for details on common methods that are implemented by all cache classes. """ def __init__(self) -> None: - # Create the underlying cache with offload processor - super().__init__(cache_processor=OffloadedCacheProcessor) + super().__init__(layer_class_to_replicate=DynamicLayer, offloading=True) class StaticCache(Cache): @@ -1440,18 +1117,20 @@ class StaticCache(Cache): >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = StaticCache(max_cache_len=max_generated_length, config=model.config) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation StaticCache() ``` """ - def __init__(self, *args, **kwargs): - super().__init__(layer_classes=StaticLayer, *args, **kwargs) + # Pass-in kwargs as well to avoid crashing for BC (it used more arguments before) + def __init__(self, max_cache_len: int, config: PretrainedConfig, **kwargs): + layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)] + super().__init__(layers=layers) -class OffloadedStaticCache(StaticCache): +class OffloadedStaticCache(Cache): """ A drop-in replacement for StaticCache that conserves accelerator memory by offloading cache tensors to CPU when not actively being used. @@ -1472,40 +1151,22 @@ class OffloadedStaticCache(StaticCache): >>> # Prepare a cache class with offloading >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = OffloadedStaticCache( - ... config=model.config, - ... max_batch_size=1, - ... max_cache_len=max_generated_length, - ... device=model.device, - ... dtype=model.dtype - ... ) + >>> past_key_values = OffloadedStaticCache(max_cache_len=max_generated_length, config=model.config) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache with offloaded layers OffloadedStaticCache() ``` """ - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs) + # Pass-in kwargs as well to avoid crashing for BC (it used more arguments before) + def __init__(self, max_cache_len: int, config: PretrainedConfig, **kwargs): + layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)] + super().__init__(layers=layers, offloading=True) class SlidingWindowCache(Cache): """ Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. - Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.sliding_window - 1`, - if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), - we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. - - The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: - - indices = (slicing + to_shift[-1].sum()-1) % self.sliding_window - tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, - 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, - 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, - 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) - - We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) - See `Cache` for details on common methods that are implemented by all cache classes. Example: @@ -1521,22 +1182,24 @@ class SlidingWindowCache(Cache): >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = SlidingWindowCache(max_cache_len=max_generated_length, config=model.config) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation SlidingWindowCache() ``` """ - def __init__(self, *args, **kwargs): - super().__init__(layer_classes=SlidingWindowLayer, *args, **kwargs) + # Pass-in kwargs as well to avoid crashing for BC (it used more arguments before) + def __init__(self, max_cache_len: int, config: PretrainedConfig, **kwargs): + layers = [SlidingWindowLayer(max_cache_len, config.sliding_window) for _ in range(config.num_hidden_layers)] + super().__init__(layers=layers) class HybridCache(Cache): """ Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window attention and global attention in every other layer (originally implemented for Gemma2). - Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] + Under the hood, Hybrid Cache leverages ["SlidingWindowLayer"] for sliding window attention and ["StaticLayer"] for global attention. For more information, see the documentation of those layer types. See `Cache` for details on common methods that are implemented by all cache classes. @@ -1554,27 +1217,35 @@ class HybridCache(Cache): >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = HybridCache(max_cache_len=max_generated_length, config=model.config) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation HybridCache() ``` """ - def __init__(self, config: PretrainedConfig, *args, **kwargs): + # Pass-in kwargs as well to avoid crashing for BC (it used more arguments before) + def __init__(self, max_cache_len: int, config: PretrainedConfig, **kwargs): if hasattr(config, "layer_types"): - layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types] + layers = [] + for layer_type in config.layer_types: + init_kwargs = {"max_cache_len": max_cache_len} + if layer_type == "sliding_attention": + init_kwargs["sliding_window"] = config.sliding_window + elif layer_type == "chunked_attention": + init_kwargs["sliding_window"] = config.attention_chunk_size + layers.append(LAYER_CLASS_MAP[layer_type](**init_kwargs)) else: # In this case, fall back to StaticCache - layer_classes = [StaticLayer] * config.num_hidden_layers - super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs) + layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)] + super().__init__(layers=layers) # The mapping already handles dispatching the correct layers in Hybrid, this is only used for BC class HybridChunkedCache(HybridCache): ... -class OffloadedHybridCache(HybridChunkedCache): +class OffloadedHybridCache(Cache): """ A drop-in replacement for HybridChunkedCache that conserves accelerator memory by offloading cache tensors to CPU when not actively being used. @@ -1585,51 +1256,73 @@ class OffloadedHybridCache(HybridChunkedCache): See `Cache` for details on common methods that are implemented by all cache classes. """ - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs) + # Pass-in kwargs as well to avoid crashing for BC (it used more arguments before) + def __init__(self, max_cache_len: int, config: PretrainedConfig, **kwargs): + if hasattr(config, "layer_types"): + layers = [] + for layer_type in config.layer_types: + init_kwargs = {"max_cache_len": max_cache_len} + if layer_type == "sliding_attention": + init_kwargs["sliding_window"] = config.sliding_window + elif layer_type == "chunked_attention": + init_kwargs["sliding_window"] = config.attention_chunk_size + layers.append(LAYER_CLASS_MAP[layer_type](**init_kwargs)) + else: + # In this case, fall back to StaticCache + layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)] + super().__init__(layers=layers, offloading=True) -class QuantizedCache(DynamicCache): +class QuantizedCache(Cache): """ - A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). - It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. - - The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the - original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The - quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. - - It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and - Value in original precision states as a list of tensors, one for each layer. The size of each tensor - is `[batch_size, num_heads, seq_len - residual_length, head_dim]`. + A quantizer cache similar to what is described in the + [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for keys and values + by applying quantization. + The cache has two types of storage, one for original precision and one for the + quantized cache. A `residual length` is set as a maximum capacity for the original precision cache. When the + length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. + The quantization is done per-channel with a set `q_group_size` for both keys and values, in contrast to what was + described in the paper. See `Cache` for details on common methods that are implemented by all cache classes. """ - def __init__(self, backend, **kwargs) -> None: + def __init__( + self, + backend: str, + config: PretrainedConfig, + nbits: int = 4, + axis_key: int = 0, + axis_value: int = 0, + q_group_size: int = 64, + residual_length: int = 128, + ): if backend == "quanto": - processor = QuantoQuantizedCacheProcessor + layer_class = QuantoQuantizedLayer elif backend == "hqq": - processor = HQQQuantizedCacheProcessor + layer_class = HQQQuantizedLayer else: raise ValueError(f"Unknown quantization backend `{backend}`") - super().__init__(cache_processor=processor, **kwargs) + layers = [ + layer_class(nbits, axis_key, axis_value, q_group_size, residual_length) + for _ in range(config.num_hidden_layers) + ] + super().__init__(layers=layers) class QuantoQuantizedCache(QuantizedCache): """ - A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750). - It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. - - The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the - original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The - quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. - - It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and - Value in original precision states as a list of tensors, one for each layer. The size of each tensor - is `[batch_size, num_heads, seq_len - residual_length, head_dim]` - - Uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. + A quantizer cache similar to what is described in the + [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for keys and values + by applying quantization. + The cache has two types of storage, one for original precision and one for the + quantized cache. A `residual length` is set as a maximum capacity for the original precision cache. When the + length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. + The quantization is done per-channel with a set `q_group_size` for both keys and values, in contrast to what was + described in the paper. See `Cache` for details on common methods that are implemented by all cache classes. @@ -1637,7 +1330,7 @@ class QuantoQuantizedCache(QuantizedCache): ```python >>> # Run pip install quanto first if you don't have it yet - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") @@ -1645,32 +1338,36 @@ class QuantoQuantizedCache(QuantizedCache): >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward - >>> cache_config = QuantizedCacheConfig(nbits=4) - >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config) + >>> past_key_values = QuantoQuantizedCache(config=model.config, nbits=4) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation QuantoQuantizedCache() ``` """ - def __init__(self, **kwargs) -> None: - DynamicCache.__init__(self, cache_processor=QuantoQuantizedCacheProcessor, **kwargs) + def __init__( + self, + config: PretrainedConfig, + nbits: int = 4, + axis_key: int = 0, + axis_value: int = 0, + q_group_size: int = 64, + residual_length: int = 128, + ): + super().__init__("quanto", config, nbits, axis_key, axis_value, q_group_size, residual_length) class HQQQuantizedCache(QuantizedCache): """ - A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). - It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. - - The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the - original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The - quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. - - It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and - Value in original precision states as a list of tensors, one for each layer. The size of each tensor - is `[batch_size, num_heads, seq_len - residual_length, head_dim]` - - Uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. + A quantizer cache similar to what is described in the + [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). + It allows the model to generate longer sequence length without allocating too much memory for keys and values + by applying quantization. + The cache has two types of storage, one for original precision and one for the + quantized cache. A `residual length` is set as a maximum capacity for the original precision cache. When the + length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. + The quantization is done per-channel with a set `q_group_size` for both keys and values, in contrast to what was + described in the paper. See `Cache` for details on common methods that are implemented by all cache classes. @@ -1678,7 +1375,7 @@ class HQQQuantizedCache(QuantizedCache): ```python >>> # Run pip install hqq first if you don't have it yet - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") @@ -1686,17 +1383,23 @@ class HQQQuantizedCache(QuantizedCache): >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward - >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1) - >>> past_key_values = HQQQuantizedCache(cache_config=cache_config) + >>> past_key_values = HQQQuantizedCache(config=model.config, nbits=4, axis_key=1, axis_value=1) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation HQQQuantizedCache() ``` """ - def __init__(self, backend="HQQ", **kwargs) -> None: - assert backend == "HQQ" - DynamicCache.__init__(self, cache_processor=HQQQuantizedCacheProcessor, **kwargs) + def __init__( + self, + config: PretrainedConfig, + nbits: int = 4, + axis_key: int = 0, + axis_value: int = 0, + q_group_size: int = 64, + residual_length: int = 128, + ): + super().__init__("hqq", config, nbits, axis_key, axis_value, q_group_size, residual_length) class EncoderDecoderCache(Cache): @@ -1724,16 +1427,15 @@ class EncoderDecoderCache(Cache): >>> outputs.past_key_values # access cache filled with key/values from generation EncoderDecoderCache() ``` - """ - # Override @property from Cache - is_compileable = None + # Override @property from Cache -> this will be set in __init__ on the instances + is_compileable = False def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): - super().__init__(layer_classes=DynamicLayer) self.self_attention_cache = self_attention_cache self.cross_attention_cache = cross_attention_cache + # Override @property from Cache self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False) self.is_updated = {} @@ -1884,115 +1586,6 @@ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[ return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx) -def parse_processor_args(processor_class: Optional[type["CacheProcessor"]], kwargs: dict) -> tuple[dict, dict]: - """ - Parse processor arguments from kwargs based on the processor class init signature. - - Args: - processor_class: The processor class to inspect, or None - kwargs: Dictionary of keyword arguments - - Returns: - tuple: (processor_kwargs, remaining_kwargs) - """ - try: - params = list(inspect.signature(processor_class.__init__).parameters)[2:] - except Exception: - return {}, kwargs - - processor_kwargs = {k: kwargs[k] for k in params if k in kwargs} - remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs} - return processor_kwargs, remaining_kwargs - - -def parse_layer_args_from_model_config( - config: Optional[PretrainedConfig], - batch_size: Optional[int] = None, - max_cache_len: Optional[int] = None, - device: Union[torch.device, str, None] = None, - dtype: Optional[torch.dtype] = None, - layer_device_map: Optional[dict[int, torch.device]] = None, - tp_size: Optional[int] = None, - max_batch_size: Optional[int] = None, -) -> dict: - """ - Parse layer arguments from model configuration for cache initialization. - - Args: - config (`Optional[PretrainedConfig]`): Model configuration containing shape/device info. - batch_size (`Optional[int]`): Batch size for cache initialization. - max_cache_len (`Optional[int]`): Maximum sequence length for cache. - device (`Union[torch.device, str, None]`): Device for cache tensors. - dtype (`Optional[torch.dtype]`): Data type for cache tensors. - layer_device_map: Per-layer device mapping. - tp_size (`Optional[int]`): Tensor parallel size to adjust number of key/value heads. - max_batch_size (`Optional[int]`): Maximum batch size for cache initialization. - - Returns: - `dict`: Dictionary containing parsed layer arguments for cache initialization. - """ - # No model config -> must be a dynamic cache, return bare dict - if config is None: - return {} - # Build the args dict for hybrid, sliding or static - else: - # Hybrid/Sliding caches require a config that supports sliding_window (max_cache_len already used) - if ( - getattr(config, "layer_types", None) is not None - and "sliding_attention" in config.layer_types - and "full_attention" in config.layer_types - ): - if getattr(config, "sliding_window", None) is None: - raise ValueError( - "Setting up a hybrid or sliding window KVCache requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - # Adjust max_cache_len for sliding window layers (they can't be larger than sliding window) - max_cache_len = max_cache_len or config.max_position_embeddings - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads: - head_dim = ( - config.head_dim - if getattr(config, "head_dim", None) is not None - else config.hidden_size // config.num_attention_heads - ) - num_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - if tp_size is not None and tp_size > 1: - if num_heads % tp_size != 0: - raise ValueError( - f"Number of key value heads {num_heads} must be divisible by tensor parallel size {tp_size}." - ) - # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. - num_heads //= tp_size - layer_args = { - "batch_size": max_batch_size if max_batch_size is not None else batch_size, - "max_cache_len": max_cache_len, - "device": torch.device(device) if device is not None else None, - "dtype": dtype, - "layer_device_map": layer_device_map, - "head_dim": head_dim, - "num_heads": num_heads, - "sliding_window": getattr(config, "sliding_window", None), - } - return {k: v for k, v in layer_args.items() if v is not None} - - -LAYER_CLASS_MAP: dict[str, type["CacheLayerMixin"]] = { - "full_attention": StaticLayer, - "sliding_attention": SlidingWindowLayer, - "chunked_attention": ChunkedSlidingLayer, -} -PROCESSOR_CLASS_MAP: dict[str, type["CacheProcessor"]] = { - "offloaded": OffloadedCacheProcessor, - "quanto_quantized": QuantizedCacheProcessor, - "hqq_quantized": HQQQuantizedCacheProcessor, -} - - ### Deprecated classes @@ -2234,91 +1827,6 @@ def __init__(self, batch_size: int, max_cache_len: int, device="cpu"): self.max_cache_len = max_cache_len self.device = device - def initialise_cache_layer(self, layer_idx, key_states): - """Overridden to use the correct device if offloaded layer (and pin memory).""" - if len(self.key_cache) > layer_idx: - return - - num_key_value_heads = key_states.shape[1] - device = key_states.device if self.is_sliding[layer_idx] else self.offload_device - pin_memory = not self.is_sliding[layer_idx] - global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim) - sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self.sliding_window, self.head_dim) - # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. - cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device, pin_memory=pin_memory) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device, pin_memory=pin_memory) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - - # Make sure to initialize the on-device layer if it does not already exist - if self.device_key_cache is None and not self.is_sliding[layer_idx]: - self.device_key_cache = [] - self.device_value_cache = [] - # We need 2 layers to avoid race conditions when prefetching the next one - for _ in range(2): - device_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=key_states.device) - device_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=key_states.device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.device_key_cache.append(device_layer_key_cache) - self.device_value_cache.append(device_layer_value_cache) - - def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): - # Wait for prefetch stream if needed - if self._prefetch_stream is not None: - torch.cuda.default_stream(key_states.device).wait_stream(self._prefetch_stream) - - # Get correct on-device layer - k_out = self.device_key_cache[self.active_device_layer] - v_out = self.device_value_cache[self.active_device_layer] - - # Let's prefetch the next layer as soon as possible - self._prefetch_next_layer(layer_idx) - - # Copy to on-device layer - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - # Copy to offloaded device - self.key_cache[layer_idx][:, :, cache_position] = key_states.to(self.offload_device) - self.value_cache[layer_idx][:, :, cache_position] = value_states.to(self.offload_device) - - return k_out, v_out - - def _prefetch_next_layer(self, layer_idx: int) -> None: - """Based on current layer_idx, prefetch next full layer to the device.""" - - # Switch the active layer - self.active_device_layer = 0 if self.active_device_layer == 1 else 1 - - # Find the next non-sliding layer - try: - next_layer = layer_idx + 1 + self.is_sliding[layer_idx + 1 :].index(False) - # In this case, we are at the last layer, and we go back to prefect the first one - except ValueError: - next_layer = self.is_sliding.index(False) - - # Alternate between two on-device caches. - if self._prefetch_stream is not None: - with torch.cuda.stream(self._prefetch_stream): - self._prefetch_layer_in_context(next_layer) - else: - self._prefetch_layer_in_context(next_layer) - - def _prefetch_layer_in_context(self, layer_idx: int) -> None: - """Performs the actual copy of the layer to device cache.""" - if len(self.key_cache) > layer_idx: - self.device_key_cache[self.active_device_layer].copy_(self.key_cache[layer_idx], non_blocking=True) - self.device_value_cache[self.active_device_layer].copy_(self.value_cache[layer_idx], non_blocking=True) - # The layer was not yet initialized - else: - self.device_key_cache[self.active_device_layer].fill_(0.0) - self.device_value_cache[self.active_device_layer].fill_(0.0) - # TODO (manuel, joao): remove this class, it is here only for backwards compatibility # PEP 562: Lazy loading for deprecated location of MambaCache diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index af97522a92cc..7b40178023da 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1798,7 +1798,8 @@ def _get_initial_cache_position(self, seq_length, device, model_kwargs): if model_kwargs.get("past_key_values") is not None: cache = model_kwargs["past_key_values"] past_length = 0 - if not isinstance(cache, Cache): + # Support for BC tuple cache format + if isinstance(cache, tuple): past_length = cache[0][0].shape[2] elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: past_length = cache.get_seq_length() @@ -1808,87 +1809,7 @@ def _get_initial_cache_position(self, seq_length, device, model_kwargs): model_kwargs["cache_position"] = cache_position return model_kwargs - def _get_layer_device_map_for_cache_init(self) -> Optional[dict[int, Union[str, int]]]: - """ - Returns the device map for each decoder layer, to allocate the cache on the right device. - Inspired from `dispatch_model` in accelerate. - """ - execution_device_map = None - - if hasattr(self, "hf_device_map"): - if set(self.hf_device_map.values()) == {"cpu"} or set(self.hf_device_map.values()) == {"cpu", "disk"}: - main_device = "cpu" - else: - main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0] - execution_device_map = { - name: main_device if device in ["cpu", "disk"] else device - for name, device in self.hf_device_map.items() - } - - # No `execution_device_map` -> rely on `self.device` to allocate the cache - if execution_device_map is None: - return None - - # Single device for all layers - num_hidden_layers = self.config.get_text_config().num_hidden_layers - if len(execution_device_map) == 1 and "" in execution_device_map: - return dict.fromkeys(range(num_hidden_layers), execution_device_map[""]) - - # Multiple devices in `execution_device_map` -> we need to map decoder layers to the correct device. - layer_device_map = {} - # Case 1: The model has a `get_decoder` method, we can use it to find the decoder name. - if hasattr(self, "get_decoder"): - decoder_name = None - for name, module in self.named_modules(): - if module is self.get_decoder(): - decoder_name = name - break - if decoder_name is None: - raise RuntimeError( - "`model.get_decoder()` is not returning a named module of the model. This is unexpected, please " - "open an issue on GitHub." - ) - - decoder_mapped_modules = [ - module_name for module_name in execution_device_map if decoder_name in module_name - ] - # The decoder name may be present in `execution_device_map` in two forms: - # a) each layer has a device mapping - if len(decoder_mapped_modules) >= num_hidden_layers: - for idx in range(num_hidden_layers): - for module_name in decoder_mapped_modules: - if f".{idx}." in f"{module_name}.": - layer_device_map[idx] = execution_device_map[module_name] - break - - # b) the whole module is mapped to a single device. If the decoder name is NOT present in the device map, - # then the mapping is done in a parent module - else: - while True: - if decoder_name in execution_device_map: - layer_device_map = dict.fromkeys(range(num_hidden_layers), execution_device_map[decoder_name]) - break - elif "." in decoder_name: - decoder_name = decoder_name.rsplit(".", 1)[0] # gets the name of the parent module - else: - raise RuntimeError(f"Decoder name {decoder_name} not found in execution device map") - - # Case 2: Legacy code path: assume the decoder layers are named as `(...).X` (X being the layer index) - else: - for layer in execution_device_map: - for idx in range(num_hidden_layers): - if f".{idx}." in f"{layer}.": - layer_device_map[idx] = execution_device_map[layer] - break - - for idx in range(num_hidden_layers): - if idx not in layer_device_map: - raise RuntimeError(f"layer {idx} has not been mapped to a device.") - return layer_device_map - - def _get_cache( - self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs - ) -> Cache: + def _get_cache(self, cache_implementation: str, batch_size: int, max_cache_len: int, model_kwargs) -> Cache: """ Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a new `generate` call requires a larger cache or uses a different batch size. @@ -1926,23 +1847,10 @@ def _get_cache( ) if need_new_cache: - if hasattr(self.config, "_pre_quantization_dtype"): - cache_dtype = self.config._pre_quantization_dtype - else: - cache_dtype = self.dtype - - layer_device_map = self._get_layer_device_map_for_cache_init() cache_kwargs = { - "config": self.config.get_text_config(), - "max_batch_size": batch_size, "max_cache_len": max_cache_len, - "dtype": cache_dtype, - "device": device, - "layer_device_map": layer_device_map, + "config": self.config.get_text_config(), } - if cache_implementation in ["static", "hybrid", "offloaded_static"]: - cache_kwargs.update({"tp_size": self.tp_size}) - self._cache = cache_cls(**cache_kwargs) if requires_cross_attention_cache: encoder_kwargs = cache_kwargs.copy() @@ -1978,7 +1886,6 @@ def _prepare_cache_for_generation( assistant_model: "PreTrainedModel", batch_size: int, max_cache_length: int, - device: torch.device, ) -> bool: """ Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is @@ -2051,7 +1958,6 @@ def _prepare_cache_for_generation( cache_implementation=generation_config.cache_implementation, batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size, max_cache_len=max_cache_length, - device=device, model_kwargs=model_kwargs, ) elif generation_config.cache_implementation == "quantized": @@ -2473,7 +2379,7 @@ def generate( ): max_cache_length += inputs_tensor.shape[1] self._prepare_cache_for_generation( - generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device + generation_config, model_kwargs, assistant_model, batch_size, max_cache_length ) # 8. determine generation mode diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index f9bc88eaa138..0afe61ca78f5 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -514,13 +514,16 @@ def __init__( self.model = model self.static_cache = StaticCache( - config=config, - max_batch_size=generation_config.cache_config.get("batch_size"), max_cache_len=generation_config.cache_config.get("max_cache_len"), - device=generation_config.cache_config.get("device"), - dtype=self.model.dtype, + config=config, ) - + batch_size = generation_config.cache_config.get("batch_size") + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) + device = generation_config.cache_config.get("device") + dtype = self.model.dtype + # We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable) + self.static_cache.early_initialization(batch_size, num_heads, head_dim, dtype, device) for i in range(len(self.static_cache)): self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False) self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False) @@ -667,13 +670,14 @@ def __init__( raise AssertionError("Model must have caching enabled.") # Initialize the HybridCache - self.cache = HybridCache( - config=config, - max_batch_size=generation_config.cache_config.get("batch_size"), - max_cache_len=generation_config.cache_config.get("max_cache_len"), - device=generation_config.cache_config.get("device"), - dtype=self.model.dtype, - ) + self.cache = HybridCache(max_cache_len=generation_config.cache_config.get("max_cache_len"), config=config) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) + max_batch_size = generation_config.cache_config.get("batch_size") + device = generation_config.cache_config.get("device") + dtype = self.model.dtype + # We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable) + self.cache.early_initialization(max_batch_size, num_heads, head_dim, dtype, device) # Register all key and value cache tensors as buffers for i in range(len(self.cache)): @@ -814,13 +818,10 @@ def __init__(self, model, max_static_cache_length, batch_size): self.config = model.config # Initialize static cache for decoder and DynamicCache for encoder - self.static_cache = StaticCache( - config=self.config, - max_batch_size=batch_size, - max_cache_len=max_static_cache_length, - device="cpu", - dtype=torch.float32, - ) + self.static_cache = StaticCache(max_cache_len=max_static_cache_length, config=self.config) + head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads) + num_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) + self.static_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, "cpu") self.cache = EncoderDecoderCache(self.static_cache, DynamicCache()) # Register cache buffers to make them exportable diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index eaae2133b66b..fd3e8c94bd98 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -31,7 +31,7 @@ from transformers.activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, DynamicLayer +from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter @@ -86,7 +86,7 @@ class BambaFlashAttentionKwargs(TypedDict, total=False): seq_idx: torch.IntTensor -class HybridMambaAttentionDynamicCache(Cache): +class HybridMambaAttentionDynamicCache: """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -100,12 +100,9 @@ class HybridMambaAttentionDynamicCache(Cache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ - key_cache = None - value_cache = None is_compileable = False def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None): - super().__init__(layer_classes=DynamicLayer) self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba conv_kernel_size = config.mamba_d_conv @@ -181,13 +178,6 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return 0 return self.key_cache[layer_idx].shape[-2] - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - - @classmethod - def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - class BambaRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 95c5ce8e36d7..c75ca632a883 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -42,7 +42,6 @@ segment_sum, ) -from ...cache_utils import DynamicLayer from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel @@ -115,7 +114,6 @@ class HybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache): """ def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None): - HybridMambaAttentionDynamicCache.__init__(layer_classes=DynamicLayer) self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba conv_kernel_size = config.mamba_d_conv diff --git a/src/transformers/models/dia/generation_dia.py b/src/transformers/models/dia/generation_dia.py index 5111e77644b3..7cac22f0d483 100644 --- a/src/transformers/models/dia/generation_dia.py +++ b/src/transformers/models/dia/generation_dia.py @@ -347,7 +347,7 @@ def _main_generate_loop( ): max_cache_length += inputs_tensor.shape[1] self._prepare_cache_for_generation( - generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device + generation_config, model_kwargs, assistant_model, batch_size, max_cache_length ) # 8. determine generation mode diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index b9b38b4c9c98..bdbdced722aa 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -32,7 +32,7 @@ from transformers.activations import ACT2FN -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter @@ -63,7 +63,7 @@ logger = logging.get_logger(__name__) -class FalconHybridMambaAttentionDynamicCache(Cache): +class FalconHybridMambaAttentionDynamicCache: """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -77,8 +77,6 @@ class FalconHybridMambaAttentionDynamicCache(Cache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ - key_cache = None - value_cache = None is_compileable = False def __init__( @@ -187,13 +185,6 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return 0 return self.key_cache[layer_idx].shape[-2] - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: - raise NotImplementedError("FalconHybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - - @classmethod - def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - raise NotImplementedError("FalconHybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - def update_conv_state( self, layer_idx: int, diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index d95d83d2c9b3..3c16f4a297e0 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -216,7 +216,7 @@ def forward( embed_positions = self._get_embed_positions(position_ids) repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) - sincos = torch.gather(embed_positions, 1, repeated_position_ids) + sincos = torch.gather(embed_positions, 1, repeated_position_ids).to(key.dtype) sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) if self.rotary_dim is not None: @@ -302,7 +302,7 @@ def forward( embed_positions = self._get_embed_positions(position_ids) repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) - sincos = torch.gather(embed_positions, 1, repeated_position_ids) + sincos = torch.gather(embed_positions, 1, repeated_position_ids).to(key.dtype) sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) if self.rotary_dim is not None: diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 91439bb2a3c9..496ac56804ab 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -27,7 +27,7 @@ from transformers.activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, DynamicLayer +from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_layers import GradientCheckpointingLayer @@ -224,7 +224,7 @@ def forward( return attn_output, attn_weights -class HybridMambaAttentionDynamicCache(Cache): +class HybridMambaAttentionDynamicCache: """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -238,12 +238,9 @@ class HybridMambaAttentionDynamicCache(Cache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ - key_cache = None - value_cache = None is_compileable = False def __init__(self, config: GraniteMoeHybridConfig, batch_size, dtype=torch.float16, device=None): - super().__init__(layer_classes=DynamicLayer) self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba conv_kernel_size = config.mamba_d_conv @@ -319,13 +316,6 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return 0 return self.key_cache[layer_idx].shape[-2] - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - - @classmethod - def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - # Helper methods for segment sum computation diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 4fe7d6cee106..f412d589c27b 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -28,7 +28,6 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, DynamicLayer from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available @@ -191,7 +190,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class HybridMambaAttentionDynamicCache(Cache): +class HybridMambaAttentionDynamicCache: """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -205,12 +204,9 @@ class HybridMambaAttentionDynamicCache(Cache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ - key_cache = None - value_cache = None is_compileable = False def __init__(self, config, batch_size, dtype=torch.float16, device=None): - super().__init__(layer_classes=DynamicLayer) self.dtype = dtype self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba @@ -274,13 +270,6 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return 0 return self.key_cache[layer_idx].shape[-2] - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - - @classmethod - def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - # Adapted from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Jamba class JambaAttention(nn.Module): diff --git a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py index 89f269f8a0fc..f2659b56935a 100644 --- a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py @@ -1220,7 +1220,6 @@ def _prepare_model_inputs( cache_methods = [ "_prepare_cache_for_generation", "_get_cache", - "_get_layer_device_map_for_cache_init", ] for method in cache_methods: setattr(self.codec_model, method, types.MethodType(getattr(self, method).__func__, self.codec_model)) @@ -1229,13 +1228,13 @@ def _prepare_model_inputs( self.codec_model, "_supports_default_dynamic_cache", types.MethodType(lambda x: True, self.codec_model) ) + self.codec_model.generation_config.cache_implementation = "dynamic" self.codec_model._prepare_cache_for_generation( generation_config=self.codec_model.generation_config, model_kwargs=temporary_model_kwargs, assistant_model=None, batch_size=batch_size, max_cache_length=self.config.codec_config.sliding_window, - device=device, ) if "past_key_values" in temporary_model_kwargs: diff --git a/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py index e0e424ac605e..c1612ba435da 100644 --- a/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py @@ -344,7 +344,6 @@ def _prepare_model_inputs( cache_methods = [ "_prepare_cache_for_generation", "_get_cache", - "_get_layer_device_map_for_cache_init", ] for method in cache_methods: setattr(self.codec_model, method, types.MethodType(getattr(self, method).__func__, self.codec_model)) @@ -353,13 +352,13 @@ def _prepare_model_inputs( self.codec_model, "_supports_default_dynamic_cache", types.MethodType(lambda x: True, self.codec_model) ) + self.codec_model.generation_config.cache_implementation = "dynamic" self.codec_model._prepare_cache_for_generation( generation_config=self.codec_model.generation_config, model_kwargs=temporary_model_kwargs, assistant_model=None, batch_size=batch_size, max_cache_length=self.config.codec_config.sliding_window, - device=device, ) if "past_key_values" in temporary_model_kwargs: diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index 7fc244cb58ae..092a8b3caa82 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -23,7 +23,7 @@ import torch.nn.functional as F from torch import nn -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask @@ -122,7 +122,7 @@ def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) -class Lfm2HybridConvCache(DynamicCache): +class Lfm2HybridConvCache: """ Attention and conv cache for Lfm2. @@ -254,16 +254,12 @@ def crop(self, max_length: int): self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + def __len__(self) -> int: + return len(self.key_cache) + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: return self.key_cache[layer_idx], self.value_cache[layer_idx] - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: - raise NotImplementedError("Lfm2HybridConvCache does not have a legacy cache equivalent.") - - @classmethod - def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - raise NotImplementedError("Lfm2HybridConvCache does not have a legacy cache equivalent.") - def reset(self): for layer_idx in range(len(self.conv_cache)): # In-place ops prevent breaking the static address diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py index 046d79dbdd40..5d3791cbe3b1 100644 --- a/src/transformers/models/lfm2/modular_lfm2.py +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -17,7 +17,6 @@ import torch.nn.functional as F from torch import nn -from ...cache_utils import DynamicCache from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast @@ -81,7 +80,7 @@ def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) -class Lfm2HybridConvCache(DynamicCache): +class Lfm2HybridConvCache: """ Attention and conv cache for Lfm2. @@ -213,16 +212,12 @@ def crop(self, max_length: int): self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + def __len__(self) -> int: + return len(self.key_cache) + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: return self.key_cache[layer_idx], self.value_cache[layer_idx] - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: - raise NotImplementedError("Lfm2HybridConvCache does not have a legacy cache equivalent.") - - @classmethod - def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - raise NotImplementedError("Lfm2HybridConvCache does not have a legacy cache equivalent.") - def reset(self): for layer_idx in range(len(self.conv_cache)): # In-place ops prevent breaking the static address diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index a4beb1ddf980..74f61e79029e 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1266,7 +1266,6 @@ def generate( assistant_model=None, batch_size=batch_size, max_cache_length=max_cache_length, - device=input_ids_length.device, ) # 7. Prepare `input_ids` which will be used for auto-regressive generation diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index f2c3d6af4b82..b0afcf6da7ef 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -2184,7 +2184,6 @@ def generate( assistant_model=None, batch_size=batch_size, max_cache_length=max_cache_length, - device=inputs_tensor.device, ) # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen) diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 367b4dc4566c..5f1c592d3230 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1569,7 +1569,6 @@ def extend_enc_output(tensor, num_beams=None): assistant_model=None, batch_size=input_ids.shape[0], max_cache_length=generation_config.max_length - 1, - device=input_ids.device, ) if generation_config.num_beams == 1: diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index e04af25febb0..dfe5b9bdd589 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -28,7 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -94,7 +94,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class ZambaHybridDynamicCache(Cache): +class ZambaHybridDynamicCache: """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -108,8 +108,6 @@ class ZambaHybridDynamicCache(Cache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ - key_cache = None - value_cache = None is_compileable = False def __init__(self, config, batch_size, dtype=torch.float16, device=None): @@ -191,13 +189,6 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return 0 return self.key_cache[layer_idx].shape[-2] - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: - raise NotImplementedError("ZambaHybridDynamicCache does not have a legacy cache equivalent.") - - @classmethod - def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - raise NotImplementedError("ZambaHybridDynamicCache does not have a legacy cache equivalent.") - def eager_attention_forward( module: nn.Module, diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 2f1e1e0bc6b2..dcc88def1002 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -98,7 +98,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class Zamba2HybridDynamicCache(Cache): +class Zamba2HybridDynamicCache: """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -112,8 +112,6 @@ class Zamba2HybridDynamicCache(Cache): and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ - key_cache = None - value_cache = None is_compileable = False def __init__( @@ -192,13 +190,6 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return 0 return self.key_cache[layer_idx].shape[-2] - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: - raise NotImplementedError("Zamba2HybridDynamicCache does not have a legacy cache equivalent.") - - @classmethod - def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - raise NotImplementedError("Zamba2HybridDynamicCache does not have a legacy cache equivalent.") - def update_conv_state( self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor ) -> torch.Tensor: diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index c28ae9a5b144..97798ff9ed14 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -206,6 +206,7 @@ is_pytesseract_available, is_pytest_available, is_pytorch_quantization_available, + is_quanto_greater, is_quark_available, is_qutlass_available, is_rich_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index da740e68de9c..0ce888db6f99 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1278,6 +1278,24 @@ def is_huggingface_hub_greater_or_equal(library_version: str, accept_dev: bool = return version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(library_version) +@lru_cache +def is_quanto_greater(library_version: str, accept_dev: bool = False): + """ + Accepts a library version and returns True if the current version of the library is greater than or equal to the + given version. If `accept_dev` is True, it will also accept development versions (e.g. 2.7.0.dev20250320 matches + 2.7.0). + """ + if not _is_package_available("optimum-quanto"): + return False + + if accept_dev: + return version.parse(version.parse(importlib.metadata.version("optimum-quanto")).base_version) > version.parse( + library_version + ) + else: + return version.parse(importlib.metadata.version("optimum-quanto")) > version.parse(library_version) + + def is_torchdistx_available(): return _torchdistx_available diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 0f7966a9c9eb..c6376617341a 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -4083,16 +4083,7 @@ def test_init_static_cache_multi_accelerator(self): # ) # results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs) - # deduced from the device_map : layer 0 on device 0 and layer 1 on device 1 - layer_device_map = {0: 0, 1: 1} - past_key_values = StaticCache( - config=model.config, - max_batch_size=1, - max_cache_len=30, - device=torch_device, - dtype=model.dtype, - layer_device_map=layer_device_map, - ) + past_key_values = StaticCache(config=model.config, max_cache_len=30) results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs) # check device of each layer @@ -4287,13 +4278,7 @@ def test_prepare_inputs_for_generation_decoder_llm(self): max_cache_len = 10 batch_size = 2 query_length = input_ids.shape[-1] - init_input_ids.shape[-1] - static_cache = StaticCache( - config=config, - max_batch_size=batch_size, - max_cache_len=max_cache_len, - device=torch_device, - dtype=torch.float32, - ) + static_cache = StaticCache(config=config, max_cache_len=max_cache_len) static_cache = model(init_input_ids, past_key_values=static_cache).past_key_values model_inputs = model.prepare_inputs_for_generation( input_ids, past_key_values=static_cache, cache_position=cache_position, attention_mask=attention_mask diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py index 245163d672c3..653a8254616b 100644 --- a/tests/models/bamba/test_modeling_bamba.py +++ b/tests/models/bamba/test_modeling_bamba.py @@ -297,6 +297,26 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi # This is because we are hitting edge cases with the causal_mask buffer model_split_percents = [0.5, 0.7, 0.8] + def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config): + self.assertIsInstance(decoder_past_key_values, HybridMambaAttentionDynamicCache) + + # (batch, head, seq_length, head_features) + expected_shape = ( + batch_size, + config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads, + cache_length, + config.hidden_size // config.num_attention_heads, + ) + + self.assertListEqual( + [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], + [expected_shape] * len(decoder_past_key_values.key_cache), + ) + self.assertListEqual( + [value_cache.shape for value_cache in decoder_past_key_values.value_cache], + [expected_shape] * len(decoder_past_key_values.value_cache), + ) + def setUp(self): self.model_tester = self.model_tester_class(self) self.config_tester = ConfigTester(self, config_class=self.model_tester.config_class, hidden_size=64) diff --git a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py index 22dfe8be07b4..c14cc8b1d4b7 100644 --- a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py +++ b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py @@ -666,6 +666,7 @@ def prepare_config_and_inputs(self): vocab_size=self.vocab_size, d_model=self.d_model, decoder_layers=self.decoder_layers, + num_hidden_layers=self.decoder_layers, decoder_ffn_dim=self.decoder_ffn_dim, encoder_attention_heads=self.encoder_attention_heads, decoder_attention_heads=self.decoder_attention_heads, diff --git a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py index 8d75649d8cc1..5a05fd574684 100644 --- a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py +++ b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py @@ -419,6 +419,7 @@ def prepare_config_and_inputs(self): vocab_size=self.vocab_size, d_model=self.d_model, decoder_layers=self.decoder_layers, + num_hidden_layers=self.decoder_layers, decoder_ffn_dim=self.decoder_ffn_dim, encoder_attention_heads=self.encoder_attention_heads, decoder_attention_heads=self.decoder_attention_heads, diff --git a/tests/models/diffllama/test_modeling_diffllama.py b/tests/models/diffllama/test_modeling_diffllama.py index 25ca02d5ba43..f376fab87e14 100644 --- a/tests/models/diffllama/test_modeling_diffllama.py +++ b/tests/models/diffllama/test_modeling_diffllama.py @@ -764,13 +764,7 @@ def test_stacked_causal_mask_static_cache(self): # upgrade the model with StaticCache max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] - past_key_values = StaticCache( - config=self.model.config, - max_batch_size=1, - max_cache_len=max_cache_len, - device=torch_device, - dtype=self.model.dtype, - ) + past_key_values = StaticCache(config=self.model.config, max_cache_len=max_cache_len) padded_attention_mask = torch.nn.functional.pad( input=mask_shared_prefix, @@ -812,13 +806,7 @@ def test_partial_stacked_causal_mask_static_cache(self): # upgrade the model with StaticCache max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] - past_key_values = StaticCache( - config=self.model.config, - max_batch_size=1, - max_cache_len=max_cache_len, - device=torch_device, - dtype=self.model.dtype, - ) + past_key_values = StaticCache(config=self.model.config, max_cache_len=max_cache_len) # forward run for the first part of input part_a = 3 # split point diff --git a/tests/models/falcon_h1/test_modeling_falcon_h1.py b/tests/models/falcon_h1/test_modeling_falcon_h1.py index 37afc2cceba1..efcf00798de0 100644 --- a/tests/models/falcon_h1/test_modeling_falcon_h1.py +++ b/tests/models/falcon_h1/test_modeling_falcon_h1.py @@ -38,7 +38,7 @@ if is_torch_available(): import torch - from transformers import AutoTokenizer, Cache, FalconH1ForCausalLM, FalconH1Model + from transformers import AutoTokenizer, FalconH1ForCausalLM, FalconH1Model from transformers.models.falcon_h1.modeling_falcon_h1 import ( FalconHybridMambaAttentionDynamicCache, ) @@ -273,7 +273,7 @@ class FalconH1ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM ) def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config): - self.assertIsInstance(decoder_past_key_values, (tuple, Cache)) + self.assertIsInstance(decoder_past_key_values, FalconHybridMambaAttentionDynamicCache) # (batch, head, seq_length, head_features) expected_shape = ( @@ -283,31 +283,14 @@ def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_value config.hidden_size // config.num_attention_heads, ) - if isinstance(decoder_past_key_values, Cache): - self.assertListEqual( - [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], - [expected_shape] * len(decoder_past_key_values.key_cache), - ) - self.assertListEqual( - [value_cache.shape for value_cache in decoder_past_key_values.value_cache], - [expected_shape] * len(decoder_past_key_values.value_cache), - ) - - # Legacy cache format checks. This branch should be removed when all models use `Cache` by default - else: - self.assertListEqual( - [isinstance(iter_past_key_values, tuple) for iter_past_key_values in decoder_past_key_values], - [True] * len(decoder_past_key_values), - ) - # check shape key, value - self.assertListEqual( - [layer_past_key_values[0].shape for layer_past_key_values in decoder_past_key_values], - [expected_shape] * len(decoder_past_key_values), - ) - self.assertListEqual( - [layer_past_key_values[1].shape for layer_past_key_values in decoder_past_key_values], - [expected_shape] * len(decoder_past_key_values), - ) + self.assertListEqual( + [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], + [expected_shape] * len(decoder_past_key_values.key_cache), + ) + self.assertListEqual( + [value_cache.shape for value_cache in decoder_past_key_values.value_cache], + [expected_shape] * len(decoder_past_key_values.value_cache), + ) def setUp(self): self.model_tester = FalconH1ModelTester(self) diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index 98ccf21e59b3..c1627fc59f2f 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -342,6 +342,26 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi test_headmasking = False test_pruning = False + def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config): + self.assertIsInstance(decoder_past_key_values, HybridMambaAttentionDynamicCache) + + # (batch, head, seq_length, head_features) + expected_shape = ( + batch_size, + config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads, + cache_length, + config.hidden_size // config.num_attention_heads, + ) + + self.assertListEqual( + [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], + [expected_shape] * len(decoder_past_key_values.key_cache), + ) + self.assertListEqual( + [value_cache.shape for value_cache in decoder_past_key_values.value_cache], + [expected_shape] * len(decoder_past_key_values.value_cache), + ) + def setUp(self): self.model_tester = JambaModelTester(self) self.config_tester = JambaConfigTester(self, config_class=JambaConfig, hidden_size=37) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 0867a5a27068..d58837cc0fbd 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -504,13 +504,7 @@ def test_stacked_causal_mask_static_cache(self): # upgrade the model with StaticCache max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] - past_key_values = StaticCache( - config=self.model.config, - max_batch_size=1, - max_cache_len=max_cache_len, - device=torch_device, - dtype=self.model.dtype, - ) + past_key_values = StaticCache(max_cache_len=max_cache_len, config=self.model.config) padded_attention_mask = torch.nn.functional.pad( input=mask_shared_prefix, @@ -552,13 +546,7 @@ def test_partial_stacked_causal_mask_static_cache(self): # upgrade the model with StaticCache max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] - past_key_values = StaticCache( - config=self.model.config, - max_batch_size=1, - max_cache_len=max_cache_len, - device=torch_device, - dtype=self.model.dtype, - ) + past_key_values = StaticCache(max_cache_len=max_cache_len, config=self.model.config) # forward run for the first part of input part_a = 3 # split point diff --git a/tests/models/llama4/test_modeling_llama4.py b/tests/models/llama4/test_modeling_llama4.py index 5ecc4732a2ab..a0113dcb8eb7 100644 --- a/tests/models/llama4/test_modeling_llama4.py +++ b/tests/models/llama4/test_modeling_llama4.py @@ -83,7 +83,7 @@ def setUp(self): def tearDown(self): cleanup(torch_device, gc_collect=True) - def test_model_17b_16e_fp16(self): + def test_model_17b_16e_fp32(self): EXPECTED_TEXTS = Expectations( { ("xpu", 3): ['system\n\nYou are a helpful assistant.user\n\nWhat is shown in this image?assistant\n\nThe image shows a cow standing on a beach with a blue sky and a body of water in the background. The cow is brown with a white face'], diff --git a/tests/models/marian/test_modeling_marian.py b/tests/models/marian/test_modeling_marian.py index 291814efde5d..99afab0843b2 100644 --- a/tests/models/marian/test_modeling_marian.py +++ b/tests/models/marian/test_modeling_marian.py @@ -694,6 +694,7 @@ def prepare_config_and_inputs(self): vocab_size=self.vocab_size, d_model=self.d_model, decoder_layers=self.decoder_layers, + num_hidden_layers=self.decoder_layers, decoder_ffn_dim=self.decoder_ffn_dim, encoder_attention_heads=self.encoder_attention_heads, decoder_attention_heads=self.decoder_attention_heads, diff --git a/tests/models/mbart/test_modeling_mbart.py b/tests/models/mbart/test_modeling_mbart.py index 4ef22c3c30e0..0a69d0ad062f 100644 --- a/tests/models/mbart/test_modeling_mbart.py +++ b/tests/models/mbart/test_modeling_mbart.py @@ -591,6 +591,7 @@ def prepare_config_and_inputs(self): vocab_size=self.vocab_size, d_model=self.d_model, decoder_layers=self.decoder_layers, + num_hidden_layers=self.decoder_layers, decoder_ffn_dim=self.decoder_ffn_dim, encoder_attention_heads=self.encoder_attention_heads, decoder_attention_heads=self.decoder_attention_heads, diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 17964dd68c27..f56f03565b0e 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -320,8 +320,8 @@ def test_compile_static_cache(self): self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) # Static Cache + compile - forward_function = model.forward - model.forward = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True) + forward_function = model.__call__ + model.__call__ = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True) generated_ids = model.generate( **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" ) @@ -330,7 +330,7 @@ def test_compile_static_cache(self): # Sliding Window Cache + compile torch._dynamo.reset() - model.forward = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True) + model.__call__ = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True) generated_ids = model.generate( **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window" ) diff --git a/tests/models/pegasus/test_modeling_pegasus.py b/tests/models/pegasus/test_modeling_pegasus.py index af119c41d335..e1dd7676b348 100644 --- a/tests/models/pegasus/test_modeling_pegasus.py +++ b/tests/models/pegasus/test_modeling_pegasus.py @@ -452,6 +452,7 @@ def prepare_config_and_inputs(self): vocab_size=self.vocab_size, d_model=self.d_model, decoder_layers=self.decoder_layers, + num_hidden_layers=self.decoder_layers, decoder_ffn_dim=self.decoder_ffn_dim, encoder_attention_heads=self.encoder_attention_heads, decoder_attention_heads=self.decoder_attention_heads, diff --git a/tests/models/phi3/test_modeling_phi3.py b/tests/models/phi3/test_modeling_phi3.py index 6887c0c6cd64..f80015eeeb56 100644 --- a/tests/models/phi3/test_modeling_phi3.py +++ b/tests/models/phi3/test_modeling_phi3.py @@ -46,13 +46,7 @@ class Phi3MiniWithStaticCache(torch.nn.Module): def __init__(self, model: Phi3ForCausalLM, batch_size: int, max_seq_len: int): super().__init__() self.model = model - self.cache = StaticCache( - config=model.config, - max_batch_size=batch_size, - max_cache_len=max_seq_len, - device=self.model.device, - dtype=self.model.dtype, - ) + self.cache = StaticCache(config=model.config, max_cache_len=max_seq_len) def forward( self, diff --git a/tests/models/phimoe/test_modeling_phimoe.py b/tests/models/phimoe/test_modeling_phimoe.py index f8cf7d455d20..d53ca6173395 100644 --- a/tests/models/phimoe/test_modeling_phimoe.py +++ b/tests/models/phimoe/test_modeling_phimoe.py @@ -42,13 +42,7 @@ class PhimoeMiniWithStaticCache(torch.nn.Module): def __init__(self, model: PhimoeForCausalLM, batch_size: int, max_seq_len: int): super().__init__() self.model = model - self.cache = StaticCache( - config=model.config, - max_batch_size=batch_size, - max_cache_len=max_seq_len, - device=self.model.device, - dtype=self.model.dtype, - ) + self.cache = StaticCache(config=model.config, max_cache_len=max_seq_len) def forward( self, diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py index 7140373081bb..431417f4c18b 100644 --- a/tests/models/zamba/test_modeling_zamba.py +++ b/tests/models/zamba/test_modeling_zamba.py @@ -304,6 +304,26 @@ class ZambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi test_headmasking = False test_pruning = False + def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config): + self.assertIsInstance(decoder_past_key_values, ZambaHybridDynamicCache) + + # (batch, head, seq_length, head_features) + expected_shape = ( + batch_size, + config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads, + cache_length, + config.hidden_size // config.num_attention_heads, + ) + + self.assertListEqual( + [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], + [expected_shape] * len(decoder_past_key_values.key_cache), + ) + self.assertListEqual( + [value_cache.shape for value_cache in decoder_past_key_values.value_cache], + [expected_shape] * len(decoder_past_key_values.value_cache), + ) + def setUp(self): self.model_tester = ZambaModelTester(self) self.config_tester = ConfigTester(self, config_class=ZambaConfig, hidden_size=37) diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 3f35a54acb66..cb742707d713 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -315,6 +315,26 @@ class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix test_headmasking = False test_pruning = False + def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config): + self.assertIsInstance(decoder_past_key_values, Zamba2HybridDynamicCache) + + # (batch, head, seq_length, head_features) + expected_shape = ( + batch_size, + config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads, + cache_length, + config.hidden_size // config.num_attention_heads, + ) + + self.assertListEqual( + [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], + [expected_shape] * len(decoder_past_key_values.key_cache), + ) + self.assertListEqual( + [value_cache.shape for value_cache in decoder_past_key_values.value_cache], + [expected_shape] * len(decoder_past_key_values.value_cache), + ) + def setUp(self): self.model_tester = Zamba2ModelTester(self) self.config_tester = ConfigTester(self, config_class=Zamba2Config, hidden_size=37) diff --git a/tests/quantization/aqlm_integration/test_aqlm.py b/tests/quantization/aqlm_integration/test_aqlm.py index b339343627b3..2fbc4595f302 100644 --- a/tests/quantization/aqlm_integration/test_aqlm.py +++ b/tests/quantization/aqlm_integration/test_aqlm.py @@ -223,11 +223,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu # Setup static KV cache for generation past_key_values = StaticCache( - config=self.quantized_model.config, - max_batch_size=1, - max_cache_len=seq_length + self.max_new_tokens + 1, - device=torch_device, - dtype=self.quantized_model.config._pre_quantization_dtype, + config=self.quantized_model.config, max_cache_len=seq_length + self.max_new_tokens + 1 ) # Allocate token ids to be generated and copy prefix ids diff --git a/tests/quantization/spqr_integration/test_spqr.py b/tests/quantization/spqr_integration/test_spqr.py index 9f7ab7f4b9b1..443b687d54a8 100644 --- a/tests/quantization/spqr_integration/test_spqr.py +++ b/tests/quantization/spqr_integration/test_spqr.py @@ -204,11 +204,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu # Setup static KV cache for generation past_key_values = StaticCache( - config=self.quantized_model.config, - max_batch_size=1, - max_cache_len=seq_length + self.max_new_tokens + 1, - device=torch_device, - dtype=self.quantized_model.config._pre_quantization_dtype, + config=self.quantized_model.config, max_cache_len=seq_length + self.max_new_tokens + 1 ) # Allocate token ids to be generated and copy prefix ids diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 74b19395a67f..37d15452c7ed 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -49,12 +49,10 @@ DynamicCache, Gemma2Config, GenerationConfig, - HQQQuantizedCacheProcessor, HybridCache, HybridChunkedCache, LlamaConfig, QuantizedCache, - QuantoQuantizedCacheProcessor, SlidingWindowCache, StaticCache, convert_and_export_with_cache, @@ -142,7 +140,7 @@ def _random_kvs(config): return random_keys, random_values mha_config = LlamaConfig(num_attention_heads=32) - mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device) + mha_static_cache = StaticCache(config=mha_config, max_cache_len=10) cached_keys, cached_values = mha_static_cache.update( *_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -150,7 +148,7 @@ def _random_kvs(config): self.assertTrue(cached_values.shape == (1, 32, 10, 128)) gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4) - gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) + gqa_static_cache = StaticCache(config=gqa_config, max_cache_len=10) cached_keys, cached_values = gqa_static_cache.update( *_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -158,7 +156,7 @@ def _random_kvs(config): self.assertTrue(cached_values.shape == (1, 4, 10, 128)) mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1) - mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) + mqa_static_cache = StaticCache(config=mqa_config, max_cache_len=10) cached_keys, cached_values = mqa_static_cache.update( *_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -294,20 +292,11 @@ def test_quantized_cache_generation(self, backend): ) self.assertIsInstance(gen_out.past_key_values, QuantizedCache) - processor = gen_out.past_key_values.cache_processor - if backend == "quanto": - self.assertIsInstance(processor, QuantoQuantizedCacheProcessor) - elif backend == "hqq": - self.assertIsInstance(processor, HQQQuantizedCacheProcessor) decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True) self.assertListEqual(decoded, expected_generation) - self.assertTrue(len(processor._quantized_keys) > 0) - # Check that something is actually quantized - has_been_quantized = any((q[0] if isinstance(q, tuple) else q).numel() > 0 for q in processor._quantized_keys) - self.assertTrue(has_been_quantized) @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) def test_cache_extra_left_padding(self, cache_implementation): @@ -360,7 +349,7 @@ def test_dynamic_cache_hard(self): set_seed(0) gen_out = model.generate( - **inputs, do_sample=True, max_new_tokens=256, return_dict_in_generate=True, output_scores=True + **inputs, do_sample=True, top_k=5, max_new_tokens=256, return_dict_in_generate=True, output_scores=True ) decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True) # sum of the scores for the generated tokens @@ -371,21 +360,21 @@ def test_dynamic_cache_hard(self): EXPECTED_GENERATION = ( "Here's everything I know about cats. Cats are mammals, they have four legs, they have a tail, they have " - "a face with a nose, eyes, and mouth. They have fur, they have claws, and they have a body that is " - "covered in fur. They are carnivores, so they eat meat. They are also very clean animals, they groom " - "themselves. They have a lot of different breeds. Some are small, some are large. Some are friendly, " - "some are not. They have a lot of different personalities. They can be very independent, or they can be " - "very affectionate. They can be very playful, or they can be very lazy. They can be very intelligent, or " - "they can be very silly. They have a lot of different behaviors. They can be very curious, or they can " - "be very cautious. They can be very vocal, or they can be very quiet. They can be very social, or they " - "can be very solitary. They can be very active, or they can be very inactive. They can be very " - "affectionate, or they can be very aloof. They can be very playful, or they can be very lazy. They can " - "be very intelligent, or they can be very silly. They have a lot of different behaviors. They can be " - "very curious, or they can" - ) - EXPECTED_SCORE_SUM = 11017.4971 + "a face with a nose, eyes, and mouth. They have fur, they have claws, and they have whiskers. They are " + "usually small, but some are big. They are usually gray or black or white, but they can be many colors. " + "They have a soft body, they are usually quiet, but they can be loud. They are good at catching mice, " + "and they are good at climbing trees. They are often kept as pets, and they are often seen in homes. " + "They are independent, but they can be affectionate with their owners. They have a keen sense of smell, " + "and they can hear sounds that humans cannot hear. They have a good sense of balance, which helps them " + "to jump and climb. They are also good at hunting, and they can be trained to do tricks. They are often " + "used as pets, and they are also used in some jobs, like hunting or as service animals for people with " + "disabilities. They have a long life span, and they can live for many years. They are also known for " + "their agility and gracefulness. They are often associated with mystery and independence. They are also " + "known for their ability to land on their feet when they fall. They" + ) + EXPECTED_SCORE_SUM = 10834.7919921875 self.assertEqual(decoded[0], EXPECTED_GENERATION) - self.assertAlmostEqual(score_sum, EXPECTED_SCORE_SUM, places=2) + self.assertAlmostEqual(score_sum.item(), EXPECTED_SCORE_SUM, places=2) self.assertIsInstance(gen_out.past_key_values, DynamicCache) # sanity check @parameterized.expand([("eager"), ("sdpa")]) @@ -476,9 +465,7 @@ def test_cache_copy(self): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, device_map=torch_device, torch_dtype=torch.bfloat16) - prompt_cache = StaticCache( - config=model.config, max_batch_size=1, max_cache_len=1024, device=torch_device, dtype=torch.bfloat16 - ) + prompt_cache = StaticCache(config=model.config, max_cache_len=1024) INITIAL_PROMPT = "You are a helpful assistant. " inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to(torch_device) @@ -498,11 +485,11 @@ def test_cache_copy(self): responses.append(response) EXPECTED_DECODED_TEXT = [ - "You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an " - "enriching experience that broadens our horizons and allows us to explore the world beyond our comfort " - "zones. Whether it's a short weekend getaway", - "You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital " - "of France.\n\n\n\n\n\n\n<|endoftext|>", + "You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is a " + "wonderful way to explore the world, learn about different cultures, and create unforgettable " + "memories. Whether you're a seasoned traveler or someone", + "You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital" + " of France.\n\n\n\nAs an AI, I am not a human being.\n\n\n\nThe Great Wall of China is", ] self.assertEqual(responses, EXPECTED_DECODED_TEXT) @@ -899,12 +886,13 @@ def setUp(self): head_dim=1, hidden_size=1, sliding_window=self.window_size, + attention_chunk_size=self.window_size, layer_types=["full_attention"] * 1, # Static cache by default ) def test_static_cache_out_of_bounds(self): """Test StaticCache raises IndexError for out-of-bounds positions.""" - static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + static_cache = StaticCache(config=self.config, max_cache_len=self.max_cache_len) pos_out_of_bounds = torch.tensor([self.max_cache_len]) # Position >= max_cache_len with self.assertRaises(IndexError): @@ -926,7 +914,7 @@ def test_static_cache(self): update pos 3: [1.0, 2.0, 3.0, 4.0] """ # Scenario 1: Fill up to near capacity - static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len) + static_cache = StaticCache(config=self.config, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] static_cache.update(key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs=None) static_cache.update( @@ -968,19 +956,19 @@ def test_sliding_window_cache(self): # Scenario 1: Update within window, no slide yet config = copy.deepcopy(self.config) config.layer_types = ["sliding_attention"] * config.num_hidden_layers - sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) - prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] + sliding_cache = SlidingWindowCache(config=config, max_cache_len=self.max_cache_len) + prefill = torch.tensor([1.0, 2.0])[None, None, :, None] sliding_cache.update( key_states=prefill, value_states=prefill, layer_idx=0, - cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size}, + cache_kwargs={"cache_position": torch.arange(2)}, ) sliding_cache.update( key_states=torch.tensor(3.0)[None, None, None, None], value_states=torch.tensor(3.0)[None, None, None, None], layer_idx=0, - cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, + cache_kwargs={"cache_position": torch.tensor([2])}, ) self.assertEqual( sliding_cache.layers[0].keys[0, 0, :, 0].tolist(), @@ -989,19 +977,19 @@ def test_sliding_window_cache(self): ) # Scenario 2: Update causing slide - sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) + sliding_cache = SlidingWindowCache(config=config, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None] sliding_cache.update( key_states=prefill, value_states=prefill, layer_idx=0, - cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size}, + cache_kwargs={"cache_position": torch.arange(4)}, ) sliding_cache.update( key_states=torch.tensor(5.0)[None, None, None, None], value_states=torch.tensor(5.0)[None, None, None, None], layer_idx=0, - cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, + cache_kwargs={"cache_position": torch.tensor([4])}, ) self.assertEqual( sliding_cache.layers[0].keys[0, 0, :, 0].tolist(), @@ -1010,13 +998,13 @@ def test_sliding_window_cache(self): ) # Scenario 3: Long prompt handling - sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) + sliding_cache = SlidingWindowCache(config=config, max_cache_len=self.max_cache_len) long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None] sliding_cache.update( key_states=long_prefill, value_states=long_prefill, layer_idx=0, - cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size}, + cache_kwargs={"cache_position": torch.arange(6)}, ) self.assertEqual( sliding_cache.layers[0].keys[0, 0, :, 0].tolist(), @@ -1038,13 +1026,13 @@ def test_hybrid_cache_static_mode(self): config.layer_types = ["full_attention"] * config.num_hidden_layers # Scenario 1 - hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) - prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] + hybrid_cache_static_mode = HybridCache(config=config, max_cache_len=self.max_cache_len) + prefill = torch.tensor([1.0, 2.0])[None, None, :, None] hybrid_cache_static_mode.update( key_states=prefill, value_states=prefill, layer_idx=0, - cache_kwargs={"cache_position": torch.arange(4)}, + cache_kwargs={"cache_position": torch.arange(2)}, ) hybrid_cache_static_mode.update( key_states=torch.tensor(3.0)[None, None, None, None], @@ -1092,19 +1080,19 @@ def test_hybrid_cache_sliding_mode(self): config = copy.deepcopy(self.config) config.layer_types = ["sliding_attention"] * config.num_hidden_layers # Scenario 1: Update within window, no slide yet - hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) - prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None] + hybrid_cache = HybridCache(config=config, max_cache_len=self.max_cache_len) + prefill = torch.tensor([1.0, 2.0])[None, None, :, None] hybrid_cache.update( key_states=prefill, value_states=prefill, layer_idx=0, - cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size}, + cache_kwargs={"cache_position": torch.arange(2)}, ) hybrid_cache.update( key_states=torch.tensor(3.0)[None, None, None, None], value_states=torch.tensor(3.0)[None, None, None, None], layer_idx=0, - cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size}, + cache_kwargs={"cache_position": torch.tensor([2])}, ) self.assertEqual( hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), @@ -1113,19 +1101,19 @@ def test_hybrid_cache_sliding_mode(self): ) # Scenario 2: Update causing first slide - hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache = HybridCache(config=config, max_cache_len=self.max_cache_len) prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None] hybrid_cache.update( key_states=prefill, value_states=prefill, layer_idx=0, - cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size}, + cache_kwargs={"cache_position": torch.arange(4)}, ) hybrid_cache.update( key_states=torch.tensor(5.0)[None, None, None, None], value_states=torch.tensor(5.0)[None, None, None, None], layer_idx=0, - cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size}, + cache_kwargs={"cache_position": torch.tensor([4])}, ) self.assertEqual( hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), @@ -1138,7 +1126,7 @@ def test_hybrid_cache_sliding_mode(self): key_states=torch.tensor(6.0)[None, None, None, None], value_states=torch.tensor(6.0)[None, None, None, None], layer_idx=0, - cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size}, + cache_kwargs={"cache_position": torch.tensor([5])}, ) self.assertEqual( hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), @@ -1147,13 +1135,13 @@ def test_hybrid_cache_sliding_mode(self): ) # Scenario 4: Long prompt handling - hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache = HybridCache(config=config, max_cache_len=self.max_cache_len) long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None] hybrid_cache.update( key_states=long_prefill, value_states=long_prefill, layer_idx=0, - cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size}, + cache_kwargs={"cache_position": torch.arange(6)}, ) self.assertEqual( hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(), @@ -1222,7 +1210,7 @@ def test_hybrid_cache(self): config.num_hidden_layers = 2 config.layer_types = ["full_attention", "sliding_attention"] config.sliding_window = 2 - hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len) + hybrid_cache = HybridCache(config=config, max_cache_len=self.max_cache_len) # Prefill both layers up to cache capacity prefill_static = torch.tensor([1.0, 2.0, 3.0])[None, None, :, None] @@ -1324,9 +1312,9 @@ def test_hybrid_chunked_cache(self): config = copy.deepcopy(self.config) config.num_hidden_layers = 2 config.layer_types = ["full_attention", "chunked_attention"] - config.sliding_window = 2 + config.attention_chunk_size = 2 max_cache_len = 4 - chunked_cache = HybridChunkedCache(config=config, max_batch_size=1, max_cache_len=max_cache_len) + chunked_cache = HybridChunkedCache(config=config, max_cache_len=max_cache_len) # 1) PREFILL (3 tokens > sliding_window) prefill_static = torch.tensor([1.0, 2.0, 3.0])[None, None, :, None] @@ -1405,7 +1393,7 @@ def test_hybrid_chunked_cache_extra_cases(self): config.num_hidden_layers = 1 config.layer_types = ["chunked_attention"] config.sliding_window = 3 - cache = HybridChunkedCache(config, max_batch_size=1, max_cache_len=3) + cache = HybridChunkedCache(config=config, max_cache_len=3) # Step 0 : multi-token prefill first_chunk = torch.tensor([10.0, 20.0])[None, None, :, None] # L = 2