Skip to content

Commit de02185

Browse files
NickLuccheamitm02
authored andcommitted
[V1] Allocate kv_cache with stride order for V1 (vllm-project#18775)
Signed-off-by: nicklucche <nlucches@redhat.com> Signed-off-by: amit <amit.man@gmail.com>
1 parent c3cd0ee commit de02185

File tree

2 files changed

+81
-16
lines changed

2 files changed

+81
-16
lines changed

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import random
4+
35
import pytest
46

7+
from vllm.attention import Attention
58
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
69
SchedulerConfig, VllmConfig)
710
from vllm.sampling_params import SamplingParams
@@ -13,27 +16,30 @@
1316
from vllm.v1.worker.gpu_input_batch import InputBatch
1417
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
1518

19+
BLOCK_SIZE = 16
20+
NUM_BLOCKS = 10
21+
1622

1723
def initialize_kv_cache(runner: GPUModelRunner):
1824
"""
1925
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
2026
"""
27+
attn_spec = FullAttentionSpec(
28+
block_size=BLOCK_SIZE,
29+
num_kv_heads=runner.model_config.get_num_kv_heads(
30+
runner.parallel_config),
31+
head_size=runner.model_config.get_head_size(),
32+
dtype=runner.kv_cache_dtype,
33+
use_mla=False,
34+
)
35+
tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
2136
kv_cache_config = KVCacheConfig(
22-
num_blocks=10,
37+
num_blocks=NUM_BLOCKS,
2338
tensors={
24-
"layer.0": KVCacheTensor(size=1024),
39+
"layer.0": KVCacheTensor(size=tensor_size),
2540
},
2641
kv_cache_groups=[
27-
KVCacheGroupSpec(
28-
layer_names=["layer.0"],
29-
kv_cache_spec=FullAttentionSpec(
30-
block_size=16,
31-
num_kv_heads=runner.model_config.get_num_kv_heads(
32-
runner.parallel_config),
33-
head_size=runner.model_config.get_head_size(),
34-
dtype=runner.kv_cache_dtype,
35-
use_mla=False,
36-
))
42+
KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec)
3743
])
3844
runner.kv_cache_config = kv_cache_config
3945
runner.input_batch = InputBatch(
@@ -65,7 +71,7 @@ def model_runner():
6571
seed=42,
6672
)
6773
cache_config = CacheConfig(
68-
block_size=16,
74+
block_size=BLOCK_SIZE,
6975
gpu_memory_utilization=0.9,
7076
swap_space=0,
7177
cache_dtype="auto",
@@ -77,6 +83,10 @@ def model_runner():
7783
scheduler_config=scheduler_config,
7884
parallel_config=parallel_config,
7985
)
86+
num_heads = model_config.get_num_kv_heads(parallel_config)
87+
head_size = model_config.get_head_size()
88+
vllm_config.compilation_config.static_forward_context[
89+
"layer.0"] = Attention(num_heads, head_size, 0.1)
8090

8191
device = "cuda"
8292
runner = GPUModelRunner(vllm_config, device)
@@ -321,3 +331,38 @@ def test_update_states_request_unscheduled(model_runner):
321331

322332
assert _is_req_added(model_runner, req_ids[1])
323333
assert not _is_req_scheduled(model_runner, req_ids[1])
334+
335+
336+
def test_kv_cache_stride_order(monkeypatch, model_runner):
337+
# This test checks if GPUModelRunner initializes correctly when an attention
338+
# backend enforces a non-default KV cache stride order.
339+
n_heads = model_runner.model_config.get_num_kv_heads(
340+
model_runner.parallel_config)
341+
expected_kv_cache_shape = [
342+
2, NUM_BLOCKS, BLOCK_SIZE, n_heads,
343+
model_runner.model_config.get_head_size()
344+
]
345+
# TODO mla test
346+
default_stride = list(range(5))
347+
# Permutation that gets you back to expected kv shape
348+
rnd_stride = tuple(random.sample(default_stride, len(default_stride)))
349+
350+
def rnd_stride_order():
351+
return rnd_stride
352+
353+
# Patch the attention backend class and re-trigger the KV cache creation.
354+
for attn_backend in model_runner.attn_backends:
355+
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
356+
rnd_stride_order)
357+
358+
model_runner.attn_backends = []
359+
model_runner.attn_metadata_builders = []
360+
model_runner.initialize_kv_cache(model_runner.kv_cache_config)
361+
362+
# Shape is unchanged, but layout may differ
363+
kv_cache_shape = model_runner.kv_caches[0].shape
364+
assert list(kv_cache_shape) == expected_kv_cache_shape
365+
if default_stride == rnd_stride:
366+
assert all(kv.is_contiguous() for kv in model_runner.kv_caches)
367+
else:
368+
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2033,9 +2033,29 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
20332033
num_blocks, kv_cache_spec.block_size,
20342034
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
20352035
dtype = kv_cache_spec.dtype
2036-
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
2037-
dtype=dtype,
2038-
device=self.device)
2036+
try:
2037+
kv_cache_stride_order = self.attn_backends[
2038+
i].get_kv_cache_stride_order()
2039+
assert len(kv_cache_stride_order) == len(
2040+
kv_cache_shape)
2041+
except (AttributeError, NotImplementedError):
2042+
kv_cache_stride_order = tuple(
2043+
range(len(kv_cache_shape)))
2044+
# The allocation respects the backend-defined stride order
2045+
# to ensure the semantic remains consistent for each
2046+
# backend. We first obtain the generic kv cache shape and
2047+
# then permute it according to the stride order which could
2048+
# result in a non-contiguous tensor.
2049+
kv_cache_shape = tuple(kv_cache_shape[i]
2050+
for i in kv_cache_stride_order)
2051+
# Maintain original KV shape view.
2052+
inv_order = [
2053+
kv_cache_stride_order.index(i)
2054+
for i in range(len(kv_cache_stride_order))
2055+
]
2056+
kv_caches[layer_name] = torch.zeros(
2057+
kv_cache_shape, dtype=dtype,
2058+
device=self.device).permute(*inv_order)
20392059
else:
20402060
# TODO: add new branches when introducing more types of
20412061
# KV cache specs.

0 commit comments

Comments
 (0)