3131from vllm .multimodal .utils import group_mm_inputs_by_modality
3232from vllm .sequence import IntermediateTensors
3333from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , LayerBlockType , cdiv ,
34- is_pin_memory_available )
34+ is_pin_memory_available , prev_power_of_2 )
3535from vllm .v1 .attention .backends .pallas import (PallasAttentionBackend ,
36- PallasMetadata )
36+ PallasMetadata ,
37+ get_page_size_bytes )
3738from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
3839from vllm .v1 .kv_cache_interface import (AttentionSpec , FullAttentionSpec ,
3940 KVCacheConfig , KVCacheSpec ,
5657INVALID_TOKEN_ID = - 1
5758# Smallest output size
5859MIN_NUM_SEQS = 8
59- # Block size used for kv cache updating kernel
60- NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8
6160
6261
6362#########################################################
@@ -139,7 +138,11 @@ def __init__(
139138 self .pin_memory = is_pin_memory_available ()
140139 self .dtype = self .model_config .dtype
141140 if cache_config .cache_dtype == "auto" :
142- self .kv_cache_dtype = self .dtype
141+ model_dtype = self .dtype
142+ if isinstance (model_dtype , str ):
143+ self .kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE [model_dtype ]
144+ else :
145+ self .kv_cache_dtype = model_dtype
143146 else :
144147 self .kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE [
145148 cache_config .cache_dtype ]
@@ -192,6 +195,14 @@ def __init__(
192195 self .max_num_encoder_input_tokens = encoder_compute_budget
193196 self .encoder_cache_size = encoder_cache_size
194197
198+ self ._num_slices_per_kv_cache_update_block = \
199+ _get_num_slices_per_kv_cache_update_block (get_page_size_bytes (
200+ block_size = self .block_size ,
201+ num_kv_heads = self .num_kv_heads ,
202+ head_size = self .head_size ,
203+ kv_cache_dtype = self .kv_cache_dtype ,
204+ ))
205+
195206 # Lazy initialization
196207 self .model : nn .Module # Set after load_model
197208 self .kv_caches : list [torch .Tensor ] = []
@@ -719,7 +730,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
719730 num_kv_update_slices = slot_mapping_metadata .shape [0 ]
720731 padded_num_slices = _get_padded_num_kv_cache_update_slices (
721732 padded_total_num_scheduled_tokens , self .max_num_reqs ,
722- self .block_size )
733+ self .block_size , self . _num_slices_per_kv_cache_update_block )
723734 slot_mapping_metadata = np .pad (
724735 slot_mapping_metadata ,
725736 [[0 , padded_num_slices - len (slot_mapping_metadata )], [0 , 0 ]],
@@ -750,8 +761,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
750761 num_kv_update_slices = torch .tensor ([num_kv_update_slices ],
751762 dtype = torch .int32 ,
752763 device = self .device ),
753- num_slices_per_kv_cache_update_block =
754- NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK ,
764+ num_slices_per_kv_cache_update_block = self .
765+ _num_slices_per_kv_cache_update_block ,
755766 )
756767 # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
757768 # request in the batch. While we should not sample any token from this
@@ -1197,7 +1208,8 @@ def _dummy_run(self, num_tokens: int, num_reqs: int,
11971208 position_ids = torch .zeros (num_tokens ,
11981209 dtype = torch .int32 ).to (self .device )
11991210 padded_num_slices = _get_padded_num_kv_cache_update_slices (
1200- num_tokens , self .max_num_reqs , self .block_size )
1211+ num_tokens , self .max_num_reqs , self .block_size ,
1212+ self ._num_slices_per_kv_cache_update_block )
12011213 num_kv_update_slices = torch .tensor ([padded_num_slices ],
12021214 dtype = torch .int32 ).to (self .device )
12031215 slot_mapping = torch .zeros ((3 , padded_num_slices ),
@@ -1220,8 +1232,8 @@ def _dummy_run(self, num_tokens: int, num_reqs: int,
12201232 query_start_loc = query_start_loc ,
12211233 num_seqs = num_seqs ,
12221234 num_kv_update_slices = num_kv_update_slices ,
1223- num_slices_per_kv_cache_update_block =
1224- NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK ,
1235+ num_slices_per_kv_cache_update_block = self .
1236+ _num_slices_per_kv_cache_update_block ,
12251237 )
12261238
12271239 if self .is_multimodal_model :
@@ -1826,19 +1838,41 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
18261838 return paddings [index ]
18271839
18281840
1829- def _get_padded_num_kv_cache_update_slices (num_tokens : int , max_num_reqs : int ,
1830- page_size : int ) -> int :
1841+ def _get_padded_num_kv_cache_update_slices (
1842+ num_tokens : int , max_num_reqs : int , page_size : int ,
1843+ num_slices_per_kv_cache_update_block : int ) -> int :
18311844 """Calculates the padded number of KV cache update slices to avoid
18321845 recompilation."""
18331846 padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
18341847 padded_num_slices = min (padded_num_slices , num_tokens )
18351848 padded_num_slices = (
1836- padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1
1837- ) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \
1838- NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
1849+ padded_num_slices + num_slices_per_kv_cache_update_block - 1
1850+ ) // num_slices_per_kv_cache_update_block * \
1851+ num_slices_per_kv_cache_update_block
18391852 return padded_num_slices
18401853
18411854
1855+ def _get_num_slices_per_kv_cache_update_block (page_size_bytes : int ) -> int :
1856+ """Find the optimum number of slices to copy per Pallas program instance.
1857+
1858+ Increasing the number of slices copied in one instance of the kernel program
1859+ will increase HBM bandwidth utilization via more in-flight DMAs.
1860+
1861+ However, it will also use more VMEM, and experimentally, we observed
1862+ performance regression at 128 slices on v6e, likely due to running
1863+ out of scalar registers. Thus this function will limit the number of
1864+ slices to 64.
1865+ """
1866+ # Conservative VMEM usage limit: 32 MiB
1867+ vmem_limit = 32 * 1024 * 1024
1868+ num_slices_per_block = vmem_limit // page_size_bytes
1869+ assert num_slices_per_block > 0 , "Number of slices should be positive"
1870+ num_slices_per_block = prev_power_of_2 (num_slices_per_block )
1871+ if num_slices_per_block > 64 :
1872+ num_slices_per_block = 64
1873+ return num_slices_per_block
1874+
1875+
18421876def replace_set_lora (model ):
18431877
18441878 def _tpu_set_lora (
0 commit comments