@@ -396,32 +396,6 @@ def get_seq_length(self, layer_idx: int = 0) -> int:
396396 return 0
397397 return self .layers [layer_idx ].get_seq_length ()
398398
399- def get_mask_sizes (self , cache_position : torch .Tensor , layer_idx : int ) -> tuple [int , int ]:
400- """
401- Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
402- the given layer at `layer_idx`.
403- The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
404- for each layer.
405- """
406- if isinstance (self .layers [layer_idx ], SlidingWindowLayer ):
407- query_length = cache_position .shape [0 ]
408- first_cache_position = cache_position [0 ]
409-
410- local_mask_kv_offset = torch .clamp (first_cache_position - self .config .sliding_window + 1 , min = 0 )
411- # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
412- local_mask_kv_length = max (query_length , self .config .sliding_window )
413- return local_mask_kv_length , local_mask_kv_offset
414-
415- full_mask_kv_offset = 0
416- if isinstance (self .layers [layer_idx ], StaticLayer ):
417- full_mask_kv_length = self .get_max_cache_shape ()
418- return full_mask_kv_length , full_mask_kv_offset
419- else :
420- query_length = cache_position .shape [0 ]
421- past_seen_tokens = self .get_seq_length ()
422- kv_length = query_length + past_seen_tokens
423- return kv_length , full_mask_kv_offset
424-
425399 def to_legacy_cache (self ) -> tuple [tuple [torch .Tensor , torch .Tensor ]]:
426400 """Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
427401 backward compatibility."""
@@ -1376,15 +1350,6 @@ def get_max_cache_shape(self) -> int:
13761350 """Returns the maximum sequence length (i.e. max capacity) of the cache object"""
13771351 return self .self_attention_cache .get_max_cache_shape ()
13781352
1379- def get_mask_sizes (self , cache_position : torch .Tensor , layer_idx : int ) -> tuple [int , int ]:
1380- """
1381- Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
1382- the given layer at `layer_idx`.
1383- The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
1384- for each layer.
1385- """
1386- return self .self_attention_cache .get_mask_sizes (cache_position , layer_idx )
1387-
13881353
13891354class HybridCache (Cache ):
13901355 """
@@ -1647,37 +1612,6 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
16471612 device = self .value_cache [layer_idx ].device
16481613 self .value_cache [layer_idx ] = self .value_cache [layer_idx ].index_select (0 , beam_idx .to (device ))
16491614
1650- def get_mask_sizes (self , cache_position : torch .Tensor , layer_idx : int ) -> tuple [int , int ]:
1651- """
1652- Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
1653- the given layer at `layer_idx`.
1654- The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
1655- for each layer.
1656- """
1657- if self .is_sliding [layer_idx ]:
1658- query_length = cache_position .shape [0 ]
1659- first_cache_position = cache_position [0 ]
1660-
1661- local_mask_kv_offset = torch .clamp (first_cache_position - self .sliding_window + 1 , min = 0 )
1662- # This is the true general case for any Cache using local attention (sliding or chunked)
1663- if first_cache_position >= self .sliding_window :
1664- # Here the Cache is already full
1665- local_mask_kv_length = self .sliding_window + query_length - 1
1666- elif (
1667- first_cache_position < self .sliding_window
1668- and first_cache_position + query_length > self .sliding_window
1669- ):
1670- # Here the Cache becomes full with the new input
1671- local_mask_kv_length = first_cache_position + query_length
1672- else :
1673- # Here the Cache is still smaller than the local size, but we return the local size as it's static
1674- local_mask_kv_length = self .sliding_window
1675- return local_mask_kv_length , local_mask_kv_offset
1676-
1677- full_mask_kv_offset = 0
1678- full_mask_kv_length = self .get_max_cache_shape ()
1679- return full_mask_kv_length , full_mask_kv_offset
1680-
16811615
16821616class OffloadedHybridCache (HybridChunkedCache ):
16831617 def __init__ (
@@ -1973,14 +1907,13 @@ def __init__(self, cache_config: QuantizedCacheConfig):
19731907 self .config = cache_config
19741908 self ._quantized_key_cache : list [torch .Tensor ] = []
19751909 self ._quantized_value_cache : list [torch .Tensor ] = []
1976- self ._seen_tokens = 0
19771910
19781911 def init (self , cache : "Cache" , ** kwargs ) -> None :
19791912 """Initialize the quantized processor and validate configuration."""
19801913 self .config .validate ()
19811914
19821915 # Only compatible with DynamicCache
1983- if not isinstance (cache , DynamicCache ):
1916+ if not isinstance (cache . layers [ 0 ], DynamicLayer ):
19841917 raise ValueError ("QuantizedCacheProcessor is only compatible with DynamicCache" )
19851918
19861919 def post_update (
@@ -1992,9 +1925,6 @@ def post_update(
19921925 cache_kwargs : Optional [dict [str , Any ]] = None ,
19931926 ) -> tuple [torch .Tensor , torch .Tensor ]:
19941927 """Apply quantization after cache update."""
1995- # Update the number of seen tokens
1996- if layer_idx == 0 :
1997- self ._seen_tokens += key_tensors .shape [- 2 ]
19981928
19991929 if len (cache .key_cache ) < layer_idx :
20001930 raise ValueError ("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache." )
@@ -2194,15 +2124,6 @@ def __init__(self, cache_config: QuantizedCacheConfig) -> None:
21942124
21952125 super ().__init__ (processor = processor )
21962126
2197- def get_seq_length (self , layer_idx : Optional [int ] = 0 ) -> int :
2198- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
2199- if len (self .key_cache ) <= layer_idx :
2200- return 0
2201- # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
2202- # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
2203- # this part of code otherwise fails when used to verify attn_weight shape in some models
2204- return self .processor ._seen_tokens if layer_idx == 0 else self .processor ._seen_tokens - 1
2205-
22062127
22072128class QuantoQuantizedCache (QuantizedCache ):
22082129 """
0 commit comments