Skip to content

Commit 17373dc

Browse files
authored
[Attention] Refactor AttentionMetadata Preparation for Encoder-only Models (#23154)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
1 parent 5964069 commit 17373dc

File tree

12 files changed

+227
-215
lines changed

12 files changed

+227
-215
lines changed

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -680,17 +680,20 @@ def test_init_kv_cache_with_kv_sharing_valid():
680680
kv_cache_spec[layer_0].page_size_bytes
681681

682682
runner.initialize_kv_cache(kv_cache_config)
683+
kv_cache_config_after_init = runner.kv_cache_config
683684

684685
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
685686
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
686687
# check layer 1 kv cache shares memory with layer 0
687688
assert id(layer_1_kv) == id(layer_0_kv)
688689

689690
# check layer 1 added to kv cache group's layer names
690-
assert len(kv_cache_config.kv_cache_groups) == 1
691-
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
692-
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
693-
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
691+
assert len(kv_cache_config_after_init.kv_cache_groups) == 1
692+
assert len(kv_cache_config_after_init.kv_cache_groups[0].layer_names) == 2
693+
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[
694+
0] == layer_0
695+
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[
696+
1] == layer_1
694697

695698

696699
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):

vllm/attention/layers/chunked_local_attention.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
import torch
77

88
from vllm import envs
9-
from vllm.attention.backends.abstract import AttentionBackend
9+
from vllm.attention.backends.abstract import (AttentionBackend,
10+
AttentionMetadata)
1011
from vllm.attention.selector import get_attn_backend
1112
from vllm.config import CacheConfig, QuantizationConfig
1213
from vllm.v1.attention.backends.utils import (
1314
CommonAttentionMetadata, make_local_attention_virtual_batches,
14-
subclass_attention_backend, subclass_attention_metadata_builder)
15+
subclass_attention_backend)
1516

1617
from ..layer import Attention
1718

@@ -24,21 +25,23 @@ def create_chunked_local_attention_backend(
2425
) -> type[AttentionBackend]:
2526
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
2627

27-
def build_preprocess_fn(cm: CommonAttentionMetadata):
28-
return make_local_attention_virtual_batches(attention_chunk_size, cm,
29-
block_size)
28+
underlying_builder = underlying_attn_backend.get_builder_cls()
29+
30+
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
31+
32+
def build(self,
33+
common_prefix_len: int,
34+
common_attn_metadata: CommonAttentionMetadata,
35+
fast_build: bool = False) -> AttentionMetadata:
36+
common_attn_metadata = make_local_attention_virtual_batches(
37+
attention_chunk_size, common_attn_metadata, block_size)
38+
return super().build(common_prefix_len, common_attn_metadata,
39+
fast_build)
3040

31-
# Dynamically create a new attention backend that wraps the
32-
# underlying attention backend but applies
33-
# `make_local_attention_virtual_batches` before calling `build(...)`
34-
builder_cls = subclass_attention_metadata_builder(
35-
name_prefix=prefix,
36-
builder_cls=underlying_attn_backend.get_builder_cls(),
37-
build_preprocess_fn=build_preprocess_fn)
3841
attn_backend = subclass_attention_backend(
3942
name_prefix=prefix,
4043
attention_backend_cls=underlying_attn_backend,
41-
builder_cls=builder_cls)
44+
builder_cls=ChunkedLocalAttentionBuilder)
4245

