Skip to content

Commit 3bee94b

Browse files
NickLucchealbertoperdomo2
authored andcommitted
[Misc] Refactor get_kv_cache_spec into AttentionLayerBase (vllm-project#26587)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
1 parent 6222e75 commit 3bee94b

File tree

10 files changed

+151
-118
lines changed

10 files changed

+151
-118
lines changed

vllm/attention/layer.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from vllm.attention.selector import get_attn_backend
1717
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
1818
from vllm.config import CacheConfig, get_current_vllm_config
19+
from vllm.config.vllm import VllmConfig
1920
from vllm.distributed.kv_transfer import (
2021
get_kv_transfer_group,
2122
has_kv_transfer_group,
@@ -34,7 +35,16 @@
3435
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
3536
from vllm.model_executor.models.vision import get_vit_attn_backend
3637
from vllm.platforms import current_platform
37-
from vllm.utils import direct_register_custom_op
38+
from vllm.utils import (
39+
direct_register_custom_op,
40+
kv_cache_dtype_str_to_dtype,
41+
)
42+
from vllm.v1.kv_cache_interface import (
43+
FullAttentionSpec,
44+
KVCacheSpec,
45+
MLAAttentionSpec,
46+
SlidingWindowSpec,
47+
)
3848

3949
FP8_DTYPE = current_platform.fp8_dtype()
4050
logger = init_logger(__name__)
@@ -152,6 +162,7 @@ def __init__(
152162
else:
153163
sliding_window = None
154164

165+
vllm_config = get_current_vllm_config()
155166
if cache_config is not None:
156167
kv_cache_dtype = cache_config.cache_dtype
157168
block_size = cache_config.block_size
@@ -160,6 +171,9 @@ def __init__(
160171
kv_cache_dtype = "auto"
161172
block_size = 16
162173
calculate_kv_scales = False
174+
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
175+
kv_cache_dtype, vllm_config.model_config
176+
)
163177
if num_kv_heads is None:
164178
num_kv_heads = num_heads
165179
assert num_heads % num_kv_heads == 0, (
@@ -256,7 +270,7 @@ def __init__(
256270
self.use_direct_call = not current_platform.opaque_attention_op()
257271

258272
self.use_output = self.attn_backend.accept_output_buffer
259-
compilation_config = get_current_vllm_config().compilation_config
273+
compilation_config = vllm_config.compilation_config
260274
if prefix in compilation_config.static_forward_context:
261275
raise ValueError(f"Duplicate layer name: {prefix}")
262276
compilation_config.static_forward_context[prefix] = self
@@ -276,9 +290,7 @@ def __init__(
276290
# this variable will not be accessed if use_direct_call is True
277291
self.kv_cache = [
278292
torch.tensor([])
279-
for _ in range(
280-
get_current_vllm_config().parallel_config.pipeline_parallel_size
281-
)
293+
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
282294
]
283295

284296
# Initialize q/k/v range constants.
@@ -394,6 +406,30 @@ def process_weights_after_loading(self, act_dtype: torch.dtype):
394406
def get_attn_backend(self) -> type[AttentionBackend]:
395407
return self.attn_backend
396408

409+
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
410+
# Block size may get updated after model loading, refresh it
411+
block_size = vllm_config.cache_config.block_size
412+
# Should not be called for enc-dec or encoder-only attention.
413+
assert self.attn_type == AttentionType.DECODER
414+
if self.sliding_window is not None:
415+
assert not vllm_config.model_config.use_mla, (
416+
"MLA is not supported for slidingwindow"
417+
)
418+
return SlidingWindowSpec(
419+
block_size=block_size,
420+
num_kv_heads=self.num_kv_heads,
421+
head_size=self.head_size,
422+
dtype=self.kv_cache_torch_dtype,
423+
sliding_window=self.sliding_window,
424+
)
425+
else:
426+
return FullAttentionSpec(
427+
block_size=block_size,
428+
num_kv_heads=self.num_kv_heads,
429+
head_size=self.head_size,
430+
dtype=self.kv_cache_torch_dtype,
431+
)
432+
397433

398434
class MultiHeadAttention(nn.Module):
399435
"""Multi-headed attention without any cache, used for ViT."""
@@ -749,6 +785,18 @@ def calc_kv_scales(
749785
def get_attn_backend(self) -> type[AttentionBackend]:
750786
return self.attn_backend
751787

788+
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
789+
kv_cache_dtype = kv_cache_dtype_str_to_dtype(
790+
self.kv_cache_dtype, vllm_config.model_config
791+
)
792+
return MLAAttentionSpec(
793+
block_size=vllm_config.cache_config.block_size,
794+
num_kv_heads=1,
795+
head_size=self.head_size,
796+
dtype=kv_cache_dtype,
797+
cache_dtype_str=vllm_config.cache_config.cache_dtype,
798+
)
799+
752800

753801
def wait_for_kv_layer_from_connector(layer_name: str):
754802
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():

vllm/attention/layers/chunked_local_attention.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
1010
from vllm.attention.selector import get_attn_backend
1111
from vllm.config import CacheConfig
12+
from vllm.config.vllm import VllmConfig
1213
from vllm.model_executor.layers.quantization import QuantizationConfig
1314
from vllm.v1.attention.backends.utils import (
1415
AttentionCGSupport,
1516
CommonAttentionMetadata,
1617
make_local_attention_virtual_batches,
1718
subclass_attention_backend,
1819
)
20+
from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, KVCacheSpec
1921

2022
from ..layer import Attention
2123

@@ -67,6 +69,7 @@ def __init__(
6769
kv_sharing_target_layer_name: str | None = None,
6870
prefix: str = "",
6971
):
72+
self.attention_chunk_size = attention_chunk_size
7073
dtype = torch.get_default_dtype()
7174
if cache_config is not None:
7275
kv_cache_dtype = cache_config.cache_dtype
@@ -99,3 +102,13 @@ def __init__(
99102
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
100103
attn_backend=attn_backend,
101104
)
105+
106+
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
107+
assert self.attention_chunk_size
108+
return ChunkedLocalAttentionSpec(
109+
block_size=vllm_config.cache_config.block_size,
110+
num_kv_heads=self.num_kv_heads,
111+
head_size=self.head_size,
112+
dtype=self.kv_cache_torch_dtype,
113+
attention_chunk_size=self.attention_chunk_size,
114+
)

vllm/attention/layers/cross_attention.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
CommonAttentionMetadata,
2222
subclass_attention_backend,
2323
)
24-
from vllm.v1.kv_cache_interface import CrossAttentionSpec
24+
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
2525

2626
logger = init_logger(__name__)
2727

@@ -174,3 +174,11 @@ def __init__(
174174
attn_type=AttentionType.ENCODER_DECODER,
175175
**kwargs,
176176
)
177+
178+
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
179+
return CrossAttentionSpec(
180+
block_size=vllm_config.cache_config.block_size,
181+
num_kv_heads=self.num_kv_heads,
182+
head_size=self.head_size,
183+
dtype=self.kv_cache_torch_dtype,
184+
)

vllm/attention/layers/encoder_only_attention.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
from vllm.attention.layer import Attention
1515
from vllm.attention.selector import get_attn_backend
1616
from vllm.config import CacheConfig
17+
from vllm.config.vllm import VllmConfig
1718
from vllm.v1.attention.backends.utils import (
1819
CommonAttentionMetadata,
1920
subclass_attention_backend,
2021
)
22+
from vllm.v1.kv_cache_interface import KVCacheSpec
2123

2224

2325
@functools.lru_cache
@@ -98,3 +100,7 @@ def __init__(
98100
attn_type=AttentionType.ENCODER_ONLY,
99101
**kwargs,
100102
)
103+
104+
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
105+
# Does not need KV cache
106+
return None

vllm/model_executor/layers/attention_layer_base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from abc import ABC, abstractmethod
66
from typing import TYPE_CHECKING
77

8+
from vllm.config import VllmConfig
9+
from vllm.v1.kv_cache_interface import KVCacheSpec
10+
811
if TYPE_CHECKING:
912
from vllm.attention.backends.abstract import AttentionBackend
1013

@@ -22,3 +25,11 @@ class AttentionLayerBase(ABC):
2225
def get_attn_backend(self) -> type["AttentionBackend"]:
2326
"""Get the attention backend class for this layer."""
2427
pass
28+
29+
@abstractmethod
30+
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
31+
"""
32+
Get the KV cache spec for this layer.
33+
May be None if the layer does not need KV cache.
34+
"""
35+
pass

vllm/model_executor/layers/mamba/abstract.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
import torch
88

9+
from vllm.config import VllmConfig
910
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
11+
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
1012

1113
if TYPE_CHECKING:
1214
from vllm.attention.backends.abstract import AttentionBackend
@@ -40,3 +42,30 @@ def mamba_type(self) -> str:
4042
def get_attn_backend(self) -> type["AttentionBackend"]:
4143
"""Get the attention backend class for this Mamba layer."""
4244
pass
45+
46+
@abstractmethod
47+
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
48+
pass
49+
50+
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
51+
if (
52+
vllm_config.speculative_config is not None
53+
and vllm_config.model_config.hf_config.model_type not in ["qwen3_next"]
54+
):
55+
raise NotImplementedError(
56+
"Mamba with speculative decoding is not supported yet."
57+
)
58+
mamba_block_size = vllm_config.cache_config.mamba_block_size
59+
page_size_padded = vllm_config.cache_config.mamba_page_size_padded
60+
return MambaSpec(
61+
shapes=self.get_state_shape(),
62+
dtypes=self.get_state_dtype(),
63+
block_size=mamba_block_size,
64+
page_size_padded=page_size_padded,
65+
mamba_type=self.mamba_type,
66+
num_speculative_blocks=(
67+
vllm_config.speculative_config.num_speculative_tokens
68+
if vllm_config.speculative_config
69+
else 0
70+
),
71+
)

vllm/model_executor/models/deepseek_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def __init__(
481481
raise ValueError(f"Duplicate layer name: {prefix}")
482482
compilation_config.static_forward_context[prefix] = self
483483

484-
def get_kv_cache_spec(self) -> KVCacheSpec:
484+
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
485485
return MLAAttentionSpec( # Only has one vector instead of K + V
486486
block_size=self.cache_config.block_size,
487487
num_kv_heads=1,

vllm/utils/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,15 @@ def set_default_torch_num_threads(num_threads: int):
137137
torch.set_num_threads(old_num_threads)
138138

139139

140+
def kv_cache_dtype_str_to_dtype(
141+
kv_cache_dtype: str, model_config: ModelConfig
142+
) -> torch.dtype:
143+
if kv_cache_dtype == "auto":
144+
# Model config may not be specified for unit tests, default to float16
145+
return model_config.dtype if model_config else torch.half
146+
return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
147+
148+
140149
T = TypeVar("T")
141150
U = TypeVar("U")
142151

vllm/v1/spec_decode/eagle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,7 @@ def load_model(self, target_model: nn.Module) -> None:
948948
indexer_layers[first_layer]
949949
.get_attn_backend()
950950
.get_builder_cls()(
951-
indexer_layers[first_layer].get_kv_cache_spec(),
951+
indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
952952
self.indexer_layer_names,
953953
self.vllm_config,
954954
self.device,

0 commit comments

Comments
 (0)