Skip to content

Commit e80c68a

Browse files
committed
remove cache configs, make CacheLayer a mixin (joaos review)
1 parent aec9ccd commit e80c68a

File tree

11 files changed

+542
-475
lines changed

11 files changed

+542
-475
lines changed

docs/source/en/kv_cache.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,15 +134,15 @@ The [`QuantizedCache`] reduces memory requirements by quantizing the KV values t
134134
> [!WARNING]
135135
> Quantizing the cache can harm latency if the context length is short and there is enough GPU memory available for generation without enabling cache quantization. Try to find a balance between memory efficiency and latency.
136136
137-
Enable [`QuantizedCache`] by configuring `cache_implementation="quantized"` in [`GenerationConfig`], and indicate the quantization backend in [`QuantizedCacheConfig`]. Any additional quantization related parameters should also be passed either as a dict or an instance of [`QuantizedCacheConfig`]. You should use the default values for these additional parameters unless you're running out-of-memory. In that case, consider decreasing the residual length.
137+
Enable [`QuantizedCache`] by configuring `cache_implementation="quantized"` in [`GenerationConfig`], and the quantization backend, as well as any additional quantization related parameters should also be passed either as a dict. You should use the default values for these additional parameters unless you're running out-of-memory. In that case, consider decreasing the residual length.
138138

139139
<hfoptions id="quantized-cache">
140140
<hfoption id="HQQQuantizedCache">
141141

142142
For [`HQQQuantizedCache`], we recommend setting the `axis-key` and `axis-value` parameters to `1`.
143143

144144
```py
145-
from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig
145+
from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache
146146

147147
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
148148
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0")
@@ -159,7 +159,7 @@ I like rock music because it's loud and energetic. It's a great way to express m
159159
For [`QuantoQuantizedCache`], we recommend setting the `axis-key` and `axis-value` parameters to `0`.
160160

161161
```py
162-
from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig
162+
from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache
163163

164164
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
165165
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0")
@@ -273,7 +273,6 @@ from transformers.cache_utils import (
273273
StaticCache,
274274
SlidingWindowCache,
275275
QuantoQuantizedCache,
276-
QuantizedCacheConfig,
277276
)
278277

279278
model_id = "meta-llama/Llama-2-7b-chat-hf"

docs/source/ko/internal/generation_utils.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -345,12 +345,6 @@ generation_output[:2]
345345
[[autodoc]] Cache
346346
- update
347347

348-
[[autodoc]] CacheConfig
349-
- update
350-
351-
[[autodoc]] QuantizedCacheConfig
352-
- validate
353-
354348
[[autodoc]] DynamicCache
355349
- update
356350
- get_seq_length

src/transformers/cache_utils.py

Lines changed: 481 additions & 425 deletions
Large diffs are not rendered by default.

src/transformers/generation/configuration_utils.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444

4545
logger = logging.get_logger(__name__)
4646
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
47-
CACHE_CONFIG_MAPPING = {}
4847
NEED_SETUP_CACHE_CLASSES_MAPPING = {}
4948
QUANT_BACKEND_CLASSES_MAPPING = {}
5049
ALL_CACHE_IMPLEMENTATIONS = []
@@ -56,18 +55,12 @@
5655
HybridChunkedCache,
5756
OffloadedHybridCache,
5857
OffloadedStaticCache,
59-
QuantizedCacheConfig,
6058
QuantoQuantizedCache,
6159
SlidingWindowCache,
6260
StaticCache,
63-
StaticCacheConfig,
6461
)
6562
from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor
6663

