@@ -1204,7 +1204,7 @@ def __init__(
12041204 config .num_attention_heads
12051205 if getattr (config , "num_key_value_heads" , None ) is None
12061206 else config .num_key_value_heads
1207- )
1207+ ) // 8 # TODO use TP!
12081208
12091209 self .key_cache : List [torch .Tensor ] = []
12101210 self .value_cache : List [torch .Tensor ] = []
@@ -1663,84 +1663,75 @@ def __init__(
16631663 max_batch_size : int ,
16641664 max_cache_len : Optional [int ] = None ,
16651665 device : Union [torch .device , str , None ] = None ,
1666- dtype : torch .dtype = torch .float32 ,
1666+ dtype : torch .dtype = torch .bfloat16 ,
16671667 layer_device_map : Optional [Dict [int , Union [str , torch .device , int ]]] = None ,
16681668 ) -> None :
16691669 super ().__init__ ()
16701670 if not hasattr (config , "sliding_window" ) or config .sliding_window is None :
1671- raise ValueError (
1672- "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
1673- "sliding window attention, please check if there is a `sliding_window` field in the model "
1674- "config and it's not set to None."
1675- )
1671+ self .sliding_window = getattr (config .get_text_config (), "attention_chunk_size" , 8092 )
1672+ else :
1673+ self .sliding_window = config .sliding_window
16761674 self .max_cache_len = max_cache_len
16771675 self .max_batch_size = max_batch_size
1678- # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
1679- self .head_dim = (
1680- config .head_dim if hasattr (config , "head_dim" ) else config .hidden_size // config .num_attention_heads
1681- )
1682-
1676+ self .head_dim = getattr (config , "head_dim" , config .hidden_size // config .num_attention_heads )
16831677 self ._dtype = dtype
1684- self .num_key_value_heads = (
1685- config .num_attention_heads if config .num_key_value_heads is None else config .num_key_value_heads
1686- )
16871678
1688- layer_switch = config .sliding_window_pattern if hasattr (config , "sliding_window_pattern" ) else 2 # 2 is for BC
1689- self .is_sliding = torch .tensor (
1690- [bool ((i + 1 ) % layer_switch ) for i in range (config .num_hidden_layers )], dtype = torch .bool
1691- )
1679+ if hasattr (config .get_text_config (), "no_rope_layers" ):
1680+ self .is_sliding = torch .tensor (config .no_rope_layers )
1681+ else :
1682+ layer_switch = getattr (config , "sliding_window_pattern" , 2 )
1683+ self .is_sliding = torch .tensor (
1684+ [bool ((i + 1 ) % layer_switch ) for i in range (config .num_hidden_layers )], dtype = torch .bool
1685+ )
1686+
16921687 self .key_cache : List [torch .Tensor ] = []
16931688 self .value_cache : List [torch .Tensor ] = []
1694- global_cache_shape = (self .max_batch_size , self .num_key_value_heads , max_cache_len , self .head_dim )
1689+ self .cumulative_length = [0 for _ in range (config .num_hidden_layers )]
1690+
1691+ def initialise_cache_layer (self , layer_idx , key_states ):
1692+ if len (self .key_cache ) > layer_idx :
1693+ return
1694+
1695+ num_key_value_heads = key_states .shape [1 ]
1696+ device = key_states .device
1697+ global_cache_shape = (self .max_batch_size , num_key_value_heads , self .max_cache_len , self .head_dim )
16951698 sliding_cache_shape = (
16961699 self .max_batch_size ,
1697- self . num_key_value_heads ,
1698- min ( config .sliding_window , max_cache_len ) ,
1700+ num_key_value_heads ,
1701+ self .sliding_window ,
16991702 self .head_dim ,
17001703 )
1701- device = torch .device (device ) if device is not None and isinstance (device , str ) else None
1702- for i in range (config .num_hidden_layers ):
1703- if layer_device_map is not None :
1704- layer_device = layer_device_map [i ]
1705- else :
1706- layer_device = device
1707- # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
1708- # breaks when updating the cache.
1709- cache_shape = global_cache_shape if not self .is_sliding [i ] else sliding_cache_shape
1710- new_layer_key_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = layer_device )
1711- new_layer_value_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = layer_device )
1712- torch ._dynamo .mark_static_address (new_layer_key_cache )
1713- torch ._dynamo .mark_static_address (new_layer_value_cache )
1714- self .key_cache .append (new_layer_key_cache )
1715- self .value_cache .append (new_layer_value_cache )
1704+ # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
1705+ # breaks when updating the cache.
1706+ cache_shape = sliding_cache_shape if self .is_sliding [layer_idx ] else global_cache_shape
1707+ new_layer_key_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = device )
1708+ new_layer_value_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = device )
1709+ torch ._dynamo .mark_static_address (new_layer_key_cache )
1710+ torch ._dynamo .mark_static_address (new_layer_value_cache )
1711+ self .key_cache .append (new_layer_key_cache )
1712+ self .value_cache .append (new_layer_value_cache )
17161713
17171714 def _sliding_update (self , cache_position , layer_idx , key_states , value_states , k_out , v_out , max_cache_len ):
1718- if cache_position .shape [0 ] > max_cache_len :
1719- k_out = key_states [:, :, - max_cache_len :, :]
1720- v_out = value_states [:, :, - max_cache_len :, :]
1721- # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
1722- self .key_cache [layer_idx ] += k_out
1723- self .value_cache [layer_idx ] += v_out
1724- # we should return the whole states instead of k_out, v_out to take the whole prompt
1725- # into consideration when building kv cache instead of just throwing away tokens outside of the window
1726- return key_states , value_states
1727-
1728- slicing = torch .ones (max_cache_len , dtype = torch .long , device = value_states .device ).cumsum (0 )
1729- cache_position = cache_position .clamp (0 , max_cache_len - 1 )
1730- to_shift = cache_position >= max_cache_len - 1
1731- indices = (slicing + to_shift [- 1 ].int () - 1 ) % max_cache_len
1732- k_out = k_out [:, :, indices ]
1733- v_out = v_out [:, :, indices ]
1734-
1735- k_out [:, :, cache_position ] = key_states
1736- v_out [:, :, cache_position ] = value_states
1737- # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
1738- self .key_cache [layer_idx ].zero_ ()
1739- self .value_cache [layer_idx ].zero_ ()
1740-
1741- self .key_cache [layer_idx ] += k_out
1742- self .value_cache [layer_idx ] += v_out
1743- return k_out , v_out
1715+ cumulative_length = self .cumulative_length [layer_idx ]
1716+ is_full = cumulative_length >= max_cache_len
1717+ if is_full :
1718+ full_key_states = torch .cat ((k_out [:, :, 1 :, :], key_states ), dim = - 2 )
1719+ full_value_states = torch .cat ((v_out [:, :, 1 :, :], value_states ), dim = - 2 )
1720+ elif not is_full and cumulative_length + key_states .shape [2 ] > max_cache_len :
1721+ full_key_states = torch .cat ((k_out [:, :, :cumulative_length , :], key_states ), dim = - 2 )
1722+ full_value_states = torch .cat ((v_out [:, :, :cumulative_length , :], value_states ), dim = - 2 )
1723+ else :
1724+ self .key_cache [layer_idx ].index_copy_ (2 , cache_position , key_states )
1725+ self .value_cache [layer_idx ].index_copy_ (2 , cache_position , value_states )
1726+ self .cumulative_length [layer_idx ] += key_states .shape [- 2 ]
1727+ return self .key_cache [layer_idx ], self .value_cache [layer_idx ]
1728+
1729+ self .key_cache [layer_idx ].copy_ (full_key_states [:, :, - max_cache_len :, :])
1730+ self .value_cache [layer_idx ].copy_ (full_value_states [:, :, - max_cache_len :, :])
1731+ self .cumulative_length [layer_idx ] += key_states .shape [- 2 ]
1732+ # we should return the whole states instead of k_out, v_out to take the whole prompt
1733+ # into consideration when building kv cache instead of just throwing away tokens outside of the window
1734+ return full_key_states , full_value_states
17441735
17451736 def _static_update (self , cache_position , layer_idx , key_states , value_states , k_out , v_out , max_cache_len ):
17461737 k_out [:, :, cache_position ] = key_states
@@ -1760,7 +1751,7 @@ def update(
17601751 if cache_kwargs is None :
17611752 cache_kwargs = {}
17621753 cache_position = cache_kwargs .get ("cache_position" )
1763- sliding_window = cache_kwargs . get ( "sliding_window" )
1754+ self . initialise_cache_layer ( layer_idx , key_states )
17641755
17651756 # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
17661757 # when the cache is initialized in the forward pass (e.g. Gemma2)
@@ -1774,7 +1765,7 @@ def update(
17741765 key_states = key_states .to (k_out .dtype )
17751766 value_states = value_states .to (v_out .dtype )
17761767
1777- if sliding_window :
1768+ if self . is_sliding [ layer_idx ] :
17781769 update_fn = self ._sliding_update
17791770 else :
17801771 update_fn = self ._static_update
@@ -1801,6 +1792,8 @@ def get_seq_length(self, layer_idx: Optional[int] = 0):
18011792 "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
18021793 "Using the `layer_idx` argument is not supported."
18031794 )
1795+ if len (self .key_cache ) == 0 :
1796+ return 0
18041797 return (self .key_cache [layer_idx ][0 , 0 ].any (dim = - 1 )).sum ()
18051798
18061799 def reset (self ):
@@ -1809,6 +1802,7 @@ def reset(self):
18091802 # In-place ops prevent breaking the static address
18101803 self .key_cache [layer_idx ].zero_ ()
18111804 self .value_cache [layer_idx ].zero_ ()
1805+ self .cumulative_length = [0 for _ in range (len (self .cumulative_length ))]
18121806
18131807
18141808class MambaCache :
0 commit comments