Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,13 +1464,31 @@ class EncoderDecoderCache(Cache):
```
"""

def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
self.self_attention_cache = self_attention_cache
self.cross_attention_cache = cross_attention_cache
def __init__(self, *caches) -> None:
# For dp and ddp support, if only one argument is passed, it should be an iterable of tuples of tensors
if len(caches) == 1:
self.self_attention_cache = DynamicCache()
self.cross_attention_cache = DynamicCache()
# Populate cache from the iterable
for layer_idx, key_value_states in enumerate(caches[0]):
key_states, value_states = key_value_states[:2]
self.self_attention_cache.update(key_states, value_states, layer_idx)
if len(key_value_states) > 2:
key_states, value_states = key_value_states[2:]
self.cross_attention_cache.update(key_states, value_states, layer_idx)
# Otherwise, we should get two arguments, a self-attention cache and a cross-attention cache
elif len(caches) == 2:
if not isinstance(caches[0], Cache) or not isinstance(caches[1], Cache):
raise TypeError(f"One of the two arguments is not a Cache: {type(caches[0]) = }, {type(caches[1]) = }")
self.self_attention_cache = caches[0]
self.cross_attention_cache = caches[1]
# Error case
else:
raise ValueError(f"Expected 1 or 2 arguments, got {len(caches)}")

self.is_updated = {}
for layer_idx in range(len(cross_attention_cache)):
self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
for layer_idx in range(len(self.cross_attention_cache)):
self.is_updated[layer_idx] = bool(self.cross_attention_cache.get_seq_length(layer_idx) > 0)

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -1527,21 +1545,18 @@ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]:

@classmethod
def from_legacy_cache(
cls, past_key_values: tuple[tuple[torch.FloatTensor, torch.FloatTensor], ...]
cls, past_key_values: Optional[Iterable[tuple[torch.FloatTensor, ...]]]
) -> "EncoderDecoderCache":
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
cache = cls(DynamicCache(), DynamicCache())
if past_key_values is None:
logger.warning_once("past_key_values should not be None in from_legacy_cache()")
cache = cls(
self_attention_cache=DynamicCache(),
cross_attention_cache=DynamicCache(),
)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx][:2]
else:
for layer_idx, key_value_states in enumerate(past_key_values):
key_states, value_states = key_value_states[:2]
cache.self_attention_cache.update(key_states, value_states, layer_idx)
if len(past_key_values[layer_idx]) > 2:
key_states, value_states = past_key_values[layer_idx][2:]
if len(key_value_states) > 2:
key_states, value_states = key_value_states[2:]
cache.cross_attention_cache.update(key_states, value_states, layer_idx)
cache.is_updated[layer_idx] = True
return cache
Expand Down
4 changes: 1 addition & 3 deletions src/transformers/models/blip/modeling_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,7 @@ def forward(
elif isinstance(past_key_values, DynamicCache):
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
elif past_key_values is None:
past_key_values = EncoderDecoderCache(
self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache()
)
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())

all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
Expand Down
8 changes: 7 additions & 1 deletion tests/models/t5/test_modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
from transformers.testing_utils import (
Expectations,
cleanup,
require_accelerate,
require_sentencepiece,
Expand Down Expand Up @@ -1200,7 +1201,12 @@ def test_small_integration_test(self):
loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss
mtf_score = -(labels.shape[-1] * loss.item())

EXPECTED_SCORE = -19.0845
EXPECTED_SCORE = Expectations(
{
(None, None): -19.0845,
("rocm", (9, 4)): -19.0846,
}
).get_expectation()
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)

@slow
Expand Down
4 changes: 0 additions & 4 deletions tests/models/t5gemma/test_modeling_t5gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,10 +1386,6 @@ def test_flex_attention_with_grads(self):
# If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605)
_ = model(**dummy_inputs)

@unittest.skip("EncoderDecoderCache can't be gathered because it is not iterable.")
def test_multi_gpu_data_parallel_forward(self):
pass


class T5GemmaEncoderOnlyModelTester:
config_class = T5GemmaConfig
Expand Down