@@ -1464,13 +1464,31 @@ class EncoderDecoderCache(Cache):
14641464 ```
14651465 """
14661466
1467- def __init__ (self , self_attention_cache : Cache , cross_attention_cache : Cache ):
1468- self .self_attention_cache = self_attention_cache
1469- self .cross_attention_cache = cross_attention_cache
1467+ def __init__ (self , * caches ) -> None :
1468+ # For dp and ddp support, if only one argument is passed, it should be an iterable of tuples of tensors
1469+ if len (caches ) == 1 :
1470+ self .self_attention_cache = DynamicCache ()
1471+ self .cross_attention_cache = DynamicCache ()
1472+ # Populate cache from the iterable
1473+ for layer_idx , key_value_states in enumerate (caches [0 ]):
1474+ key_states , value_states = key_value_states [:2 ]
1475+ self .self_attention_cache .update (key_states , value_states , layer_idx )
1476+ if len (key_value_states ) > 2 :
1477+ key_states , value_states = key_value_states [2 :]
1478+ self .cross_attention_cache .update (key_states , value_states , layer_idx )
1479+ # Otherwise, we should get two arguments, a self-attention cache and a cross-attention cache
1480+ elif len (caches ) == 2 :
1481+ if not isinstance (caches [0 ], Cache ) or not isinstance (caches [1 ], Cache ):
1482+ raise TypeError (f"One of the two arguments is not a Cache: { type (caches [0 ]) = } , { type (caches [1 ]) = } " )
1483+ self .self_attention_cache = caches [0 ]
1484+ self .cross_attention_cache = caches [1 ]
1485+ # Error case
1486+ else :
1487+ raise ValueError (f"Expected 1 or 2 arguments, got { len (caches )} " )
14701488
14711489 self .is_updated = {}
1472- for layer_idx in range (len (cross_attention_cache )):
1473- self .is_updated [layer_idx ] = bool (cross_attention_cache .get_seq_length (layer_idx ) > 0 )
1490+ for layer_idx in range (len (self . cross_attention_cache )):
1491+ self .is_updated [layer_idx ] = bool (self . cross_attention_cache .get_seq_length (layer_idx ) > 0 )
14741492
14751493 def __repr__ (self ) -> str :
14761494 return (
@@ -1527,21 +1545,18 @@ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]:
15271545
15281546 @classmethod
15291547 def from_legacy_cache (
1530- cls , past_key_values : tuple [ tuple [torch .FloatTensor , torch . FloatTensor ], ... ]
1548+ cls , past_key_values : Optional [ Iterable [ tuple [torch .FloatTensor , ...]] ]
15311549 ) -> "EncoderDecoderCache" :
15321550 """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
1551+ cache = cls (DynamicCache (), DynamicCache ())
15331552 if past_key_values is None :
15341553 logger .warning_once ("past_key_values should not be None in from_legacy_cache()" )
1535- cache = cls (
1536- self_attention_cache = DynamicCache (),
1537- cross_attention_cache = DynamicCache (),
1538- )
1539- if past_key_values is not None :
1540- for layer_idx in range (len (past_key_values )):
1541- key_states , value_states = past_key_values [layer_idx ][:2 ]
1554+ else :
1555+ for layer_idx , key_value_states in enumerate (past_key_values ):
1556+ key_states , value_states = key_value_states [:2 ]
15421557 cache .self_attention_cache .update (key_states , value_states , layer_idx )
1543- if len (past_key_values [ layer_idx ] ) > 2 :
1544- key_states , value_states = past_key_values [ layer_idx ] [2 :]
1558+ if len (key_value_states ) > 2 :
1559+ key_states , value_states = key_value_states [2 :]
15451560 cache .cross_attention_cache .update (key_states , value_states , layer_idx )
15461561 cache .is_updated [layer_idx ] = True
15471562 return cache
0 commit comments