Skip to content

Commit 1dc8a70

Browse files
[Attention] Support multiple attention metadata builders per kv_cache_spec + proper local attention no hybrid kv cache fix (#21588)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent f825c6b commit 1dc8a70

File tree

13 files changed

+368
-212
lines changed

13 files changed

+368
-212
lines changed

tests/v1/spec_decode/test_eagle.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,8 @@ def create_deterministic_logits(token_ids):
313313

314314
# Mock runner for attention metadata building
315315
proposer.runner = mock.MagicMock()
316-
proposer.runner.attn_metadata_builders = [attn_metadata_builder]
316+
proposer.runner.attn_groups.append([mock.MagicMock()])
317+
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder
317318

318319
result = proposer.propose(target_token_ids=target_token_ids,
319320
target_positions=target_positions,

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -417,12 +417,12 @@ def rnd_stride_order():
417417
return rnd_stride
418418

419419
# Patch the attention backend class and re-trigger the KV cache creation.
420-
for attn_backend in model_runner.attn_backends:
420+
for attn_group in model_runner._attn_group_iterator():
421+
attn_backend = attn_group.backend
421422
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
422423
rnd_stride_order)
423424

424-
model_runner.attn_backends = []
425-
model_runner.attn_metadata_builders = []
425+
model_runner.attn_groups = []
426426
model_runner.initialize_kv_cache(model_runner.kv_cache_config)
427427

428428
# Shape is unchanged, but layout may differ

vllm/attention/backends/abstract.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ def advance_step(self, model_input: "ModelRunnerInputBase",
106106
block_size: int, num_seqs: int, num_queries: int) -> None:
107107
raise NotImplementedError
108108

109+
@classmethod
110+
def full_cls_name(cls) -> tuple[str, str]:
111+
return (cls.__module__, cls.__qualname__)
112+
109113

110114
@dataclass
111115
class AttentionMetadata:

vllm/attention/layer.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import vllm.envs as envs
1111
from vllm.attention import AttentionType
12+
from vllm.attention.backends.abstract import AttentionBackend
1213
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
1314
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
1415
from vllm.config import CacheConfig, get_current_vllm_config
@@ -80,6 +81,7 @@ def __init__(
8081
prefix: str = "",
8182
attn_type: str = AttentionType.DECODER,
8283
kv_sharing_target_layer_name: Optional[str] = None,
84+
attn_backend: Optional[type[AttentionBackend]] = None,
8385
**extra_impl_args,
8486
) -> None:
8587
"""
@@ -137,15 +139,6 @@ def __init__(
137139
self.num_kv_heads = num_kv_heads
138140
self.sliding_window = sliding_window
139141

140-
# For v1 we have backend agnostic iRoPE (local chunked attention)
141-
# we have to store the flag on the layer so gpu model runner can
142-
# set KVSpec appropriately (and pop it so it doesnt get passed to
143-
# the backends)
144-
if envs.VLLM_USE_V1:
145-
self.use_irope = extra_impl_args.pop("use_irope", False)
146-
else:
147-
self.use_irope = extra_impl_args.get("use_irope", False)
148-
149142
quant_method = quant_config.get_quant_method(
150143
self, prefix=prefix) if quant_config else None
151144
if quant_method is not None and not isinstance(
@@ -166,18 +159,22 @@ def __init__(
166159
# During model initialization, the default dtype is set as the model
167160
# weight and activation dtype.
168161
dtype = torch.get_default_dtype()
169-
attn_backend = get_attn_backend(head_size,
170-
dtype,
171-
kv_cache_dtype,
172-
block_size,
173-
is_attention_free,
174-
use_mla=use_mla)
175-
impl_cls = attn_backend.get_impl_cls()
162+
if attn_backend is None:
163+
self.attn_backend = get_attn_backend(head_size,
164+
dtype,
165+
kv_cache_dtype,
166+
block_size,
167+
is_attention_free,
168+
use_mla=use_mla)
169+
else:
170+
self.attn_backend = attn_backend
171+
172+
impl_cls = self.attn_backend.get_impl_cls()
176173
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
177174
alibi_slopes, sliding_window, kv_cache_dtype,
178175
logits_soft_cap, attn_type,
179176
kv_sharing_target_layer_name, **extra_impl_args)
180-
self.backend = backend_name_to_enum(attn_backend.get_name())
177+
self.backend = backend_name_to_enum(self.attn_backend.get_name())
181178
self.dtype = dtype
182179

183180
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
@@ -187,7 +184,7 @@ def __init__(
187184
self.use_direct_call = not current_platform.is_cuda_alike(
188185
) and not current_platform.is_cpu()
189186

190-
self.use_output = attn_backend.accept_output_buffer
187+
self.use_output = self.attn_backend.accept_output_buffer
191188
compilation_config = get_current_vllm_config().compilation_config
192189
if prefix in compilation_config.static_forward_context:
193190
raise ValueError(f"Duplicate layer name: {prefix}")
@@ -309,6 +306,9 @@ def process_weights_after_loading(self, act_dtype: torch.dtype):
309306
if hasattr(self.impl, "process_weights_after_loading"):
310307
self.impl.process_weights_after_loading(act_dtype)
311308

309+
def get_attn_backend(self) -> type[AttentionBackend]:
310+
return self.attn_backend
311+
312312

313313
class MultiHeadAttention(nn.Module):
314314
"""Multi-headed attention without any cache, used for ViT."""
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import functools
4+
from typing import List, Optional
5+
6+
import torch
7+
8+
from vllm import envs
9+
from vllm.attention.backends.abstract import AttentionBackend
10+
from vllm.attention.selector import get_attn_backend
11+
from vllm.config import CacheConfig, QuantizationConfig
12+
from vllm.v1.attention.backends.utils import (
13+
CommonAttentionMetadata, make_local_attention_virtual_batches,
14+
subclass_attention_backend, subclass_attention_metadata_builder)
15+
16+
from ..layer import Attention
17+
18+
19+
@functools.lru_cache
20+
def create_chunked_local_attention_backend(
21+
underlying_attn_backend: AttentionBackend,
22+
attention_chunk_size: int,
23+
block_size: int,
24+
) -> type[AttentionBackend]:
25+
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
26+
27+
def build_preprocess_fn(cm: CommonAttentionMetadata):
28+
return make_local_attention_virtual_batches(attention_chunk_size, cm,
29+
block_size)
30+
31+
# Dynamically create a new attention backend that wraps the
32+
# underlying attention backend but applies
33+
# `make_local_attention_virtual_batches` before calling `build(...)`
34+
builder_cls = subclass_attention_metadata_builder(
35+
name_prefix=prefix,
36+
builder_cls=underlying_attn_backend.get_builder_cls(),
37+
build_preprocess_fn=build_preprocess_fn)
38+
attn_backend = subclass_attention_backend(
39+
name_prefix=prefix,
40+
attention_backend_cls=underlying_attn_backend,
41+
builder_cls=builder_cls)
42+
43+
return attn_backend
44+
45+
46+
class ChunkedLocalAttention(Attention):
47+
48+
def __init__(self,
49+
num_heads: int,
50+
head_size: int,
51+
scale: float,
52+
attention_chunk_size: int,
53+
num_kv_heads: Optional[int] = None,
54+
alibi_slopes: Optional[List[float]] = None,
55+
cache_config: Optional[CacheConfig] = None,
56+
quant_config: Optional[QuantizationConfig] = None,
57+
kv_sharing_target_layer_name: Optional[str] = None,
58+
prefix: str = ""):
59+
dtype = torch.get_default_dtype()
60+
if cache_config is not None:
61+
kv_cache_dtype = cache_config.cache_dtype
62+
block_size = cache_config.block_size
63+
else:
64+
kv_cache_dtype = "auto"
65+
block_size = 16
66+
67+
if envs.VLLM_USE_V1:
68+
underlying_attn_backend = get_attn_backend(head_size, dtype,
69+
kv_cache_dtype,
70+
block_size)
71+
72+
attn_backend = create_chunked_local_attention_backend(
73+
underlying_attn_backend, attention_chunk_size, block_size)
74+
else:
75+
# in v0 the local attention is handled inside the backends
76+
attn_backend = None
77+
78+
super().__init__(
79+
num_heads=num_heads,
80+
head_size=head_size,
81+
scale=scale,
82+
num_kv_heads=num_kv_heads,
83+
alibi_slopes=alibi_slopes,
84+
cache_config=cache_config,
85+
quant_config=quant_config,
86+
prefix=prefix,
87+
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
88+
attn_backend=attn_backend)

vllm/attention/selector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def get_attn_backend(
142142
dtype: torch.dtype,
143143
kv_cache_dtype: Optional[str],
144144
block_size: int,
145-
is_attention_free: bool,
145+
is_attention_free: bool = False,
146146
use_mla: bool = False,
147147
) -> type[AttentionBackend]:
148148
"""Selects which attention backend to use and lazily imports it."""

vllm/model_executor/models/llama4.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from transformers import Llama4TextConfig
2626

2727
from vllm.attention import Attention
28+
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
2829
from vllm.compilation.decorators import support_torch_compile
2930
from vllm.config import CacheConfig, VllmConfig
3031
from vllm.distributed import get_tensor_model_parallel_world_size
@@ -194,17 +195,18 @@ def __init__(self,
194195
is_neox_style=is_neox_style,
195196
) if not self.nope else None
196197

197-
self.attn = Attention(
198+
attn_cls = Attention if self.nope else ChunkedLocalAttention
199+
self.attn = attn_cls(
198200
self.num_heads,
199201
self.head_dim,
200202
self.scaling,
201203
num_kv_heads=self.num_kv_heads,
202204
cache_config=cache_config,
203205
quant_config=quant_config,
204-
per_layer_sliding_window=None,
205-
use_irope=not self.nope,
206206
prefix=f"{prefix}.attn",
207-
)
207+
**({
208+
"attention_chunk_size": config.attention_chunk_size
209+
} if not self.nope else {}))
208210

209211
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
210212
floor = torch.floor((positions + 1.0) / self.floor_scale)

vllm/v1/attention/backends/utils.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import functools
66
from abc import abstractmethod
77
from dataclasses import dataclass, make_dataclass
8-
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
8+
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional,
9+
TypeVar)
910

1011
import numpy as np
1112
import torch
1213

13-
from vllm.attention.layer import Attention
1414
from vllm.config import VllmConfig, get_layers_from_vllm_config
1515
from vllm.utils import cdiv
1616

@@ -20,6 +20,8 @@
2020
from vllm.v1.worker.gpu_input_batch import InputBatch
2121

2222
import vllm.envs as envs
23+
from vllm.attention.backends.abstract import AttentionBackend
24+
from vllm.attention.layer import Attention
2325
from vllm.distributed.kv_transfer.kv_connector.utils import (
2426
get_kv_connector_cache_layout)
2527
from vllm.logger import init_logger
@@ -532,6 +534,48 @@ def make_local_attention_virtual_batches(
532534
)
533535

534536

537+
def subclass_attention_metadata_builder(
538+
name_prefix: str,
539+
builder_cls: type[AttentionMetadataBuilder[M]],
540+
build_preprocess_fn: Callable[[CommonAttentionMetadata],
541+
CommonAttentionMetadata],
542+
) -> type[AttentionMetadataBuilder[M]]:
543+
"""
544+
Return a new subclass of `builder_cls` whose .build(...) method
545+
first calls build_preprocess_fn(common_attn_metadata) on the metadata.
546+
"""
547+
name: str = name_prefix + builder_cls.__name__ # type: ignore
548+
549+
def build(self,
550+
common_prefix_len: int,
551+
common_attn_metadata: CommonAttentionMetadata,
552+
fast_build: bool = False):
553+
return builder_cls.build(self, common_prefix_len,
554+
build_preprocess_fn(common_attn_metadata),
555+
fast_build)
556+
557+
Wrapped = type(
558+
name,
559+
(builder_cls, ), # inherit from the original
560+
{
561+
"build": build,
562+
})
563+
return Wrapped # type: ignore
564+
565+
566+
def subclass_attention_backend(
567+
name_prefix: str, attention_backend_cls: type[AttentionBackend],
568+
builder_cls: type[AttentionMetadataBuilder[M]]
569+
) -> type[AttentionBackend]:
570+
"""
571+
Return a new subclass where `get_builder_cls` returns `builder_cls`.
572+
"""
573+
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
574+
575+
return type(name, (attention_backend_cls, ),
576+
{"get_builder_cls": lambda: builder_cls})
577+
578+
535579
def split_decodes_and_prefills(
536580
common_attn_metadata: CommonAttentionMetadata,
537581
decode_threshold: int = 1,

vllm/v1/spec_decode/eagle.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,9 @@ def propose(
158158
assert self.runner is not None
159159

160160
# FIXME: need to consider multiple kv_cache_groups
161-
attn_metadata = self.runner.attn_metadata_builders[
162-
0].build_for_drafting(common_attn_metadata=common_attn_metadata,
163-
draft_index=0)
161+
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
162+
.build_for_drafting(common_attn_metadata=common_attn_metadata,
163+
draft_index=0)
164164

165165
# At this moment, we assume all eagle layers belong to the same KV
166166
# cache group, thus using the same attention metadata.
@@ -349,7 +349,8 @@ def propose_tree(
349349
hidden_states: torch.Tensor,
350350
common_attn_metadata: CommonAttentionMetadata,
351351
) -> list[torch.Tensor]:
352-
tree_attn_metadata_builder = self.runner.attn_metadata_builders[0]
352+
tree_attn_metadata_builder = \
353+
self.runner.attn_groups[0][0].metadata_builder
353354
assert isinstance(tree_attn_metadata_builder,
354355
TreeAttentionMetadataBuilder)
355356

vllm/v1/worker/cpu_model_runner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
5353
raise ValueError("Multiple KVCacheGroups is not"
5454
"currently supported with CPU model runner.")
5555

56-
assert type(
57-
self.attn_metadata_builders[0]) is TorchSDPAMetadataBuilderV1
56+
assert type(self.attn_groups[0]
57+
[0].metadata_builder) is TorchSDPAMetadataBuilderV1
5858

59-
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
60-
scheduler_output)
59+
self.attn_groups[0][0].metadata_builder.reorder_batch(
60+
self.input_batch, scheduler_output)
6161

6262
def _postprocess_tenosrs(self) -> None:
6363
# Note: replace device tensors with cpu tensors

0 commit comments

Comments
 (0)