@@ -1464,31 +1464,32 @@ class EncoderDecoderCache(Cache):
14641464 ```
14651465 """
14661466
1467- # Override @property from Cache -> this will be set in __init__ on the instances
1468- is_compileable = False
1469-
14701467 def __init__ (self , * caches ) -> None :
1471- # If only one argument is passed, it should be a legacy cache ie. an iterable of tuples of tensors
1468+ # If only one argument is passed, it should be an iterable of tuples of tensors
14721469 # This is not only for legacy reason, but also to be compatible with nn.DataParallel
14731470 if len (caches ) == 1 :
1474- self_attention_cache , cross_attention_cache = self .create_dynamic_caches_from_legacy_cache (caches [0 ])
1471+ self .self_attention_cache = DynamicCache ()
1472+ self .cross_attention_cache = DynamicCache ()
1473+ # Populate cache from the iterable
1474+ for layer_idx , key_value_states in enumerate (caches [0 ]):
1475+ key_states , value_states = key_value_states [:2 ]
1476+ self .self_attention_cache .update (key_states , value_states , layer_idx )
1477+ if len (key_value_states ) > 2 :
1478+ key_states , value_states = key_value_states [2 :]
1479+ self .cross_attention_cache .update (key_states , value_states , layer_idx )
14751480 # Otherwise, we should get two arguments, a self-attention cache and a cross-attention cache
14761481 elif len (caches ) == 2 :
14771482 assert isinstance (caches [0 ], Cache ), f"{ type (caches [0 ]) = } is not a Cache"
14781483 assert isinstance (caches [1 ], Cache ), f"{ type (caches [1 ]) = } is not a Cache"
1479- self_attention_cache = caches [0 ]
1480- cross_attention_cache = caches [1 ]
1484+ self . self_attention_cache = caches [0 ]
1485+ self . cross_attention_cache = caches [1 ]
14811486 # Error case
14821487 else :
14831488 raise ValueError (f"Expected 1 or 2 arguments, got { len (caches )} " )
14841489
1485- # Initialize caches
1486- self .self_attention_cache = self_attention_cache
1487- self .cross_attention_cache = cross_attention_cache
1488-
14891490 self .is_updated = {}
1490- for layer_idx in range (len (cross_attention_cache )):
1491- self .is_updated [layer_idx ] = bool (cross_attention_cache .get_seq_length (layer_idx ) > 0 )
1491+ for layer_idx in range (len (self . cross_attention_cache )):
1492+ self .is_updated [layer_idx ] = bool (self . cross_attention_cache .get_seq_length (layer_idx ) > 0 )
14921493
14931494 def __repr__ (self ) -> str :
14941495 return (
0 commit comments