@@ -37,6 +37,9 @@ def __init__(self):
3737 self .keys , self .values = None , None
3838 self .cumulative_length = 0
3939
40+ def __repr__ (self ):
41+ return f"{ self .__class__ .__name__ } "
42+
4043 @abstractmethod
4144 def update (
4245 self , key_states : torch .Tensor , value_states : torch .Tensor , cache_kwargs : Optional [dict [str , Any ]] = None
@@ -74,9 +77,9 @@ def reset(self) -> None:
7477 self .values .zero_ ()
7578 self .cumulative_length = 0
7679
77- def reorder_cache (self , beam_idx : torch .LongTensor ) -> tuple [ torch . Tensor , torch . Tensor ] :
80+ def reorder_cache (self , beam_idx : torch .LongTensor ) -> None :
7881 """Reorders this layer's cache for beam search."""
79- if self .keys is not None and self . keys . numel () :
82+ if self .get_seq_length () > 0 :
8083 self .keys = self .keys .index_select (0 , beam_idx .to (self .keys .device ))
8184 self .values = self .values .index_select (0 , beam_idx .to (self .values .device ))
8285
@@ -141,19 +144,19 @@ def crop(self, max_length: int) -> None:
141144 if self .get_seq_length () <= max_length :
142145 return
143146
144- if self .keys is not None and self .keys . numel ():
145- self .keys = self .keys [..., :max_length , :]
146- self .values = self . values [..., : max_length , :]
147+ self .keys = self .keys [..., : max_length , :]
148+ self .values = self .values [..., :max_length , :]
149+ self .cumulative_length = max_length
147150
148151 def batch_repeat_interleave (self , repeats : int ) -> None :
149152 """Repeat the cache `repeats` times in the batch dimension."""
150- if self .keys is not None and self . keys . numel () :
153+ if self .get_seq_length () > 0 :
151154 self .keys = self .keys .repeat_interleave (repeats , dim = 0 )
152155 self .values = self .values .repeat_interleave (repeats , dim = 0 )
153156
154157 def batch_select_indices (self , indices : torch .Tensor ) -> None :
155158 """Only keep the `indices` in the batch dimension of the cache."""
156- if self .keys is not None and self . keys . numel () :
159+ if self .get_seq_length () > 0 :
157160 self .keys = self .keys [indices , ...]
158161 self .values = self .values [indices , ...]
159162
@@ -167,24 +170,9 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
167170
168171 @classmethod
169172 def from_tensors (cls , keys : torch .Tensor , values : torch .Tensor ) -> "DynamicLayer" :
170- """
171- Build a `DynamicLayer` instance from pre-existing key/value tensors.
172-
173- Args:
174- keys (`torch.Tensor`):
175- Key cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
176- values (`torch.Tensor`):
177- Value cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
178-
179- Returns:
180- `DynamicLayer`: The newly constructed layer whose internal cache directly references
181- the supplied tensors.
182- """
173+ """Build a `DynamicLayer` instance from pre-existing key/value tensors."""
183174 layer = cls ()
184- layer .dtype , layer .device = keys .dtype , keys .device
185- layer .cumulative_length = keys .shape [- 2 ]
186- layer .keys = keys
187- layer .values = values
175+ _ , _ = layer .update (keys , values )
188176 return layer
189177
190178
@@ -198,7 +186,6 @@ class DynamicSlidingWindowLayer(DynamicLayer):
198186 def __init__ (self , sliding_window : int ):
199187 super ().__init__ ()
200188 self .sliding_window = sliding_window
201- self .cumulative_length = 0
202189
203190 def get_max_cache_shape (self ) -> int :
204191 """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length."""
@@ -553,7 +540,6 @@ def __init__(
553540 self .axis_value = axis_value
554541 self .q_group_size = q_group_size
555542 self .residual_length = residual_length
556- self .cumulative_length = 0
557543
558544 def update (
559545 self ,
@@ -599,10 +585,6 @@ def update(
599585
600586 return keys_to_return , values_to_return
601587
602- def get_seq_length (self , cache_position = None ) -> int :
603- """Returns the sequence length of the cached states."""
604- return self .cumulative_length
605-
606588 @abstractmethod
607589 def _quantize (self , tensor , axis ): ...
608590
@@ -710,7 +692,13 @@ def _dequantize(self, qtensor):
710692 return tensor
711693
712694
713- LAYER_CLASS_MAP : dict [str , type [CacheLayerMixin ]] = {
695+ DYNAMIC_LAYER_CLASS_MAPPING : dict [str , type [CacheLayerMixin ]] = {
696+ "full_attention" : DynamicLayer ,
697+ "sliding_attention" : DynamicSlidingWindowLayer ,
698+ "chunked_attention" : DynamicSlidingWindowLayer ,
699+ }
700+
701+ STATIC_LAYER_CLASS_MAPPING : dict [str , type [CacheLayerMixin ]] = {
714702 "full_attention" : StaticLayer ,
715703 "sliding_attention" : SlidingWindowLayer ,
716704 "chunked_attention" : ChunkedSlidingLayer ,
@@ -997,7 +985,7 @@ def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.T
997985 else :
998986 super ().__init__ (layer_class_to_replicate = DynamicLayer )
999987
1000- def to_legacy_cache (self ) -> tuple [tuple [torch .Tensor , torch .Tensor ], ... ]:
988+ def to_legacy_cache (self ) -> tuple [tuple [torch .Tensor , torch .Tensor ]]:
1001989 """
1002990 Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
1003991 backward compatibility.
@@ -1008,7 +996,7 @@ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]:
1008996 return legacy_cache
1009997
1010998 @classmethod
1011- def from_legacy_cache (cls , past_key_values : tuple [tuple [torch .FloatTensor , torch .FloatTensor ], ... ]) -> "Cache " :
999+ def from_legacy_cache (cls , past_key_values : tuple [tuple [torch .Tensor , torch .Tensor ] ]) -> "DynamicCache " :
10121000 """
10131001 Converts a cache in the legacy cache format into an equivalent `Cache`. Used for
10141002 backward compatibility.
@@ -1069,6 +1057,27 @@ def _unflatten_dynamic_cache(
10691057 )
10701058
10711059
1060+ class HybridDynamicCache (Cache ):
1061+
1062+ def __init__ (self , config : PretrainedConfig ):
1063+ sliding_window = getattr (config , "sliding_window" , None ) or getattr ("attention_chunk_size" , None )
1064+ if hasattr (config , "layer_types" ):
1065+ layers = []
1066+ init_kwargs = {}
1067+ for layer_type in config .layer_types :
1068+ if layer_type == "sliding_attention" :
1069+ init_kwargs ["sliding_window" ] = config .sliding_window
1070+ elif layer_type == "chunked_attention" :
1071+ init_kwargs ["sliding_window" ] = config .attention_chunk_size
1072+ layers .append (DYNAMIC_LAYER_CLASS_MAPPING [layer_type ](** init_kwargs ))
1073+ elif sliding_window is not None :
1074+ # In this case, fall back to a full sliding cache
1075+ layers = [DynamicSlidingWindowLayer (sliding_window ) for _ in range (config .num_hidden_layers )]
1076+ else :
1077+ # In this case, fallback to DynamicCache
1078+ layers = [DynamicLayer () for _ in range (config .num_hidden_layers )]
1079+ super ().__init__ (layers = layers )
1080+
10721081class OffloadedCache (Cache ):
10731082 """
10741083 A drop-in replacement for DynamicCache that conserves accelerator (GPU, XPU) memory at the expense of more CPU memory.
@@ -1217,7 +1226,7 @@ def __init__(self, max_cache_len: int, config: PretrainedConfig, **kwargs):
12171226 init_kwargs ["sliding_window" ] = config .sliding_window
12181227 elif layer_type == "chunked_attention" :
12191228 init_kwargs ["sliding_window" ] = config .attention_chunk_size
1220- layers .append (LAYER_CLASS_MAP [layer_type ](** init_kwargs ))
1229+ layers .append (STATIC_LAYER_CLASS_MAPPING [layer_type ](** init_kwargs ))
12211230 else :
12221231 # In this case, fall back to StaticCache
12231232 layers = [StaticLayer (max_cache_len ) for _ in range (config .num_hidden_layers )]
@@ -1249,7 +1258,7 @@ def __init__(self, max_cache_len: int, config: PretrainedConfig, **kwargs):
12491258 init_kwargs ["sliding_window" ] = config .sliding_window
12501259 elif layer_type == "chunked_attention" :
12511260 init_kwargs ["sliding_window" ] = config .attention_chunk_size
1252- layers .append (LAYER_CLASS_MAP [layer_type ](** init_kwargs ))
1261+ layers .append (STATIC_LAYER_CLASS_MAPPING [layer_type ](** init_kwargs ))
12531262 else :
12541263 # In this case, fall back to StaticCache
12551264 layers = [StaticLayer (max_cache_len ) for _ in range (config .num_hidden_layers )]
0 commit comments