From 33d5969547a68241b46d47d0d3c7c62c31468395 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 5 Sep 2024 09:16:52 -0700 Subject: [PATCH 1/2] Reshape_and_cache_flash kernel to be kv-cache layout aware. Signed-off-by: shuw --- csrc/cache_kernels.cu | 39 ++++++++++--------- tests/kernels/attention/test_cache.py | 56 ++++++++++++++++++--------- vllm/attention/backends/abstract.py | 4 ++ vllm/attention/backends/flashinfer.py | 30 +++++++++++--- vllm/utils.py | 13 +++++-- vllm/worker/cache_engine.py | 23 ++++++++--- 6 files changed, 115 insertions(+), 50 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 0b3f6fc8c19a..88559c8fe718 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -270,9 +270,10 @@ __global__ void reshape_and_cache_flash_kernel( cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, // head_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int block_stride, const int key_stride, const int value_stride, - const int num_heads, const int head_size, const int block_size, - const float* k_scale, const float* v_scale) { + const int64_t block_stride, const int64_t page_stride, + const int64_t head_stride, const int64_t key_stride, + const int64_t value_stride, const int num_heads, const int head_size, + const int block_size, const float* k_scale, const float* v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded @@ -288,8 +289,8 @@ __global__ void reshape_and_cache_flash_kernel( const int head_idx = i / head_size; const int head_offset = i % head_size; const int64_t tgt_key_value_idx = block_idx * block_stride + - block_offset * num_heads * head_size + - head_idx * head_size + head_offset; + block_offset * page_stride + + head_idx * head_stride + head_offset; scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_value = value[src_value_idx]; if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { @@ -396,16 +397,16 @@ void reshape_and_cache( // KV_T is the data type of key and value tensors. // CACHE_T is the stored data type of kv-cache. // KV_DTYPE is the real data type of kv-cache. -#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \ - vllm::reshape_and_cache_flash_kernel \ - <<>>( \ - reinterpret_cast(key.data_ptr()), \ - reinterpret_cast(value.data_ptr()), \ - reinterpret_cast(key_cache.data_ptr()), \ - reinterpret_cast(value_cache.data_ptr()), \ - slot_mapping.data_ptr(), block_stride, key_stride, \ - value_stride, num_heads, head_size, block_size, \ - reinterpret_cast(k_scale.data_ptr()), \ +#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_flash_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, page_stride, \ + head_stride, key_stride, value_stride, num_heads, head_size, \ + block_size, reinterpret_cast(k_scale.data_ptr()), \ reinterpret_cast(v_scale.data_ptr())); void reshape_and_cache_flash( @@ -432,9 +433,11 @@ void reshape_and_cache_flash( int head_size = key.size(2); int block_size = key_cache.size(1); - int key_stride = key.stride(0); - int value_stride = value.stride(0); - int block_stride = key_cache.stride(0); + int64_t key_stride = key.stride(0); + int64_t value_stride = value.stride(0); + int64_t block_stride = key_cache.stride(0); + int64_t page_stride = key_cache.stride(1); + int64_t head_stride = key_cache.stride(2); TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0)); dim3 grid(num_tokens); diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index 899122818e0e..2fada663f5d3 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -16,6 +16,7 @@ NUM_HEADS = [8] # Arbitrary values for testing HEAD_SIZES = [64, 80, 120, 256] BLOCK_SIZES = [8, 16, 32] +CACHE_LAYOUTS = ["NHD", "HND"] # Parameters for MLA tests. KV_LORA_RANKS = [512] @@ -220,6 +221,7 @@ def test_reshape_and_cache( @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS) @torch.inference_mode() def test_reshape_and_cache_flash( kv_cache_factory_flashinfer, @@ -232,6 +234,7 @@ def test_reshape_and_cache_flash( seed: int, device: str, kv_cache_dtype: str, + kv_cache_layout: str, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) @@ -242,7 +245,6 @@ def test_reshape_and_cache_flash( slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) - qkv = torch.randn(num_tokens, 3, num_heads, @@ -261,27 +263,35 @@ def test_reshape_and_cache_flash( kv_cache_dtype, dtype, device=device, + cache_layout=kv_cache_layout, ) - key_cache, value_cache = key_caches[0].contiguous( - ), value_caches[0].contiguous() + key_cache, value_cache = key_caches[0], value_caches[0] del key_caches del value_caches k_scale = (key.amax() / 64.0).to(torch.float32) v_scale = (value.amax() / 64.0).to(torch.float32) + def permute_and_compact(x): + y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3) + return y.contiguous() + + key_cache_compact = permute_and_compact(key_cache) + value_cache_compact = permute_and_compact(value_cache) + # Clone the KV caches. if kv_cache_dtype == "fp8": - cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item(), - kv_cache_dtype) - cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item(), + cloned_key_cache = torch.empty_like(key_cache_compact, + dtype=torch.float16) + ops.convert_fp8(cloned_key_cache, key_cache_compact, k_scale.item(), kv_cache_dtype) + cloned_value_cache = torch.empty_like(value_cache_compact, + dtype=torch.float16) + ops.convert_fp8(cloned_value_cache, value_cache_compact, + v_scale.item(), kv_cache_dtype) else: - cloned_key_cache = key_cache.clone() - cloned_value_cache = value_cache.clone() - + cloned_key_cache = key_cache_compact.clone() + cloned_value_cache = value_cache_compact.clone() # Call the reshape_and_cache kernel. opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash, (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, @@ -289,16 +299,20 @@ def test_reshape_and_cache_flash( cond=(head_size == HEAD_SIZES[0])) ops.reshape_and_cache_flash(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, k_scale, v_scale) + key_cache_compact = permute_and_compact(key_cache) + value_cache_compact = permute_and_compact(value_cache) if kv_cache_dtype == "fp8": - result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) + result_key_cache = torch.empty_like(key_cache_compact, + dtype=torch.float16) ops.convert_fp8(result_key_cache, - key_cache, + key_cache_compact, k_scale.item(), kv_dtype=kv_cache_dtype) - result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) + result_value_cache = torch.empty_like(value_cache_compact, + dtype=torch.float16) ops.convert_fp8(result_value_cache, - value_cache, + value_cache_compact, v_scale.item(), kv_dtype=kv_cache_dtype) @@ -310,8 +324,12 @@ def test_reshape_and_cache_flash( for i in range(num_tokens): block_idx = block_indicies_lst[i] block_offset = block_offsets_lst[i] - cloned_key_cache[block_idx, block_offset, :, :] = key[i] - cloned_value_cache[block_idx, block_offset, :, :] = value[i] + if kv_cache_layout == "NHD": + cloned_key_cache[block_idx, block_offset, :, :] = key[i] + cloned_value_cache[block_idx, block_offset, :, :] = value[i] + else: + cloned_key_cache[block_idx, :, block_offset, :] = key[i] + cloned_value_cache[block_idx, :, block_offset, :] = value[i] if kv_cache_dtype == "fp8": torch.testing.assert_close(result_key_cache, @@ -323,8 +341,8 @@ def test_reshape_and_cache_flash( atol=0.001, rtol=0.1) else: - torch.testing.assert_close(key_cache, cloned_key_cache) - torch.testing.assert_close(value_cache, cloned_value_cache) + torch.testing.assert_close(key_cache_compact, cloned_key_cache) + torch.testing.assert_close(value_cache_compact, cloned_value_cache) @pytest.mark.parametrize("direction", COPYING_DIRECTION) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 82d60f9da7da..1aa1652dc802 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -77,6 +77,10 @@ def get_kv_cache_shape( ) -> Tuple[int, ...]: raise NotImplementedError + @staticmethod + def get_kv_cache_stride_order() -> Tuple[int, ...]: + raise NotImplementedError + @staticmethod @abstractmethod def swap_blocks( diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 718b15e58785..889c4eb9d8e6 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import dataclasses +import os from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass @@ -48,6 +49,9 @@ from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) +FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT", + "NHD").upper() + class FlashInferBackend(AttentionBackend): @@ -80,6 +84,14 @@ def get_kv_cache_shape( ) -> Tuple[int, ...]: return (num_blocks, 2, block_size, num_kv_heads, head_size) + @staticmethod + def get_kv_cache_stride_order() -> Tuple[int, ...]: + cache_layout = FLASHINFER_KV_CACHE_LAYOUT + assert (cache_layout in ("NHD", "HND")) + stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, + 2, 4) + return stride_order + @staticmethod def swap_blocks( src_kv_cache: torch.Tensor, @@ -188,6 +200,7 @@ def __init__(self, runner): self.global_hyperparameters: Optional[PerLayerParameters] = None self.vllm_config = self.runner.vllm_config + self._kv_cache_layout = None def _get_workspace_buffer(self): if self._workspace_buffer is None: @@ -197,10 +210,15 @@ def _get_workspace_buffer(self): device=self.runner.device) return self._workspace_buffer + def get_kv_cache_layout(self): + if self._kv_cache_layout is None: + self._kv_cache_layout = FLASHINFER_KV_CACHE_LAYOUT + return self._kv_cache_layout + def _get_prefill_wrapper(self): if self._prefill_wrapper is None: self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), "NHD") + self._get_workspace_buffer(), self.get_kv_cache_layout()) return self._prefill_wrapper def _get_decode_wrapper(self): @@ -213,7 +231,7 @@ def _get_decode_wrapper(self): num_qo_heads // num_kv_heads > 4) self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._get_workspace_buffer(), - "NHD", + self.get_kv_cache_layout(), use_tensor_cores=use_tensor_cores) return self._decode_wrapper @@ -274,7 +292,8 @@ def graph_capture_get_metadata_for_batch( self._graph_decode_wrapper = \ CUDAGraphBatchDecodeWithPagedKVCacheWrapper( self._graph_decode_workspace_buffer, _indptr_buffer, - self._graph_indices_buffer, _last_page_len_buffer, "NHD", + self._graph_indices_buffer, _last_page_len_buffer, + self.get_kv_cache_layout(), use_tensor_cores) if self.runner.kv_cache_dtype.startswith("fp8"): kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( @@ -1005,6 +1024,7 @@ def forward( prefill_output: Optional[torch.Tensor] = None decode_output: Optional[torch.Tensor] = None + stride_order = FlashInferBackend.get_kv_cache_stride_order() if prefill_meta := attn_metadata.prefill_metadata: # We will use flash attention for prefill # when kv_cache is not provided. @@ -1036,7 +1056,7 @@ def forward( prefill_output = prefill_meta.prefill_wrapper.run( query, - kv_cache, + kv_cache.permute(*stride_order), k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, ) @@ -1051,7 +1071,7 @@ def forward( decode_output = decode_meta.decode_wrapper.run( decode_query, - kv_cache, + kv_cache.permute(*stride_order), k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, ) diff --git a/vllm/utils.py b/vllm/utils.py index ed406a6b7b11..96a82530afaa 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -762,21 +762,28 @@ def create_kv_caches_with_random_flash( model_dtype: Optional[Union[str, torch.dtype]] = None, seed: Optional[int] = None, device: Optional[str] = "cuda", + cache_layout: Optional[str] = "NHD", ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: from vllm.platforms import current_platform current_platform.seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) - key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) + generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) + assert cache_layout in ("NHD", "HND") + stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, + 4) + + kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] + for i in stride_order) scale = head_size**-0.5 key_caches: list[torch.Tensor] = [] value_caches: list[torch.Tensor] = [] for _ in range(num_layers): - key_value_cache = torch.empty(size=key_value_cache_shape, + key_value_cache = torch.empty(size=kv_cache_allocation_shape, dtype=torch_dtype, - device=device) + device=device).permute(*stride_order) if cache_dtype in ["auto", "half", "bfloat16", "float"]: key_value_cache.uniform_(-scale, scale) elif cache_dtype == 'fp8': diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 85ebe8121e52..d48a6957c5dd 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -71,19 +71,32 @@ def _allocate_kv_cache( device: str, ) -> List[torch.Tensor]: """Allocates KV cache on the specified device.""" - kv_cache_shape = self.attn_backend.get_kv_cache_shape( + kv_cache_generic_shape = self.attn_backend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size) pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] + try: + kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order( + ) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(kv_cache_generic_shape))) + + # The allocation respects the backend-defined stride order to ensure + # the semantic remains consistent for each backend. We first obtain the + # generic kv cache shape and then permute it according to the stride + # order which could result in a non-contiguous tensor. + kv_cache_allocation_shape = tuple(kv_cache_generic_shape[i] + for i in kv_cache_stride_order) for _ in range(self.num_attention_layers): # null block in CpuGpuBlockAllocator requires at least that # block to be zeroed-out. # We zero-out everything for simplicity. - layer_kv_cache = torch.zeros(kv_cache_shape, - dtype=self.dtype, - pin_memory=pin_memory, - device=device) + layer_kv_cache = torch.zeros( + kv_cache_allocation_shape, + dtype=self.dtype, + pin_memory=pin_memory, + device=device).permute(*kv_cache_stride_order) # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases # when entry_shape is higher than 1D From a48be76f41fccb45800f497b12eaec0a37a369e0 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 24 Apr 2025 21:37:47 +0000 Subject: [PATCH 2/2] Reduce test size Signed-off-by: shuw --- tests/kernels/attention/test_cache.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index 2fada663f5d3..2f2212dd2b0e 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -239,6 +239,10 @@ def test_reshape_and_cache_flash( current_platform.seed_everything(seed) torch.set_default_device(device) + # fp8 conversion requires continugous memory buffer. Reduce the number of + # blocks and tokens to consume less memory. + num_tokens = num_tokens // 2 + num_blocks = num_blocks // 2 # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping_lst = random.sample(range(num_slots), num_tokens)