From fda61f7c1f612cc1dc95f82dbd504b298d86d7ef Mon Sep 17 00:00:00 2001 From: nicklucche Date: Tue, 27 May 2025 18:29:39 +0000 Subject: [PATCH] kv cache stride order for V1 Signed-off-by: nicklucche --- tests/v1/worker/test_gpu_model_runner.py | 71 +++++++++++++++++++----- vllm/v1/worker/gpu_model_runner.py | 26 ++++++++- 2 files changed, 81 insertions(+), 16 deletions(-) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index b8c3d88617d0..c38eb486646f 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 +import random + import pytest +from vllm.attention import Attention from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig) from vllm.sampling_params import SamplingParams @@ -13,27 +16,30 @@ from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner +BLOCK_SIZE = 16 +NUM_BLOCKS = 10 + def initialize_kv_cache(runner: GPUModelRunner): """ Only perform necessary steps in GPUModelRunner.initialize_kv_cache() """ + attn_spec = FullAttentionSpec( + block_size=BLOCK_SIZE, + num_kv_heads=runner.model_config.get_num_kv_heads( + runner.parallel_config), + head_size=runner.model_config.get_head_size(), + dtype=runner.kv_cache_dtype, + use_mla=False, + ) + tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS kv_cache_config = KVCacheConfig( - num_blocks=10, + num_blocks=NUM_BLOCKS, tensors={ - "layer.0": KVCacheTensor(size=1024), + "layer.0": KVCacheTensor(size=tensor_size), }, kv_cache_groups=[ - KVCacheGroupSpec( - layer_names=["layer.0"], - kv_cache_spec=FullAttentionSpec( - block_size=16, - num_kv_heads=runner.model_config.get_num_kv_heads( - runner.parallel_config), - head_size=runner.model_config.get_head_size(), - dtype=runner.kv_cache_dtype, - use_mla=False, - )) + KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec) ]) runner.kv_cache_config = kv_cache_config runner.input_batch = InputBatch( @@ -65,7 +71,7 @@ def model_runner(): seed=42, ) cache_config = CacheConfig( - block_size=16, + block_size=BLOCK_SIZE, gpu_memory_utilization=0.9, swap_space=0, cache_dtype="auto", @@ -77,6 +83,10 @@ def model_runner(): scheduler_config=scheduler_config, parallel_config=parallel_config, ) + num_heads = model_config.get_num_kv_heads(parallel_config) + head_size = model_config.get_head_size() + vllm_config.compilation_config.static_forward_context[ + "layer.0"] = Attention(num_heads, head_size, 0.1) device = "cuda" runner = GPUModelRunner(vllm_config, device) @@ -321,3 +331,38 @@ def test_update_states_request_unscheduled(model_runner): assert _is_req_added(model_runner, req_ids[1]) assert not _is_req_scheduled(model_runner, req_ids[1]) + + +def test_kv_cache_stride_order(monkeypatch, model_runner): + # This test checks if GPUModelRunner initializes correctly when an attention + # backend enforces a non-default KV cache stride order. + n_heads = model_runner.model_config.get_num_kv_heads( + model_runner.parallel_config) + expected_kv_cache_shape = [ + 2, NUM_BLOCKS, BLOCK_SIZE, n_heads, + model_runner.model_config.get_head_size() + ] + # TODO mla test + default_stride = list(range(5)) + # Permutation that gets you back to expected kv shape + rnd_stride = tuple(random.sample(default_stride, len(default_stride))) + + def rnd_stride_order(): + return rnd_stride + + # Patch the attention backend class and re-trigger the KV cache creation. + for attn_backend in model_runner.attn_backends: + monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order", + rnd_stride_order) + + model_runner.attn_backends = [] + model_runner.attn_metadata_builders = [] + model_runner.initialize_kv_cache(model_runner.kv_cache_config) + + # Shape is unchanged, but layout may differ + kv_cache_shape = model_runner.kv_caches[0].shape + assert list(kv_cache_shape) == expected_kv_cache_shape + if default_stride == rnd_stride: + assert all(kv.is_contiguous() for kv in model_runner.kv_caches) + else: + assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 910c0e80bb31..fd9990eb5528 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2018,9 +2018,29 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype - kv_caches[layer_name] = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) + try: + kv_cache_stride_order = self.attn_backends[ + i].get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len( + kv_cache_shape) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple( + range(len(kv_cache_shape))) + # The allocation respects the backend-defined stride order + # to ensure the semantic remains consistent for each + # backend. We first obtain the generic kv cache shape and + # then permute it according to the stride order which could + # result in a non-contiguous tensor. + kv_cache_shape = tuple(kv_cache_shape[i] + for i in kv_cache_stride_order) + # Maintain original KV shape view. + inv_order = [ + kv_cache_stride_order.index(i) + for i in range(len(kv_cache_stride_order)) + ] + kv_caches[layer_name] = torch.zeros( + kv_cache_shape, dtype=dtype, + device=self.device).permute(*inv_order) else: # TODO: add new branches when introducing more types of # KV cache specs.