Skip to content

Commit 6b6314d

Browse files
committed
remove hack for quantized.get_seq_length and refactor out get_max_shape to masking_utils.py
1 parent 26c28af commit 6b6314d

File tree

2 files changed

+56
-82
lines changed

2 files changed

+56
-82
lines changed

src/transformers/cache_utils.py

Lines changed: 1 addition & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -396,32 +396,6 @@ def get_seq_length(self, layer_idx: int = 0) -> int:
396396
return 0
397397
return self.layers[layer_idx].get_seq_length()
398398

399-
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
400-
"""
401-
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
402-
the given layer at `layer_idx`.
403-
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
404-
for each layer.
405-
"""
406-
if isinstance(self.layers[layer_idx], SlidingWindowLayer):
407-
query_length = cache_position.shape[0]
408-
first_cache_position = cache_position[0]
409-
410-
local_mask_kv_offset = torch.clamp(first_cache_position - self.config.sliding_window + 1, min=0)
411-
# This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
412-
local_mask_kv_length = max(query_length, self.config.sliding_window)
413-
return local_mask_kv_length, local_mask_kv_offset
414-
415-
full_mask_kv_offset = 0
416-
if isinstance(self.layers[layer_idx], StaticLayer):
417-
full_mask_kv_length = self.get_max_cache_shape()
418-
return full_mask_kv_length, full_mask_kv_offset
419-
else:
420-
query_length = cache_position.shape[0]
421-
past_seen_tokens = self.get_seq_length()
422-
kv_length = query_length + past_seen_tokens
423-
return kv_length, full_mask_kv_offset
424-
425399
def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]:
426400
"""Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
427401
backward compatibility."""
@@ -1376,15 +1350,6 @@ def get_max_cache_shape(self) -> int:
13761350
"""Returns the maximum sequence length (i.e. max capacity) of the cache object"""
13771351
return self.self_attention_cache.get_max_cache_shape()
13781352

1379-
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
1380-
"""
1381-
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
1382-
the given layer at `layer_idx`.
1383-
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
1384-
for each layer.
1385-
"""
1386-
return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx)
1387-
13881353

13891354
class HybridCache(Cache):
13901355
"""
@@ -1647,37 +1612,6 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
16471612
device = self.value_cache[layer_idx].device
16481613
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
16491614

1650-
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
1651-
"""
1652-
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
1653-
the given layer at `layer_idx`.
1654-
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
1655-
for each layer.
1656-
"""
1657-
if self.is_sliding[layer_idx]:
1658-
query_length = cache_position.shape[0]
1659-
first_cache_position = cache_position[0]
1660-
1661-
local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
1662-
# This is the true general case for any Cache using local attention (sliding or chunked)
1663-
if first_cache_position >= self.sliding_window:
1664-
# Here the Cache is already full
1665-
local_mask_kv_length = self.sliding_window + query_length - 1
1666-
elif (
1667-
first_cache_position < self.sliding_window
1668-
and first_cache_position + query_length > self.sliding_window
1669-
):
1670-
# Here the Cache becomes full with the new input
1671-
local_mask_kv_length = first_cache_position + query_length
1672-
else:
1673-
# Here the Cache is still smaller than the local size, but we return the local size as it's static
1674-
local_mask_kv_length = self.sliding_window
1675-
return local_mask_kv_length, local_mask_kv_offset
1676-
1677-
full_mask_kv_offset = 0
1678-
full_mask_kv_length = self.get_max_cache_shape()
1679-
return full_mask_kv_length, full_mask_kv_offset
1680-
16811615

16821616
class OffloadedHybridCache(HybridChunkedCache):
16831617
def __init__(
@@ -1973,14 +1907,13 @@ def __init__(self, cache_config: QuantizedCacheConfig):
19731907
self.config = cache_config
19741908
self._quantized_key_cache: list[torch.Tensor] = []
19751909
self._quantized_value_cache: list[torch.Tensor] = []
1976-
self._seen_tokens = 0
19771910

19781911
def init(self, cache: "Cache", **kwargs) -> None:
19791912
"""Initialize the quantized processor and validate configuration."""
19801913
self.config.validate()
19811914

19821915
# Only compatible with DynamicCache
1983-
if not isinstance(cache, DynamicCache):
1916+
if not isinstance(cache.layers[0], DynamicLayer):
19841917
raise ValueError("QuantizedCacheProcessor is only compatible with DynamicCache")
19851918

19861919
def post_update(
@@ -1992,9 +1925,6 @@ def post_update(
19921925
cache_kwargs: Optional[dict[str, Any]] = None,
19931926
) -> tuple[torch.Tensor, torch.Tensor]:
19941927
"""Apply quantization after cache update."""
1995-
# Update the number of seen tokens
1996-
if layer_idx == 0:
1997-
self._seen_tokens += key_tensors.shape[-2]
19981928

19991929
if len(cache.key_cache) < layer_idx:
20001930
raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.")
@@ -2194,15 +2124,6 @@ def __init__(self, cache_config: QuantizedCacheConfig) -> None:
21942124

21952125
super().__init__(processor=processor)
21962126

2197-
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
2198-
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
2199-
if len(self.key_cache) <= layer_idx:
2200-
return 0
2201-
# since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
2202-
# updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
2203-
# this part of code otherwise fails when used to verify attn_weight shape in some models
2204-
return self.processor._seen_tokens if layer_idx == 0 else self.processor._seen_tokens - 1
2205-
22062127

22072128
class QuantoQuantizedCache(QuantizedCache):
22082129
"""