4346
return attn_backend
4447

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import functools
4+
from copy import copy
5+
from typing import Optional
6+
7+
import torch
8+
from transformers import CacheConfig
9+
10+
from vllm import envs
11+
from vllm.attention.backends.abstract import (AttentionBackend,
12+
AttentionMetadata, AttentionType)
13+
from vllm.attention.layer import Attention
14+
from vllm.attention.selector import get_attn_backend
15+
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
16+
subclass_attention_backend)
17+
18+
19+
@functools.lru_cache
20+
def create_encoder_only_attention_backend(
21+
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
22+
prefix = "EncoderOnlyAttention_"
23+
underlying_builder = underlying_attn_backend.get_builder_cls()
24+
25+
class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore
26+
27+
def build(self,
28+
common_prefix_len: int,
29+
common_attn_metadata: CommonAttentionMetadata,
30+
fast_build: bool = False) -> AttentionMetadata:
31+
new_common_attn_metadata = copy(common_attn_metadata)
32+
new_common_attn_metadata.causal = False
33+
return super().build(common_prefix_len, new_common_attn_metadata,
34+
fast_build)
35+
36+
attn_backend = subclass_attention_backend(
37+
name_prefix=prefix,
38+
attention_backend_cls=underlying_attn_backend,
39+
builder_cls=EncoderOnlyAttentionBuilder)
40+
41+
return attn_backend
42+
43+
44+
class EncoderOnlyAttention(Attention):
45+
"""
46+
Encoder attention is a special case that doesn't need a KV Cache.
47+
"""
48+
49+
def __init__(self,
50+
num_heads: int,
51+
head_size: int,
52+
scale: float,
53+
cache_config: Optional[CacheConfig] = None,
54+
attn_type: Optional[str] = None,
55+
**kwargs):
56+
dtype = torch.get_default_dtype()
57+
58+
if cache_config is not None:
59+
kv_cache_dtype = cache_config.cache_dtype
60+
block_size = cache_config.block_size
61+
else:
62+
kv_cache_dtype = "auto"
63+
block_size = 16
64+
65+
if envs.VLLM_USE_V1:
66+
underlying_attn_backend = get_attn_backend(head_size, dtype,
67+
kv_cache_dtype,
68+
block_size)
69+
70+
attn_backend = create_encoder_only_attention_backend(
71+
underlying_attn_backend)
72+
else:
73+
# in v0 encoder only attention is handled inside the backends
74+
attn_backend = None
75+
76+
if attn_type is not None:
77+
assert attn_type == AttentionType.ENCODER_ONLY, \
78+
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
79+
80+
super().__init__(num_heads=num_heads,
81+
head_size=head_size,
82+
scale=scale,
83+
cache_config=cache_config,
84+
attn_backend=attn_backend,
85+
attn_type=AttentionType.ENCODER_ONLY,
86+
**kwargs)

vllm/model_executor/models/bert.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch import nn
99
from transformers import BertConfig
1010

11-
from vllm.attention import Attention, AttentionType
11+
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
1212
from vllm.compilation.decorators import support_torch_compile
1313
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
1414
from vllm.distributed import get_tensor_model_parallel_world_size
@@ -239,14 +239,13 @@ def __init__(
239239
quant_config=quant_config,
240240
prefix=f"{prefix}.qkv_proj")
241241

242-
self.attn = Attention(num_heads=self.num_heads,
243-
head_size=self.head_dim,
244-
scale=self.scaling,
245-
num_kv_heads=self.num_kv_heads,
246-
cache_config=cache_config,
247-
quant_config=quant_config,
248-
prefix=f"{prefix}.attn",
249-
attn_type=AttentionType.ENCODER_ONLY)
242+
self.attn = EncoderOnlyAttention(num_heads=self.num_heads,
243+
head_size=self.head_dim,
244+
scale=self.scaling,
245+
num_kv_heads=self.num_kv_heads,
246+
cache_config=cache_config,
247+
quant_config=quant_config,
248+
prefix=f"{prefix}.attn")
250249

251250
def forward(
252251
self,

vllm/model_executor/models/bert_with_rope.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch import nn
88
from transformers import PretrainedConfig
99

10-
from vllm.attention import Attention, AttentionType
10+
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
1111
from vllm.compilation.decorators import support_torch_compile
1212
from vllm.config import CacheConfig, VllmConfig
1313
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
@@ -119,14 +119,13 @@ def __init__(
119119

120120
self.rotary_emb = get_rope(**rotary_kwargs)
121121

122-
self.attn = Attention(num_heads=self.num_heads,
123-
head_size=self.head_dim,
124-
scale=self.scaling,
125-
num_kv_heads=self.num_kv_heads,
126-
cache_config=cache_config,
127-
quant_config=quant_config,
128-
prefix=f"{prefix}.attn",
129-
attn_type=AttentionType.ENCODER_ONLY)
122+
self.attn = EncoderOnlyAttention(num_heads=self.num_heads,
123+
head_size=self.head_dim,
124+
scale=self.scaling,
125+
num_kv_heads=self.num_kv_heads,
126+
cache_config=cache_config,
127+
quant_config=quant_config,
128+
prefix=f"{prefix}.attn")
130129

131130
self.out_proj = RowParallelLinear(input_size=hidden_size,
132131
output_size=hidden_size,

vllm/model_executor/models/llama.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from transformers import LlamaConfig
3232

3333
from vllm.attention import Attention, AttentionType
34+
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
3435
from vllm.compilation.decorators import support_torch_compile
3536
from vllm.config import CacheConfig, VllmConfig
3637
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@@ -173,7 +174,10 @@ def __init__(
173174
if is_sliding:
174175
sliding_window = config.sliding_window
175176

176-
self.attn = Attention(
177+
attn_cls = (EncoderOnlyAttention
178+
if attn_type == AttentionType.ENCODER_ONLY else Attention)
179+
180+
self.attn = attn_cls(
177181
self.num_heads,
178182
self.head_dim,
179183
self.scaling,

vllm/model_executor/models/modernbert.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch import nn
88
from transformers import ModernBertConfig
99

10-
from vllm.attention import Attention, AttentionType
10+
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
1111
from vllm.compilation.decorators import support_torch_compile
1212
from vllm.config import VllmConfig
1313
from vllm.distributed import get_tensor_model_parallel_world_size
@@ -104,12 +104,12 @@ def __init__(self,
104104
head_size=self.head_dim,
105105
dim=self.head_dim,
106106
base=rope_theta)
107-
self.attn = Attention(self.num_heads,
108-
self.head_dim,
109-
self.scaling,
110-
prefix=f"{layer_id}.attn",
111-
attn_type=AttentionType.ENCODER_ONLY,
112-
per_layer_sliding_window=sliding_window)
107+
self.attn = EncoderOnlyAttention(
108+
self.num_heads,
109+
self.head_dim,
110+
self.scaling,
111+
prefix=f"{layer_id}.attn",
112+
per_layer_sliding_window=sliding_window)
113113
self.Wo = RowParallelLinear(config.hidden_size,
114114
config.hidden_size,
115115
bias=config.attention_bias)

vllm/model_executor/models/qwen2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from transformers import Qwen2Config
3333

3434
from vllm.attention import Attention, AttentionType
35+
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
3536
from vllm.compilation.decorators import support_torch_compile
3637
from vllm.config import CacheConfig, VllmConfig
3738
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@@ -159,7 +160,9 @@ def __init__(
159160
rope_scaling=rope_scaling,
160161
dual_chunk_attention_config=dual_chunk_attention_config,
161162
)
162-
self.attn = Attention(
163+
attn_cls = (EncoderOnlyAttention
164+
if attn_type == AttentionType.ENCODER_ONLY else Attention)
165+
self.attn = attn_cls(
163166
self.num_heads,
164167
self.head_dim,
165168
self.scaling,

vllm/v1/attention/backends/utils.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
import functools
66
from abc import abstractmethod
77
from dataclasses import dataclass, make_dataclass
8-
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional,
9-
TypeVar)
8+
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
109

1110
import numpy as np
1211
import torch
@@ -543,35 +542,6 @@ def make_local_attention_virtual_batches(
543542
)
544543

545544

546-
def subclass_attention_metadata_builder(
547-
name_prefix: str,
548-
builder_cls: type[AttentionMetadataBuilder[M]],
549-
build_preprocess_fn: Callable[[CommonAttentionMetadata],
550-
CommonAttentionMetadata],
551-
) -> type[AttentionMetadataBuilder[M]]:
552-
"""
553-
Return a new subclass of `builder_cls` whose .build(...) method
554-
first calls build_preprocess_fn(common_attn_metadata) on the metadata.
555-
"""
556-
name: str = name_prefix + builder_cls.__name__ # type: ignore
557-
558-
def build(self,
559-
common_prefix_len: int,
560-
common_attn_metadata: CommonAttentionMetadata,
561-
fast_build: bool = False):
562-
return builder_cls.build(self, common_prefix_len,
563-
build_preprocess_fn(common_attn_metadata),
564-
fast_build)
565-
566-
Wrapped = type(
567-
name,
568-
(builder_cls, ), # inherit from the original
569-
{
570-
"build": build,
571-
})
572-
return Wrapped # type: ignore
573-
574-
575545
def subclass_attention_backend(
576546
name_prefix: str, attention_backend_cls: type[AttentionBackend],
577547
builder_cls: type[AttentionMetadataBuilder[M]]

vllm/v1/kv_cache_interface.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,14 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
203203
return self.page_size_bytes
204204

205205

206+
@dataclass(frozen=True)
207+
class EncoderOnlyAttentionSpec(AttentionSpec):
208+
209+
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
210+
# Encoder-only layers do not need KV cache
211+
return 0
212+
213+
206214
@dataclass
207215
class KVCacheTensor:
208216
"""

0 commit comments

Comments
 (0)