@@ -1098,6 +1098,10 @@ class StaticCache(Cache):
10981098 Mapping between the layers and its device. This is required when you are manually initializing the cache
10991099 and the model is split between different gpus. You can know which layers mapped to which device by
11001100 checking the associated device_map: `model.hf_device_map`.
1101+ tp_size (`Optional[int]`, *optional*):
1102+ The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache
1103+ if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the
1104+ number of key/value heads will not be adjusted.
11011105
11021106
11031107 Example:
@@ -1130,6 +1134,7 @@ def __init__(
11301134 device : Union [torch .device , str , None ] = None ,
11311135 dtype : torch .dtype = torch .float32 ,
11321136 layer_device_map : Optional [dict [int , Union [str , torch .device , int ]]] = None ,
1137+ tp_size : Optional [int ] = None ,
11331138 ) -> None :
11341139 super ().__init__ ()
11351140 self .max_batch_size = max_batch_size
@@ -1144,6 +1149,13 @@ def __init__(
11441149 if getattr (config , "num_key_value_heads" , None ) is None
11451150 else config .num_key_value_heads
11461151 )
1152+ if tp_size is not None and tp_size > 1 :
1153+ if self .num_key_value_heads % tp_size != 0 :
1154+ raise ValueError (
1155+ f"Number of key value heads { self .num_key_value_heads } must be divisible by tensor parallel size { tp_size } ."
1156+ )
1157+ # If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
1158+ self .num_key_value_heads //= tp_size
11471159
11481160 self .key_cache : list [torch .Tensor ] = []
11491161 self .value_cache : list [torch .Tensor ] = []
@@ -1573,6 +1585,10 @@ class HybridCache(Cache):
15731585 Mapping between the layers and its device. This is required when you are manually initializing the cache
15741586 and the model is split between different gpus. You can know which layers mapped to which device by
15751587 checking the associated device_map: `model.hf_device_map`.
1588+ tp_size (`Optional[int]`, *optional*):
1589+ The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache
1590+ if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the
1591+ number of key/value heads will not be adjusted.
15761592
15771593 Example:
15781594
@@ -1604,6 +1620,7 @@ def __init__(
16041620 device : Union [torch .device , str , None ] = None ,
16051621 dtype : torch .dtype = torch .float32 ,
16061622 layer_device_map : Optional [dict [int , Union [str , torch .device , int ]]] = None ,
1623+ tp_size : Optional [int ] = None ,
16071624 ) -> None :
16081625 super ().__init__ ()
16091626 if not hasattr (config , "sliding_window" ) or config .sliding_window is None :
@@ -1627,6 +1644,13 @@ def __init__(
16271644 if getattr (config , "num_key_value_heads" , None ) is None
16281645 else config .num_key_value_heads
16291646 )
1647+ if tp_size is not None and tp_size > 1 :
1648+ if self .num_key_value_heads % tp_size != 0 :
1649+ raise ValueError (
1650+ f"Number of key value heads { self .num_key_value_heads } must be divisible by tensor parallel size { tp_size } ."
1651+ )
1652+ # If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
1653+ self .num_key_value_heads //= tp_size
16301654
16311655 # If the attribute does not exist in the config, fallback to a simple StaticCache
16321656 if hasattr (config , "layer_types" ):
@@ -2197,6 +2221,10 @@ class OffloadedStaticCache(StaticCache):
21972221 Mapping between the layers and its device. This is required when you are manually initializing the cache
21982222 and the model is split between different gpus. You can know which layers mapped to which device by
21992223 checking the associated device_map: `model.hf_device_map`.
2224+ tp_size (`Optional[int]`, *optional*):
2225+ The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache
2226+ if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the
2227+ number of key/value heads will not be adjusted.
22002228
22012229 Example:
22022230
@@ -2228,6 +2256,7 @@ def __init__(
22282256 dtype : Optional [torch .dtype ] = None ,
22292257 offload_device : Union [str , torch .device ] = torch .device ("cpu" ),
22302258 layer_device_map : Optional [dict [int , Union [str , torch .device , int ]]] = None ,
2259+ tp_size : Optional [int ] = None ,
22312260 ) -> None :
22322261 super (Cache , self ).__init__ ()
22332262
@@ -2251,6 +2280,13 @@ def __init__(
22512280 if getattr (config , "num_key_value_heads" , None ) is None
22522281 else config .num_key_value_heads
22532282 )
2283+ if tp_size is not None and tp_size > 1 :
2284+ if num_key_value_heads % tp_size != 0 :
2285+ raise ValueError (
2286+ f"Number of key value heads { num_key_value_heads } must be divisible by tensor parallel size { tp_size } ."
2287+ )
2288+ # If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
2289+ num_key_value_heads //= tp_size
22542290
22552291 cache_shape = (max_batch_size , num_key_value_heads , self .max_cache_len , head_dim )
22562292
0 commit comments