Skip to content

Commit c030aa2

Browse files
committed
fix
1 parent 7a01fcd commit c030aa2

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

src/transformers/cache_utils.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,16 +177,19 @@ class Cache(CacheBase):
177177
Parameters:
178178
model_config (`PretrainedConfig`):
179179
Model configuration for shape/device info.
180-
processor (`CacheProcessor`, *optional*):
180+
cache_processor (`CacheProcessor`, *optional*):
181181
Cache processor to apply (e.g., quantization, offloading).
182-
Additional arguments for cache configuration:
183-
- `max_batch_size`/`batch_size` (`int`): Maximum batch size for static caches
184-
- `max_cache_len` (`int`): Maximum sequence length. For hybrid caches:
185-
* SlidingWindowLayers: clamped to `min(sliding_window, max_cache_len)`
186-
* StaticLayers: uses full `max_cache_len`
187-
- `device` (`torch.device`): Device for cache tensors
188-
- `dtype` (`torch.dtype`): Data type for cache tensors
189-
- `layer_device_map` (`dict[int, Union[str, torch.device]]`): Per-layer device mapping
182+
layer_classes (`list[type[CacheLayer]]`, *optional*):
183+
List of layer classes to use for the cache.
184+
185+
Additional arguments for cache configuration:
186+
- `max_batch_size`/`batch_size` (`int`): Maximum batch size for static caches
187+
- `max_cache_len` (`int`): Maximum sequence length. For hybrid caches:
188+
* SlidingWindowLayers: clamped to `min(sliding_window, max_cache_len)`
189+
* StaticLayers: uses full `max_cache_len`
190+
- `device` (`torch.device`): Device for cache tensors
191+
- `dtype` (`torch.dtype`): Data type for cache tensors
192+
- `layer_device_map` (`dict[int, Union[str, torch.device]]`): Per-layer device mapping
190193
191194
Note for hybrid caches (blocks of (StaticLayer, ..., SlidingWindowLayer) repeated across layers):
192195
- Requires `model_config.sliding_window` to be set

src/transformers/masking_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,9 @@ def create_causal_mask(
692692
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
693693
"""
694694
# If we have an HybridCache structure, here we want to create the mask for the full layers
695-
is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers]
695+
is_sliding = []
696+
if past_key_values is not None:
697+
is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers]
696698
layer_idx = is_sliding.index(True) if True in is_sliding else 0
697699

698700
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
@@ -772,7 +774,9 @@ def create_sliding_window_causal_mask(
772774
useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
773775
"""
774776
# If we have an HybridCache structure, here we want to create the mask for the sliding layers
775-
is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers]
777+
is_sliding = []
778+
if past_key_values is not None:
779+
is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers]
776780
layer_idx = is_sliding.index(True) if True in is_sliding else 0
777781

778782
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
@@ -857,7 +861,9 @@ def create_chunked_causal_mask(
857861
useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
858862
"""
859863
# If we have an HybridCache structure, here we want to create the mask for the sliding layers
860-
is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers]
864+
is_sliding = []
865+
if past_key_values is not None:
866+
is_sliding = [getattr(layer, "is_sliding", False) for layer in past_key_values.layers]
861867
layer_idx = is_sliding.index(True) if True in is_sliding else 0
862868

863869
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(

0 commit comments

Comments
 (0)