diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index d97d873ccfdd..1cf392621822 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -734,6 +734,13 @@ def next_power_of_2(n) -> int: return 1 << (n - 1).bit_length() +def prev_power_of_2(n: int) -> int: + """The previous power of 2 (inclusive)""" + if n <= 0: + return 0 + return 1 << (n.bit_length() - 1) + + def round_up(x: int, y: int) -> int: return ((x + y - 1) // y) * y @@ -2985,4 +2992,4 @@ def has_deep_ep() -> bool: def has_deep_gemm() -> bool: """Whether the optional `deep_gemm` package is available.""" - return _has_module("deep_gemm") \ No newline at end of file + return _has_module("deep_gemm") diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 253d79d925ce..0bcb9e901fb8 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -318,3 +318,9 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, page_size: int, num_slices_per_block: int) -> torch.Tensor: return kv_cache + + +def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int, + kv_cache_dtype: torch.dtype) -> int: + """Returns the size in bytes of one page of the KV cache.""" + return block_size * num_kv_heads * head_size * kv_cache_dtype.itemsize diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f5f26d8fff98..acd383ccf5e3 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -30,9 +30,10 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv, - is_pin_memory_available) + is_pin_memory_available, prev_power_of_2) from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, - PallasMetadata) + PallasMetadata, + get_page_size_bytes) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, @@ -56,8 +57,6 @@ INVALID_TOKEN_ID = -1 # Smallest output size MIN_NUM_SEQS = 8 -# Block size used for kv cache updating kernel -NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8 ######################################################### @@ -139,7 +138,11 @@ def __init__( self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype if cache_config.cache_dtype == "auto": - self.kv_cache_dtype = self.dtype + model_dtype = self.dtype + if isinstance(model_dtype, str): + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] + else: + self.kv_cache_dtype = model_dtype else: self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ cache_config.cache_dtype] @@ -192,6 +195,14 @@ def __init__( self.max_num_encoder_input_tokens = encoder_compute_budget self.encoder_cache_size = encoder_cache_size + self._num_slices_per_kv_cache_update_block = \ + _get_num_slices_per_kv_cache_update_block(get_page_size_bytes( + block_size=self.block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + kv_cache_dtype=self.kv_cache_dtype, + )) + # Lazy initialization self.model: nn.Module # Set after load_model self.kv_caches: list[torch.Tensor] = [] @@ -719,7 +730,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", num_kv_update_slices = slot_mapping_metadata.shape[0] padded_num_slices = _get_padded_num_kv_cache_update_slices( padded_total_num_scheduled_tokens, self.max_num_reqs, - self.block_size) + self.block_size, self._num_slices_per_kv_cache_update_block) slot_mapping_metadata = np.pad( slot_mapping_metadata, [[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]], @@ -750,8 +761,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", num_kv_update_slices=torch.tensor([num_kv_update_slices], dtype=torch.int32, device=self.device), - num_slices_per_kv_cache_update_block= - NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK, + num_slices_per_kv_cache_update_block=self. + _num_slices_per_kv_cache_update_block, ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this @@ -1178,7 +1189,8 @@ def _dummy_run(self, num_tokens: int, num_reqs: int, position_ids = torch.zeros(num_tokens, dtype=torch.int32).to(self.device) padded_num_slices = _get_padded_num_kv_cache_update_slices( - num_tokens, self.max_num_reqs, self.block_size) + num_tokens, self.max_num_reqs, self.block_size, + self._num_slices_per_kv_cache_update_block) num_kv_update_slices = torch.tensor([padded_num_slices], dtype=torch.int32).to(self.device) slot_mapping = torch.zeros((3, padded_num_slices), @@ -1201,8 +1213,8 @@ def _dummy_run(self, num_tokens: int, num_reqs: int, query_start_loc=query_start_loc, num_seqs=num_seqs, num_kv_update_slices=num_kv_update_slices, - num_slices_per_kv_cache_update_block= - NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK, + num_slices_per_kv_cache_update_block=self. + _num_slices_per_kv_cache_update_block, ) if self.is_multimodal_model: @@ -1807,19 +1819,41 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int: return paddings[index] -def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, - page_size: int) -> int: +def _get_padded_num_kv_cache_update_slices( + num_tokens: int, max_num_reqs: int, page_size: int, + num_slices_per_kv_cache_update_block: int) -> int: """Calculates the padded number of KV cache update slices to avoid recompilation.""" padded_num_slices = 2 * max_num_reqs + num_tokens // page_size padded_num_slices = min(padded_num_slices, num_tokens) padded_num_slices = ( - padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1 - ) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \ - NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK + padded_num_slices + num_slices_per_kv_cache_update_block - 1 + ) // num_slices_per_kv_cache_update_block * \ + num_slices_per_kv_cache_update_block return padded_num_slices +def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int: + """Find the optimum number of slices to copy per Pallas program instance. + + Increasing the number of slices copied in one instance of the kernel program + will increase HBM bandwidth utilization via more in-flight DMAs. + + However, it will also use more VMEM, and experimentally, we observed + performance regression at 128 slices on v6e, likely due to running + out of scalar registers. Thus this function will limit the number of + slices to 64. + """ + # Conservative VMEM usage limit: 32 MiB + vmem_limit = 32 * 1024 * 1024 + num_slices_per_block = vmem_limit // page_size_bytes + assert num_slices_per_block > 0, "Number of slices should be positive" + num_slices_per_block = prev_power_of_2(num_slices_per_block) + if num_slices_per_block > 64: + num_slices_per_block = 64 + return num_slices_per_block + + def replace_set_lora(model): def _tpu_set_lora(