Skip to content

Commit 6f0f570

Browse files
authored
[deepseek] kernel block size for UniformTypeKVCacheSpecs (#26559)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
1 parent b545a0b commit 6f0f570

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

vllm/v1/attention/backends/mla/indexer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from dataclasses import dataclass
4-
from typing import ClassVar, Optional
4+
from typing import ClassVar, Optional, Union
55

66
import torch
77

8-
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
8+
from vllm.attention.backends.abstract import (
9+
AttentionBackend,
10+
AttentionMetadata,
11+
MultipleOf,
12+
)
913
from vllm.config import VllmConfig
1014
from vllm.logger import init_logger
1115
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata
@@ -47,6 +51,10 @@ def get_kv_cache_shape(
4751
def get_kv_cache_stride_order() -> tuple[int, ...]:
4852
return (0, 1, 2)
4953

54+
@classmethod
55+
def get_supported_kernel_block_size(cls) -> list[Union[int, MultipleOf]]:
56+
return [64]
57+
5058

5159
@dataclass
5260
class DeepseekV32IndexerPrefillChunkMetadata:

vllm/v1/worker/gpu_model_runner.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4242,9 +4242,14 @@ def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[in
42424242
for kv_cache_group_id, kv_cache_group in enumerate(
42434243
kv_cache_config.kv_cache_groups
42444244
):
4245-
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
4245+
kv_cache_spec = kv_cache_group.kv_cache_spec
4246+
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
4247+
# All layers in the UniformTypeKVCacheSpecs have the same type,
4248+
# Pick an arbitrary one to dispatch.
4249+
kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values()))
4250+
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
42464251
continue
4247-
elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec):
4252+
elif isinstance(kv_cache_spec, AttentionSpec):
42484253
# This is an attention backend that supports virtual
42494254
# block splitting. Get the supported block sizes from
42504255
# all backends in the group.
@@ -4254,10 +4259,10 @@ def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[in
42544259
kv_manager_block_size, attn_groups
42554260
)
42564261
kernel_block_sizes.append(selected_kernel_size)
4257-
elif isinstance(kv_cache_group.kv_cache_spec, MambaSpec):
4262+
elif isinstance(kv_cache_spec, MambaSpec):
42584263
# This is likely Mamba or other non-attention cache,
42594264
# no splitting.
4260-
kernel_block_sizes.append(kv_cache_group.kv_cache_spec.block_size)
4265+
kernel_block_sizes.append(kv_cache_spec.block_size)
42614266
else:
42624267
raise NotImplementedError(
42634268
f"unknown kv cache spec {kv_cache_group.kv_cache_spec}"

0 commit comments

Comments
 (0)