Skip to content

Commit 7a01fcd

Browse files
committed
raushan review, arthur review
1 parent 26c28af commit 7a01fcd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+513
-618
lines changed

docs/source/en/cache_explanation.md

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,24 +82,18 @@ When you use Transformers' [`Cache`] class, the self-attention module performs s
8282

8383
## Cache storage implementation
8484

85-
The actual storage of key-value pairs varies between cache implementations. As an example, consider the [`DynamicCache`].
85+
Caches are structured as a list of layers, where each layer contains a key and value cache. The key and value caches are tensors with the shape `[batch_size, num_heads, seq_len, head_dim]`.
8686

87+
Layers can be of different types (e.g. `DynamicLayer`, `StaticLayer`, `SlidingWindowLayer`), which mostly changes how sequence length is handled and how the cache is updated.
8788

88-
In [`DynamicCache`], the key-value pairs are stored as two lists of tensors. Each tensor in the lists have the shape `[batch_size, num_heads, seq_len, head_dim]`.
89-
- `key_cache`: A list of tensors, one for each layer.
90-
- `value_cache`: A list of tensors, one for each layer.
89+
The simplest is a `DynamicLayer` that grows as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token:
9190

92-
When new tokens are processed:
93-
94-
1. For each layer, the new key and value states are concatenated with the existing cache.
9591
```py
96-
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
97-
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
92+
cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2)
93+
cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2)
9894
```
9995

100-
2. The cache grows dynamically as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token.
101-
102-
3. The cache maintains a count of seen tokens through `self._seen_tokens`. This is updated when the first layer processes a new token.
96+
Other layers like `StaticLayer` and `SlidingWindowLayer` have a fixed sequence length that is set when the cache is created. This makes them compatible with `torch.compile`. In the case of `SlidingWindowLayer`, existing tokens are shifted out of the cache when a new token is added.
10397

10498
The example below demonstrates how to create a generation loop with [`DynamicCache`]. As discussed, the attention mask is a concatenation of past and current token values and `1` is added to the cache position for the next token.
10599

@@ -143,7 +137,7 @@ The legacy format is essentially the same data structure but organized different
143137
- The tensors have the same shape `[batch_size, num_heads, seq_len, head_dim]`.
144138
- The format is less flexible and doesn't support features like quantization or offloading.
145139

146-
If your project depends on this legacy format, you can convert between [`DynamicCache`] and a tuple of tuples as shown below with the [`~DynamicCache.from_legacy_cache`] and [`DynamicCache.to_legacy_cache`] functions. This is helpful if you have custom logic for manipulating a cache in a specific format.
140+
If your project depends on this legacy format, you can convert between [`DynamicCache`] and a tuple of tuples as shown below with the [`~Cache.from_legacy_cache`] and [`Cache.to_legacy_cache`] functions. This is helpful if you have custom logic for manipulating a cache in a specific format.
147141

