diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 56a9e7a4b5a9..0c1524c6164a 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1464,13 +1464,31 @@ class EncoderDecoderCache(Cache): ``` """ - def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): - self.self_attention_cache = self_attention_cache - self.cross_attention_cache = cross_attention_cache + def __init__(self, *caches) -> None: + # For dp and ddp support, if only one argument is passed, it should be an iterable of tuples of tensors + if len(caches) == 1: + self.self_attention_cache = DynamicCache() + self.cross_attention_cache = DynamicCache() + # Populate cache from the iterable + for layer_idx, key_value_states in enumerate(caches[0]): + key_states, value_states = key_value_states[:2] + self.self_attention_cache.update(key_states, value_states, layer_idx) + if len(key_value_states) > 2: + key_states, value_states = key_value_states[2:] + self.cross_attention_cache.update(key_states, value_states, layer_idx) + # Otherwise, we should get two arguments, a self-attention cache and a cross-attention cache + elif len(caches) == 2: + if not isinstance(caches[0], Cache) or not isinstance(caches[1], Cache): + raise TypeError(f"One of the two arguments is not a Cache: {type(caches[0]) = }, {type(caches[1]) = }") + self.self_attention_cache = caches[0] + self.cross_attention_cache = caches[1] + # Error case + else: + raise ValueError(f"Expected 1 or 2 arguments, got {len(caches)}") self.is_updated = {} - for layer_idx in range(len(cross_attention_cache)): - self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) + for layer_idx in range(len(self.cross_attention_cache)): + self.is_updated[layer_idx] = bool(self.cross_attention_cache.get_seq_length(layer_idx) > 0) def __repr__(self) -> str: return ( @@ -1527,21 +1545,18 @@ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]: @classmethod def from_legacy_cache( - cls, past_key_values: tuple[tuple[torch.FloatTensor, torch.FloatTensor], ...] + cls, past_key_values: Optional[Iterable[tuple[torch.FloatTensor, ...]]] ) -> "EncoderDecoderCache": """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" + cache = cls(DynamicCache(), DynamicCache()) if past_key_values is None: logger.warning_once("past_key_values should not be None in from_legacy_cache()") - cache = cls( - self_attention_cache=DynamicCache(), - cross_attention_cache=DynamicCache(), - ) - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx][:2] + else: + for layer_idx, key_value_states in enumerate(past_key_values): + key_states, value_states = key_value_states[:2] cache.self_attention_cache.update(key_states, value_states, layer_idx) - if len(past_key_values[layer_idx]) > 2: - key_states, value_states = past_key_values[layer_idx][2:] + if len(key_value_states) > 2: + key_states, value_states = key_value_states[2:] cache.cross_attention_cache.update(key_states, value_states, layer_idx) cache.is_updated[layer_idx] = True return cache diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 4aa44a9afb4f..8c3fc0fa4e53 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -446,9 +446,7 @@ def forward( elif isinstance(past_key_values, DynamicCache): past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) elif past_key_values is None: - past_key_values = EncoderDecoderCache( - self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache() - ) + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index 535008a1a02f..fc704fd23920 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -25,6 +25,7 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 from transformers.testing_utils import ( + Expectations, cleanup, require_accelerate, require_sentencepiece, @@ -1200,7 +1201,12 @@ def test_small_integration_test(self): loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss mtf_score = -(labels.shape[-1] * loss.item()) - EXPECTED_SCORE = -19.0845 + EXPECTED_SCORE = Expectations( + { + (None, None): -19.0845, + ("rocm", (9, 4)): -19.0846, + } + ).get_expectation() self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) @slow diff --git a/tests/models/t5gemma/test_modeling_t5gemma.py b/tests/models/t5gemma/test_modeling_t5gemma.py index 95279fae5bc0..abfc7e74c646 100644 --- a/tests/models/t5gemma/test_modeling_t5gemma.py +++ b/tests/models/t5gemma/test_modeling_t5gemma.py @@ -1386,10 +1386,6 @@ def test_flex_attention_with_grads(self): # If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605) _ = model(**dummy_inputs) - @unittest.skip("EncoderDecoderCache can't be gathered because it is not iterable.") - def test_multi_gpu_data_parallel_forward(self): - pass - class T5GemmaEncoderOnlyModelTester: config_class = T5GemmaConfig