@@ -54,12 +54,12 @@ def __init__(
5454
5555 # Create a list of CustomKVCache instances, one per layer
5656 self .kv_cache = torch .nn .ModuleList ()
57- for _ in range ( config . num_hidden_layers ) :
57+ for layer in self . layers :
5858 layer_cache = CustomKVCache (
59- max_batch_size = self .max_batch_size ,
60- max_context_length = self .max_cache_len ,
61- n_heads = self . num_key_value_heads ,
62- head_dim = self .head_dim ,
59+ max_batch_size = layer .max_batch_size ,
60+ max_context_length = layer .max_cache_len ,
61+ n_heads = layer . num_heads ,
62+ head_dim = layer .head_dim ,
6363 dtype = dtype ,
6464 )
6565 self .kv_cache .append (layer_cache )
@@ -202,32 +202,29 @@ def __init__(
202202 layer_device_map = layer_device_map ,
203203 )
204204
205- # make sure layer_device_map is none
206205 assert layer_device_map is None
207206 assert device is None or device == "cpu" , "Device must be None or 'cpu'"
208207
209208 self .cache_position = None
210- # Create a list of cache instances, one per layer
211- # Use CustomKVCache for global layers and CustomRingKVCache for sliding window layers
209+ # Create a list of cache instances, one per layer.
210+ # Use CustomKVCache for global layers and CustomRingKVCache for sliding window layers.
212211 self .kv_cache = torch .nn .ModuleList ()
213- for layer_idx in range (config .num_hidden_layers ):
214- # newer version of transfomer has is_sliding defined
215- # for HybridCache
216- if self .is_sliding [layer_idx ]:
212+ for layer in self .layers :
213+ if layer .is_sliding ():
217214 # This is a sliding window layer
218215 layer_cache = CustomRingKVCache (
219- max_batch_size = self .max_batch_size ,
220- max_context_length = self . sliding_window_len ,
221- n_heads = self . num_key_value_heads ,
222- head_dim = self .head_dim ,
216+ max_batch_size = layer .max_batch_size ,
217+ max_context_length = layer . max_cache_len ,
218+ n_heads = layer . num_heads ,
219+ head_dim = layer .head_dim ,
223220 dtype = dtype ,
224221 )
225222 else :
226223 layer_cache = CustomKVCache (
227- max_batch_size = self .max_batch_size ,
228- max_context_length = self .max_cache_len ,
229- n_heads = self . num_key_value_heads ,
230- head_dim = self .head_dim ,
224+ max_batch_size = layer .max_batch_size ,
225+ max_context_length = layer .max_cache_len ,
226+ n_heads = layer . num_heads ,
227+ head_dim = layer .head_dim ,
231228 dtype = dtype ,
232229 )
233230 self .kv_cache .append (layer_cache )
@@ -284,7 +281,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
284281
285282 # For CustomRingKVCache, we need to handle the sequence length differently
286283 layer_cache = self .kv_cache [layer_idx ]
287- if self .is_sliding [layer_idx ]:
284+ if self .layers [layer_idx ]. is_sliding () :
288285 # CustomRingKVCache cache_position_manager which
289286 # maintains cache position for each slot in the kv cache
290287 # we return the max position + 1 to indicate max position
@@ -308,7 +305,7 @@ def get_layer_cache(self, layer_idx: int):
308305
309306def replace_with_et_custom_kv_cache (module , config , generation_config , cache_dtype ):
310307 """
311- Replace all KV caches in the module with ETCustomStaticCache.
308+ Replace all KV caches in the module with ETCustomStaticCache or ETCustomHybridCache .
312309 This modifies the model in place.
313310
314311 Args:
@@ -342,18 +339,18 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt
342339 if getattr (module , "replace_cache" , None ) is not None :
343340 static_cache = ETCustomStaticCache (
344341 config = config ,
345- max_batch_size = generation_config .cache_config .batch_size ,
346- max_cache_len = generation_config .cache_config .max_cache_len ,
347- device = generation_config .cache_config .device ,
342+ max_batch_size = generation_config .cache_config .get ( " batch_size" ) ,
343+ max_cache_len = generation_config .cache_config .get ( " max_cache_len" ) ,
344+ device = generation_config .cache_config .get ( " device" ) ,
348345 dtype = cache_dtype ,
349346 )
350347 module .replace_cache (static_cache )
351348 else :
352349 module .static_cache = ETCustomStaticCache (
353350 config = config ,
354- max_batch_size = generation_config .cache_config .batch_size ,
355- max_cache_len = generation_config .cache_config .max_cache_len ,
356- device = generation_config .cache_config .device ,
351+ max_batch_size = generation_config .cache_config .get ( " batch_size" ) ,
352+ max_cache_len = generation_config .cache_config .get ( " max_cache_len" ) ,
353+ device = generation_config .cache_config .get ( " device" ) ,
357354 dtype = cache_dtype ,
358355 )
359356 # Dont know why we need to this even though
@@ -370,25 +367,25 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt
370367 if getattr (module , "replace_cache" , None ) is not None :
371368 hybrid_cache = ETCustomHybridCache (
372369 config = config ,
373- max_batch_size = generation_config .cache_config .batch_size ,
374- max_cache_len = generation_config .cache_config .max_cache_len ,
375- device = generation_config .cache_config .device ,
370+ max_batch_size = generation_config .cache_config .get ( " batch_size" ) ,
371+ max_cache_len = generation_config .cache_config .get ( " max_cache_len" ) ,
372+ device = generation_config .cache_config .get ( " device" ) ,
376373 dtype = cache_dtype ,
377374 )
378375 module .replace_cache (hybrid_cache )
379376 else :
380377 module .cache = ETCustomHybridCache (
381378 config = config ,
382- max_batch_size = generation_config .cache_config .batch_size ,
383- max_cache_len = generation_config .cache_config .max_cache_len ,
384- device = generation_config .cache_config .device ,
379+ max_batch_size = generation_config .cache_config .get ( " batch_size" ) ,
380+ max_cache_len = generation_config .cache_config .get ( " max_cache_len" ) ,
381+ device = generation_config .cache_config .get ( " device" ) ,
385382 dtype = cache_dtype ,
386383 )
387384 # Register cache attributes for each layer
388385 for i in range (len (module .cache .kv_cache )):
389386 setattr (module , f"key_cache_{ i } " , module .cache .kv_cache [i ].k_cache )
390387 setattr (module , f"value_cache_{ i } " , module .cache .kv_cache [i ].v_cache )
391- if module .cache .is_sliding [i ]:
388+ if module .cache .layers [i ]. is_sliding () :
392389 # Register cache_positions as buffer for sliding window layers
393390 # This prevents it from being traced as a constant
394391 module .register_buffer (
0 commit comments