Skip to content

Commit 41d55aa

Browse files
committed
improve
1 parent a1cdb3b commit 41d55aa

File tree

2 files changed

+46
-37
lines changed

2 files changed

+46
-37
lines changed

src/transformers/cache_utils.py

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ def __init__(self):
3737
self.keys, self.values = None, None
3838
self.cumulative_length = 0
3939

40+
def __repr__(self):
41+
return f"{self.__class__.__name__}"
42+
4043
@abstractmethod
4144
def update(
4245
self, key_states: torch.Tensor, value_states: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None
@@ -74,9 +77,9 @@ def reset(self) -> None:
7477
self.values.zero_()
7578
self.cumulative_length = 0
7679

77-
def reorder_cache(self, beam_idx: torch.LongTensor) -> tuple[torch.Tensor, torch.Tensor]:
80+
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
7881
"""Reorders this layer's cache for beam search."""
79-
if self.keys is not None and self.keys.numel():
82+
if self.get_seq_length() > 0:
8083
self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
8184
self.values = self.values.index_select(0, beam_idx.to(self.values.device))
8285

@@ -141,19 +144,19 @@ def crop(self, max_length: int) -> None:
141144
if self.get_seq_length() <= max_length:
142145
return
143146

144-
if self.keys is not None and self.keys.numel():
145-
self.keys = self.keys[..., :max_length, :]
146-
self.values = self.values[..., :max_length, :]
147+
self.keys = self.keys[..., :max_length, :]
148+
self.values = self.values[..., :max_length, :]
149+
self.cumulative_length = max_length
147150

148151
def batch_repeat_interleave(self, repeats: int) -> None:
149152
"""Repeat the cache `repeats` times in the batch dimension."""
150-
if self.keys is not None and self.keys.numel():
153+
if self.get_seq_length() > 0:
151154
self.keys = self.keys.repeat_interleave(repeats, dim=0)
152155
self.values = self.values.repeat_interleave(repeats, dim=0)
153156

154157
def batch_select_indices(self, indices: torch.Tensor) -> None:
155158
"""Only keep the `indices` in the batch dimension of the cache."""
156-
if self.keys is not None and self.keys.numel():
159+
if self.get_seq_length() > 0:
157160
self.keys = self.keys[indices, ...]
158161
self.values = self.values[indices, ...]
159162

@@ -167,24 +170,9 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
167170

168171
@classmethod
169172
def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "DynamicLayer":
170-
"""
171-
Build a `DynamicLayer` instance from pre-existing key/value tensors.
172-
173-
Args:
174-
keys (`torch.Tensor`):
175-
Key cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
176-
values (`torch.Tensor`):
177-
Value cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
178-
179-
Returns:
180-
`DynamicLayer`: The newly constructed layer whose internal cache directly references
181-
the supplied tensors.
182-
"""
173+
"""Build a `DynamicLayer` instance from pre-existing key/value tensors."""
183174
layer = cls()
184-
layer.dtype, layer.device = keys.dtype, keys.device
185-
layer.cumulative_length = keys.shape[-2]
186-
layer.keys = keys
187-
layer.values = values
175+
_, _ = layer.update(keys, values)
188176
return layer
189177

190178

@@ -198,7 +186,6 @@ class DynamicSlidingWindowLayer(DynamicLayer):
198186
def __init__(self, sliding_window: int):
199187
super().__init__()
200188
self.sliding_window = sliding_window
201-
self.cumulative_length = 0
202189

203190
def get_max_cache_shape(self) -> int:
204191
"""Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length."""
@@ -553,7 +540,6 @@ def __init__(
553540
self.axis_value = axis_value
554541
self.q_group_size = q_group_size
555542
self.residual_length = residual_length
556-
self.cumulative_length = 0
557543

558544
def update(
559545
self,
@@ -599,10 +585,6 @@ def update(
599585

600586
return keys_to_return, values_to_return
601587

602-
def get_seq_length(self, cache_position=None) -> int:
603-
"""Returns the sequence length of the cached states."""
604-
return self.cumulative_length
605-
606588
@abstractmethod
607589
def _quantize(self, tensor, axis): ...
608590

@@ -710,7 +692,13 @@ def _dequantize(self, qtensor):
710692
return tensor
711693

712694

713-
LAYER_CLASS_MAP: dict[str, type[CacheLayerMixin]] = {
695+
DYNAMIC_LAYER_CLASS_MAPPING: dict[str, type[CacheLayerMixin]] = {
696+
"full_attention": DynamicLayer,
697+
"sliding_attention": DynamicSlidingWindowLayer,
698+
"chunked_attention": DynamicSlidingWindowLayer,
699+
}
700+
701+
STATIC_LAYER_CLASS_MAPPING: dict[str, type[CacheLayerMixin]] = {
714702
"full_attention": StaticLayer,
715703
"sliding_attention": SlidingWindowLayer,
716704
"chunked_attention": ChunkedSlidingLayer,
@@ -997,7 +985,7 @@ def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.T
997985
else:
998986
super().__init__(layer_class_to_replicate=DynamicLayer)
999987

1000-
def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]:
988+
def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]:
1001989
"""
1002990
Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
1003991
backward compatibility.
@@ -1008,7 +996,7 @@ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]:
1008996
return legacy_cache
1009997

1010998
@classmethod
1011-
def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.FloatTensor, torch.FloatTensor], ...]) -> "Cache":
999+
def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.Tensor, torch.Tensor]]) -> "DynamicCache":
10121000
"""
10131001
Converts a cache in the legacy cache format into an equivalent `Cache`. Used for
10141002
backward compatibility.
@@ -1069,6 +1057,27 @@ def _unflatten_dynamic_cache(
10691057
)
10701058

10711059

1060+
class HybridDynamicCache(Cache):
1061+
1062+
def __init__(self, config: PretrainedConfig):
1063+
sliding_window = getattr(config, "sliding_window", None) or getattr("attention_chunk_size", None)
1064+
if hasattr(config, "layer_types"):
1065+
layers = []
1066+
init_kwargs = {}
1067+
for layer_type in config.layer_types:
1068+
if layer_type == "sliding_attention":
1069+
init_kwargs["sliding_window"] = config.sliding_window
1070+
elif layer_type == "chunked_attention":
1071+
init_kwargs["sliding_window"] = config.attention_chunk_size
1072+
layers.append(DYNAMIC_LAYER_CLASS_MAPPING[layer_type](**init_kwargs))
1073+
elif sliding_window is not None:
1074+
# In this case, fall back to a full sliding cache
1075+
layers = [DynamicSlidingWindowLayer(sliding_window) for _ in range(config.num_hidden_layers)]
1076+
else:
1077+
# In this case, fallback to DynamicCache
1078+
layers = [DynamicLayer() for _ in range(config.num_hidden_layers)]
1079+
super().__init__(layers=layers)
1080+
10721081
class OffloadedCache(Cache):
10731082
"""
10741083
A drop-in replacement for DynamicCache that conserves accelerator (GPU, XPU) memory at the expense of more CPU memory.
@@ -1217,7 +1226,7 @@ def __init__(self, max_cache_len: int, config: PretrainedConfig, **kwargs):
12171226
init_kwargs["sliding_window"] = config.sliding_window
12181227
elif layer_type == "chunked_attention":
12191228
init_kwargs["sliding_window"] = config.attention_chunk_size
1220-
layers.append(LAYER_CLASS_MAP[layer_type](**init_kwargs))
1229+
layers.append(STATIC_LAYER_CLASS_MAPPING[layer_type](**init_kwargs))
12211230
else:
12221231
# In this case, fall back to StaticCache
12231232
layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)]
@@ -1249,7 +1258,7 @@ def __init__(self, max_cache_len: int, config: PretrainedConfig, **kwargs):
12491258
init_kwargs["sliding_window"] = config.sliding_window
12501259
elif layer_type == "chunked_attention":
12511260
init_kwargs["sliding_window"] = config.attention_chunk_size
1252-
layers.append(LAYER_CLASS_MAP[layer_type](**init_kwargs))
1261+
layers.append(STATIC_LAYER_CLASS_MAPPING[layer_type](**init_kwargs))
12531262
else:
12541263
# In this case, fall back to StaticCache
12551264
layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)]

src/transformers/models/mistral/modular_mistral.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from transformers.utils.generic import check_model_inputs
77

8-
from ...cache_utils import Cache, DynamicCache
8+
from ...cache_utils import Cache, HybridDynamicCache
99
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
1010
from ...modeling_flash_attention_utils import FlashAttentionKwargs
1111
from ...modeling_layers import (
@@ -132,7 +132,7 @@ def forward(
132132
inputs_embeds = self.embed_tokens(input_ids)
133133

134134
if use_cache and past_key_values is None:
135-
past_key_values = DynamicCache()
135+
past_key_values = HybridDynamicCache(self.config)
136136

137137
if cache_position is None:
138138
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0

0 commit comments

Comments
 (0)