67-
CACHE_CONFIG_MAPPING["quantized"] = QuantizedCacheConfig
68-
CACHE_CONFIG_MAPPING["static"] = StaticCacheConfig
69-
CACHE_CONFIG_MAPPING["sliding_window"] = StaticCacheConfig
70-
CACHE_CONFIG_MAPPING["hybrid"] = StaticCacheConfig
7164
NEED_SETUP_CACHE_CLASSES_MAPPING = {
7265
"static": StaticCache,
7366
"offloaded_static": OffloadedStaticCache,
@@ -188,10 +181,8 @@ class GenerationConfig(PushToHubMixin):
188181
189182
If none is specified, we will use the default cache for the model (which is often [`DynamicCache`]). See
190183
our [cache documentation](https://huggingface.co/docs/transformers/en/kv_cache) for further information.
191-
cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
192-
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
193-
it will be converted to its respective `CacheConfig` internally.
194-
Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`.
184+
cache_config (`dict`, *optional*, default to `None`):
185+
Arguments used in the key-value cache class can be passed in `cache_config`.
195186
return_legacy_cache (`bool`, *optional*, default to `True`):
196187
Whether to return the legacy or new format of the cache when `DynamicCache` is used by default.
197188
@@ -406,10 +397,6 @@ def __init__(self, **kwargs):
406397
self.use_cache = kwargs.pop("use_cache", True)
407398
self.cache_implementation = kwargs.pop("cache_implementation", None)
408399
self.cache_config = kwargs.pop("cache_config", None)
409-
if self.cache_implementation is not None and self.cache_implementation in CACHE_CONFIG_MAPPING:
410-
cache_config_class = CACHE_CONFIG_MAPPING[self.cache_implementation]
411-
if isinstance(self.cache_config, dict):
412-
self.cache_config = cache_config_class.from_dict(self.cache_config)
413400
self.return_legacy_cache = kwargs.pop("return_legacy_cache", None)
414401
self.prefill_chunk_size = kwargs.pop("prefill_chunk_size", None)
415402

@@ -611,17 +598,6 @@ def validate(self, strict=False):
611598
f"Invalid `cache_implementation` ({self.cache_implementation}). Choose one of: "
612599
f"{ALL_CACHE_IMPLEMENTATIONS}"
613600
)
614-
if self.cache_config is not None:
615-
cache_class = CACHE_CONFIG_MAPPING.get(self.cache_implementation)
616-
if cache_class is None:
617-
raise ValueError(
618-
"You provided a `cache_config` but the cache implementation you are using "
619-
f"({self.cache_implementation}) does not require any config. Make sure to use the "
620-
"correct cache implementation matching your cache config."
621-
)
622-
if not isinstance(self.cache_config, cache_class):
623-
self.cache_config = cache_class.from_dict(self.cache_config)
624-
self.cache_config.validate()
625601
# 1.3. Performance attributes
626602
if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig):
627603
raise ValueError(

src/transformers/generation/utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
HybridChunkedCache,
3636
OffloadedCache,
3737
OffloadedHybridCache,
38-
QuantizedCacheConfig,
3938
)
4039
from ..configuration_utils import PretrainedConfig
4140
from ..dynamic_module_utils import (
@@ -2077,22 +2076,22 @@ def _prepare_cache_for_generation(
20772076
cache_config = (
20782077
generation_config.cache_config
20792078
if generation_config.cache_config is not None
2080-
else QuantizedCacheConfig()
2079+
else {"backend": "quanto"}
20812080
)
2082-
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
2081+
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config["backend"]]
20832082

2084-
if cache_config.backend == "quanto" and not is_optimum_quanto_available():
2083+
if cache_config["backend"] == "quanto" and not is_optimum_quanto_available():
20852084
raise ImportError(
20862085
"You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. "
20872086
"Please install it via with `pip install optimum-quanto`"
20882087
)
2089-
elif cache_config.backend == "HQQ" and not is_hqq_available():
2088+
elif cache_config["backend"] == "HQQ" and not is_hqq_available():
20902089
raise ImportError(
20912090
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
20922091
"Please install it via with `pip install hqq`"
20932092
)
20942093

2095-
model_kwargs[cache_name] = cache_class(cache_config)
2094+
model_kwargs[cache_name] = cache_class(**cache_config)
20962095
elif generation_config.cache_implementation == "offloaded":
20972096
model_kwargs[cache_name] = OffloadedCache()
20982097
elif generation_config.cache_implementation == "dynamic":

src/transformers/integrations/executorch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,9 @@ def __init__(self, model: PreTrainedModel):
276276
self.model = model
277277
self.static_cache = StaticCache(
278278
model_config=self.model.config,
279-
max_batch_size=self.model.generation_config.cache_config.batch_size,
280-
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
281-
device=self.model.generation_config.cache_config.device,
279+
max_batch_size=self.model.generation_config.cache_config.get("batch_size"),
280+
max_cache_len=self.model.generation_config.cache_config.get("max_cache_len"),
281+
device=self.model.generation_config.cache_config.get("device"),
282282
dtype=self.model.dtype,
283283
)
284284
for i in range(len(self.static_cache)):

src/transformers/masking_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ def create_causal_mask(
693693
"""
694694
# If we have an HybridCache structure, here we want to create the mask for the full layers
695695
is_sliding = []
696-
if past_key_values is not None:
696+
if past_key_values is not None and past_key_values.layers is not None:
697697
is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers]
698698
layer_idx = is_sliding.index(True) if True in is_sliding else 0
699699

@@ -775,7 +775,7 @@ def create_sliding_window_causal_mask(
775775
"""
776776
# If we have an HybridCache structure, here we want to create the mask for the sliding layers
777777
is_sliding = []
778-
if past_key_values is not None:
778+
if past_key_values is not None and past_key_values.layers is not None:
779779
is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers]
780780
layer_idx = is_sliding.index(True) if True in is_sliding else 0
781781

src/transformers/models/zamba/modeling_zamba.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ def __init__(self, config, batch_size, dtype=torch.float16, device=None):
146146
def __len__(self):
147147
return len(self.key_cache)
148148

149+
def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
150+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
151+
149152
# Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.update
150153
def update(
151154
self,

src/transformers/models/zamba2/modeling_zamba2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ def __init__(
150150
def __len__(self):
151151
return len(self.key_cache)
152152

153+
def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
154+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
155+
153156
def update(
154157
self,
155158
key_states: torch.Tensor,

tests/models/falcon_h1/test_modeling_falcon_h1.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
if is_torch_available():
3737
import torch
3838

39-
from transformers import AutoTokenizer, FalconH1ForCausalLM, FalconH1Model
39+
from transformers import AutoTokenizer, Cache, FalconH1ForCausalLM, FalconH1Model
4040
from transformers.models.falcon_h1.modeling_falcon_h1 import (
4141
FalconHybridMambaAttentionDynamicCache,
4242
)
@@ -270,6 +270,43 @@ class FalconH1ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
270270
{"feature-extraction": FalconH1Model, "text-generation": FalconH1ForCausalLM} if is_torch_available() else {}
271271
)
272272

273+
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
274+
self.assertIsInstance(decoder_past_key_values, (tuple, Cache))
275+
276+
# (batch, head, seq_length, head_features)
277+
expected_shape = (
278+
batch_size,
279+
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
280+
cache_length,
281+
config.hidden_size // config.num_attention_heads,
282+
)
283+
284+
if isinstance(decoder_past_key_values, Cache):
285+
self.assertListEqual(
286+
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
287+
[expected_shape] * len(decoder_past_key_values.key_cache),
288+
)
289+
self.assertListEqual(
290+
[value_cache.shape for value_cache in decoder_past_key_values.value_cache],
291+
[expected_shape] * len(decoder_past_key_values.value_cache),
292+
)
293+
294+
# Legacy cache format checks. This branch should be removed when all models use `Cache` by default
295+
else:
296+
self.assertListEqual(
297+
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in decoder_past_key_values],
298+
[True] * len(decoder_past_key_values),
299+
)
300+
# check shape key, value
301+
self.assertListEqual(
302+
[layer_past_key_values[0].shape for layer_past_key_values in decoder_past_key_values],
303+
[expected_shape] * len(decoder_past_key_values),
304+
)
305+
self.assertListEqual(
306+
[layer_past_key_values[1].shape for layer_past_key_values in decoder_past_key_values],
307+
[expected_shape] * len(decoder_past_key_values),
308+
)
309+
273310
def setUp(self):
274311
self.model_tester = FalconH1ModelTester(self)
275312
self.config_tester = ConfigTester(self, config_class=FalconH1Config, hidden_size=64)

0 commit comments

Comments
 (0)