src/transformers/masking_utils.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919
import torch.nn.functional as F
2020

21-
from .cache_utils import Cache
21+
from .cache_utils import Cache, EncoderDecoderCache, HybridChunkedCache, SlidingWindowLayer, StaticLayer
2222
from .configuration_utils import PretrainedConfig
2323
from .utils.generic import GeneralInterface
2424
from .utils.import_utils import is_torch_flex_attn_available, is_torch_greater_or_equal, is_torchdynamo_compiling
@@ -592,6 +592,59 @@ class AttentionMaskInterface(GeneralInterface):
592592
ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface()
593593

594594

595+
def get_mask_sizes(cache, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
596+
"""
597+
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
598+
the given layer at `layer_idx`.
599+
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
600+
for each layer.
601+
"""
602+
if isinstance(cache, HybridChunkedCache): # not yet ported to layer-wise
603+
if cache.is_sliding[layer_idx]:
604+
query_length = cache_position.shape[0]
605+
first_cache_position = cache_position[0]
606+
607+
local_mask_kv_offset = torch.clamp(first_cache_position - cache.sliding_window + 1, min=0)
608+
# This is the true general case for any Cache using local attention (sliding or chunked)
609+
if first_cache_position >= cache.sliding_window:
610+
# Here the Cache is already full
611+
local_mask_kv_length = cache.sliding_window + query_length - 1
612+
elif (
613+
first_cache_position < cache.sliding_window
614+
and first_cache_position + query_length > cache.sliding_window
615+
):
616+
# Here the Cache becomes full with the new input
617+
local_mask_kv_length = first_cache_position + query_length
618+
else:
619+
# Here the Cache is still smaller than the local size, but we return the local size as it's static
620+
local_mask_kv_length = cache.sliding_window
621+
return local_mask_kv_length, local_mask_kv_offset
622+
623+
return cache.get_max_cache_shape(), 0
624+
625+
if isinstance(cache, EncoderDecoderCache):
626+
cache = cache.attention_cache
627+
628+
if isinstance(cache.layers[layer_idx], SlidingWindowLayer):
629+
query_length = cache_position.shape[0]
630+
first_cache_position = cache_position[0]
631+
632+
local_mask_kv_offset = torch.clamp(first_cache_position - cache.config.sliding_window + 1, min=0)
633+
# This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
634+
local_mask_kv_length = max(query_length, cache.config.sliding_window)
635+
return local_mask_kv_length, local_mask_kv_offset
636+
637+
full_mask_kv_offset = 0
638+
if isinstance(cache.layers[layer_idx], StaticLayer):
639+
full_mask_kv_length = cache.get_max_cache_shape()
640+
return full_mask_kv_length, full_mask_kv_offset
641+
else:
642+
query_length = cache_position.shape[0]
643+
past_seen_tokens = cache_position.shape[0] if cache_position.shape[0] > 1 else cache_position[0] + 1
644+
kv_length = query_length + past_seen_tokens
645+
return kv_length, full_mask_kv_offset
646+
647+
595648
def _preprocess_mask_arguments(
596649
config: PretrainedConfig,
597650
input_embeds: torch.Tensor,
@@ -649,7 +702,7 @@ def _preprocess_mask_arguments(
649702

650703
# If using a cache, it can give all informations about mask sizes based on seen tokens
651704
if past_key_values is not None:
652-
kv_length, kv_offset = past_key_values.get_mask_sizes(cache_position, layer_idx)
705+
kv_length, kv_offset = get_mask_sizes(cache_to_query, cache_position, layer_idx)
653706
# Otherwise, the sizes are simply the input sizes
654707
else:
655708
kv_length, kv_offset = input_embeds.shape[1], 0

0 commit comments

Comments
 (0)