11# SPDX-License-Identifier: Apache-2.0
22
3+ import random
4+
35import pytest
46
7+ from vllm .attention import Attention
58from vllm .config import (CacheConfig , ModelConfig , ParallelConfig ,
69 SchedulerConfig , VllmConfig )
710from vllm .sampling_params import SamplingParams
1316from vllm .v1 .worker .gpu_input_batch import InputBatch
1417from vllm .v1 .worker .gpu_model_runner import GPUModelRunner
1518
19+ BLOCK_SIZE = 16
20+ NUM_BLOCKS = 10
21+
1622
1723def initialize_kv_cache (runner : GPUModelRunner ):
1824 """
1925 Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
2026 """
27+ attn_spec = FullAttentionSpec (
28+ block_size = BLOCK_SIZE ,
29+ num_kv_heads = runner .model_config .get_num_kv_heads (
30+ runner .parallel_config ),
31+ head_size = runner .model_config .get_head_size (),
32+ dtype = runner .kv_cache_dtype ,
33+ use_mla = False ,
34+ )
35+ tensor_size = attn_spec .page_size_bytes * NUM_BLOCKS
2136 kv_cache_config = KVCacheConfig (
22- num_blocks = 10 ,
37+ num_blocks = NUM_BLOCKS ,
2338 tensors = {
24- "layer.0" : KVCacheTensor (size = 1024 ),
39+ "layer.0" : KVCacheTensor (size = tensor_size ),
2540 },
2641 kv_cache_groups = [
27- KVCacheGroupSpec (
28- layer_names = ["layer.0" ],
29- kv_cache_spec = FullAttentionSpec (
30- block_size = 16 ,
31- num_kv_heads = runner .model_config .get_num_kv_heads (
32- runner .parallel_config ),
33- head_size = runner .model_config .get_head_size (),
34- dtype = runner .kv_cache_dtype ,
35- use_mla = False ,
36- ))
42+ KVCacheGroupSpec (layer_names = ["layer.0" ], kv_cache_spec = attn_spec )
3743 ])
3844 runner .kv_cache_config = kv_cache_config
3945 runner .input_batch = InputBatch (
@@ -65,7 +71,7 @@ def model_runner():
6571 seed = 42 ,
6672 )
6773 cache_config = CacheConfig (
68- block_size = 16 ,
74+ block_size = BLOCK_SIZE ,
6975 gpu_memory_utilization = 0.9 ,
7076 swap_space = 0 ,
7177 cache_dtype = "auto" ,
@@ -77,6 +83,10 @@ def model_runner():
7783 scheduler_config = scheduler_config ,
7884 parallel_config = parallel_config ,
7985 )
86+ num_heads = model_config .get_num_kv_heads (parallel_config )
87+ head_size = model_config .get_head_size ()
88+ vllm_config .compilation_config .static_forward_context [
89+ "layer.0" ] = Attention (num_heads , head_size , 0.1 )
8090
8191 device = "cuda"
8292 runner = GPUModelRunner (vllm_config , device )
@@ -321,3 +331,38 @@ def test_update_states_request_unscheduled(model_runner):
321331
322332 assert _is_req_added (model_runner , req_ids [1 ])
323333 assert not _is_req_scheduled (model_runner , req_ids [1 ])
334+
335+
336+ def test_kv_cache_stride_order (monkeypatch , model_runner ):
337+ # This test checks if GPUModelRunner initializes correctly when an attention
338+ # backend enforces a non-default KV cache stride order.
339+ n_heads = model_runner .model_config .get_num_kv_heads (
340+ model_runner .parallel_config )
341+ expected_kv_cache_shape = [
342+ 2 , NUM_BLOCKS , BLOCK_SIZE , n_heads ,
343+ model_runner .model_config .get_head_size ()
344+ ]
345+ # TODO mla test
346+ default_stride = list (range (5 ))
347+ # Permutation that gets you back to expected kv shape
348+ rnd_stride = tuple (random .sample (default_stride , len (default_stride )))
349+
350+ def rnd_stride_order ():
351+ return rnd_stride
352+
353+ # Patch the attention backend class and re-trigger the KV cache creation.
354+ for attn_backend in model_runner .attn_backends :
355+ monkeypatch .setattr (attn_backend , "get_kv_cache_stride_order" ,
356+ rnd_stride_order )
357+
358+ model_runner .attn_backends = []
359+ model_runner .attn_metadata_builders = []
360+ model_runner .initialize_kv_cache (model_runner .kv_cache_config )
361+
362+ # Shape is unchanged, but layout may differ
363+ kv_cache_shape = model_runner .kv_caches [0 ].shape
364+ assert list (kv_cache_shape ) == expected_kv_cache_shape
365+ if default_stride == rnd_stride :
366+ assert all (kv .is_contiguous () for kv in model_runner .kv_caches )
367+ else :
368+ assert all (not kv .is_contiguous () for kv in model_runner .kv_caches )
0 commit comments