148142
```py
149143
import torch

src/transformers/cache_utils.py

Lines changed: 291 additions & 387 deletions
Large diffs are not rendered by default.

src/transformers/generation/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1951,7 +1951,7 @@ def _get_cache(
19511951

19521952
layer_device_map = self._get_layer_device_map_for_cache_init()
19531953
cache_kwargs = {
1954-
"config": self.config.get_text_config(),
1954+
"model_config": self.config.get_text_config(),
19551955
"max_batch_size": batch_size,
19561956
"max_cache_len": max_cache_len,
19571957
"dtype": cache_dtype,

src/transformers/integrations/executorch.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -275,15 +275,15 @@ def __init__(self, model: PreTrainedModel):
275275

276276
self.model = model
277277
self.static_cache = StaticCache(
278-
config=self.model.config,
278+
model_config=self.model.config,
279279
max_batch_size=self.model.generation_config.cache_config.batch_size,
280280
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
281281
device=self.model.generation_config.cache_config.device,
282282
dtype=self.model.dtype,
283283
)
284-
for i in range(len(self.static_cache.key_cache)):
285-
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
286-
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
284+
for i in range(len(self.static_cache)):
285+
self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False)
286+
self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False)
287287

288288
def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
289289
"""
@@ -404,17 +404,17 @@ def __init__(
404404

405405
# Initialize the HybridCache
406406
self.cache = HybridCache(
407-
config=self.model.config,
407+
model_config=self.model.config,
408408
max_batch_size=max_batch_size,
409409
max_cache_len=max_cache_len,
410410
device=self.model.device,
411411
dtype=self.model.dtype,
412412
)
413413

414414
# Register all key and value cache tensors as buffers
415-
for i in range(len(self.cache.key_cache)):
416-
self.register_buffer(f"key_cache_{i}", self.cache.key_cache[i], persistent=False)
417-
self.register_buffer(f"value_cache_{i}", self.cache.value_cache[i], persistent=False)
415+
for i in range(len(self.cache)):
416+
self.register_buffer(f"key_cache_{i}", self.cache.layers[i].keys, persistent=False)
417+
self.register_buffer(f"value_cache_{i}", self.cache.layers[i].values, persistent=False)
418418

419419
def forward(
420420
self,
@@ -550,17 +550,17 @@ def __init__(self, model, max_static_cache_length, batch_size):
550550

551551
# Initialize static cache
552552
self.static_cache = StaticCache(
553-
config=self.config,
553+
model_config=self.config,
554554
max_batch_size=batch_size,
555555
max_cache_len=max_static_cache_length,
556556
device="cpu",
557557
dtype=torch.float32,
558558
)
559559

560560
# Register cache buffers to make them exportable
561-
for i in range(len(self.static_cache.key_cache)):
562-
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
563-
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
561+
for i in range(len(self.static_cache)):
562+
self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False)
563+
self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False)
564564

565565
def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
566566
# Get outputs from decoder

src/transformers/masking_utils.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -692,10 +692,8 @@ def create_causal_mask(
692692
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
693693
"""
694694
# If we have an HybridCache structure, here we want to create the mask for the full layers
695-
if hasattr(past_key_values, "is_sliding") and False in past_key_values.is_sliding:
696-
layer_idx = past_key_values.is_sliding.index(False)
697-
else:
698-
layer_idx = 0
695+
is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers]
696+
layer_idx = is_sliding.index(True) if True in is_sliding else 0
699697

700698
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
701699
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
@@ -774,10 +772,8 @@ def create_sliding_window_causal_mask(
774772
useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
775773
"""
776774
# If we have an HybridCache structure, here we want to create the mask for the sliding layers
777-
if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
778-
layer_idx = past_key_values.is_sliding.index(True)
779-
else:
780-
layer_idx = 0
775+
is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers]
776+
layer_idx = is_sliding.index(True) if True in is_sliding else 0
781777

782778
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
783779
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
@@ -861,10 +857,8 @@ def create_chunked_causal_mask(
861857
useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
862858
"""
863859
# If we have an HybridCache structure, here we want to create the mask for the sliding layers
864-
if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
865-
layer_idx = past_key_values.is_sliding.index(True)
866-
else:
867-
layer_idx = 0
860+
is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers]
861+
layer_idx = is_sliding.index(True) if True in is_sliding else 0
868862

869863
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
870864
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx

src/transformers/models/bart/modeling_bart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,8 @@ def forward(
230230
current_states = key_value_states if is_cross_attention else hidden_states
231231
if is_cross_attention and past_key_value is not None and is_updated:
232232
# reuse k,v, cross_attentions
233-
key_states = curr_past_key_value.key_cache[self.layer_idx]
234-
value_states = curr_past_key_value.value_cache[self.layer_idx]
233+
key_states = curr_past_key_value.layers[self.layer_idx].keys
234+
value_states = curr_past_key_value.layers[self.layer_idx].values
235235
else:
236236
key_states = self.k_proj(current_states)
237237
value_states = self.v_proj(current_states)

src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,8 +1293,8 @@ def forward(
12931293
current_states = key_value_states if is_cross_attention else hidden_states
12941294
if is_cross_attention and past_key_value is not None and is_updated:
12951295
# reuse k,v, cross_attentions
1296-
key_states = curr_past_key_value.key_cache[self.layer_idx]
1297-
value_states = curr_past_key_value.value_cache[self.layer_idx]
1296+
key_states = curr_past_key_value.layers[self.layer_idx].keys
1297+
value_states = curr_past_key_value.layers[self.layer_idx].values
12981298
else:
12991299
key_states = self.k_proj(current_states)
13001300
value_states = self.v_proj(current_states)

src/transformers/models/biogpt/modeling_biogpt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ def forward(
207207
current_states = key_value_states if is_cross_attention else hidden_states
208208
if is_cross_attention and past_key_value is not None and is_updated:
209209
# reuse k,v, cross_attentions
210-
key_states = curr_past_key_value.key_cache[self.layer_idx]
211-
value_states = curr_past_key_value.value_cache[self.layer_idx]
210+
key_states = curr_past_key_value.layers[self.layer_idx].keys
211+
value_states = curr_past_key_value.layers[self.layer_idx].values
212212
else:
213213
key_states = self.k_proj(current_states)
214214
value_states = self.v_proj(current_states)

src/transformers/models/blenderbot/modeling_blenderbot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,8 @@ def forward(
229229
current_states = key_value_states if is_cross_attention else hidden_states
230230
if is_cross_attention and past_key_value is not None and is_updated:
231231
# reuse k,v, cross_attentions
232-
key_states = curr_past_key_value.key_cache[self.layer_idx]
233-
value_states = curr_past_key_value.value_cache[self.layer_idx]
232+
key_states = curr_past_key_value.layers[self.layer_idx].keys
233+
value_states = curr_past_key_value.layers[self.layer_idx].values
234234
else:
235235
key_states = self.k_proj(current_states)
236236
value_states = self.v_proj(current_states)

src/transformers/models/blenderbot_small/modeling_blenderbot_small.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,8 @@ def forward(
213213
current_states = key_value_states if is_cross_attention else hidden_states
214214
if is_cross_attention and past_key_value is not None and is_updated:
215215
# reuse k,v, cross_attentions
216-
key_states = curr_past_key_value.key_cache[self.layer_idx]
217-
value_states = curr_past_key_value.value_cache[self.layer_idx]
216+
key_states = curr_past_key_value.layers[self.layer_idx].keys
217+
value_states = curr_past_key_value.layers[self.layer_idx].values
218218
else:
219219
key_states = self.k_proj(current_states)
220220
value_states = self.v_proj(current_states)

0 commit comments

Comments
 (0)