Skip to content

Commit ada6520

Browse files
committed
Init is self-contained
1 parent 86c22de commit ada6520

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

src/transformers/cache_utils.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,31 +1464,32 @@ class EncoderDecoderCache(Cache):
14641464
```
14651465
"""
14661466

1467-
# Override @property from Cache -> this will be set in __init__ on the instances
1468-
is_compileable = False
1469-
14701467
def __init__(self, *caches) -> None:
1471-
# If only one argument is passed, it should be a legacy cache ie. an iterable of tuples of tensors
1468+
# If only one argument is passed, it should be an iterable of tuples of tensors
14721469
# This is not only for legacy reason, but also to be compatible with nn.DataParallel
14731470
if len(caches) == 1:
1474-
self_attention_cache, cross_attention_cache = self.create_dynamic_caches_from_legacy_cache(caches[0])
1471+
self.self_attention_cache = DynamicCache()
1472+
self.cross_attention_cache = DynamicCache()
1473+
# Populate cache from the iterable
1474+
for layer_idx, key_value_states in enumerate(caches[0]):
1475+
key_states, value_states = key_value_states[:2]
1476+
self.self_attention_cache.update(key_states, value_states, layer_idx)
1477+
if len(key_value_states) > 2:
1478+
key_states, value_states = key_value_states[2:]
1479+
self.cross_attention_cache.update(key_states, value_states, layer_idx)
14751480
# Otherwise, we should get two arguments, a self-attention cache and a cross-attention cache
14761481
elif len(caches) == 2:
14771482
assert isinstance(caches[0], Cache), f"{type(caches[0]) = } is not a Cache"
14781483
assert isinstance(caches[1], Cache), f"{type(caches[1]) = } is not a Cache"
1479-
self_attention_cache = caches[0]
1480-
cross_attention_cache = caches[1]
1484+
self.self_attention_cache = caches[0]
1485+
self.cross_attention_cache = caches[1]
14811486
# Error case
14821487
else:
14831488
raise ValueError(f"Expected 1 or 2 arguments, got {len(caches)}")
14841489

1485-
# Initialize caches
1486-
self.self_attention_cache = self_attention_cache
1487-
self.cross_attention_cache = cross_attention_cache
1488-
14891490
self.is_updated = {}
1490-
for layer_idx in range(len(cross_attention_cache)):
1491-
self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
1491+
for layer_idx in range(len(self.cross_attention_cache)):
1492+
self.is_updated[layer_idx] = bool(self.cross_attention_cache.get_seq_length(layer_idx) > 0)
14921493

14931494
def __repr__(self) -> str:
14941495
return (

0 commit comments

Comments
 (0)