Skip to content

Commit 9de664c

Browse files
MatthewBonannixuebwang-amd
authored andcommitted
[Attention] Implement universal BACKEND_MAP (vllm-project#25900)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent b58c30d commit 9de664c

File tree

12 files changed

+119
-75
lines changed

12 files changed

+119
-75
lines changed

tests/kernels/attention/test_attention_selector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def clear_cache():
3434

3535
DEVICE_REGULAR_ATTN_BACKENDS = {
3636
"cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"],
37-
"hip": ["ROCM_FLASH"],
37+
"hip": ["ROCM_ATTN"],
3838
"cpu": ["TORCH_SDPA"],
3939
}
4040

@@ -122,7 +122,7 @@ def test_env(
122122
backend = get_attn_backend(
123123
16, torch.float16, None, block_size, use_mla=use_mla
124124
)
125-
expected = "TRITON_ATTN"
125+
expected = "ROCM_ATTN"
126126
assert backend.get_name() == expected
127127

128128
elif device == "cuda":

tests/kernels/attention/test_rocm_attention_selector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def clear_cache():
1818
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
1919
def test_selector(monkeypatch: pytest.MonkeyPatch):
2020
with monkeypatch.context() as m:
21-
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH")
21+
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_ATTN")
2222

2323
# Set the current platform to ROCm using monkeypatch
2424
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())

tests/v1/attention/test_attention_backends.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
create_common_attn_metadata,
1515
create_standard_kv_cache_spec,
1616
create_vllm_config,
17-
get_attention_backend,
17+
try_get_attention_backend,
1818
)
1919
from vllm.attention.backends.registry import _Backend
2020
from vllm.config import ModelConfig
@@ -214,7 +214,7 @@ def run_attention_backend(
214214
actual_backend = _Backend.FLEX_ATTENTION
215215
use_direct_block_mask = False
216216

217-
builder_cls, impl_cls = get_attention_backend(actual_backend)
217+
builder_cls, impl_cls = try_get_attention_backend(actual_backend)
218218

219219
# Mock flashinfer's get_per_layer_parameters if needed
220220
if actual_backend == _Backend.FLASHINFER:

tests/v1/attention/test_mla_backends.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
create_common_attn_metadata,
1313
create_standard_kv_cache_spec,
1414
create_vllm_config,
15-
get_attention_backend,
15+
try_get_attention_backend,
1616
)
1717
from vllm import _custom_ops as ops
1818
from vllm.attention.backends.registry import _Backend
@@ -239,7 +239,7 @@ def run_attention_backend(
239239
) -> torch.Tensor:
240240
"""Run attention computation using the specified backend's AttentionImpl."""
241241

242-
builder_cls, impl_cls = get_attention_backend(backend)
242+
builder_cls, impl_cls = try_get_attention_backend(backend)
243243

244244
# Build metadata
245245
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
@@ -400,7 +400,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
400400
# Determine if this is decode or prefill
401401
is_decode = []
402402
for i, backend in enumerate(BACKENDS_TO_TEST):
403-
builder_cls, _ = get_attention_backend(backend)
403+
builder_cls, _ = try_get_attention_backend(backend)
404404
is_decode.append(q_len <= builder_cls.reorder_batch_threshold)
405405

406406
# Split q into nope and rope components

tests/v1/attention/utils.py

Lines changed: 14 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import pytest
99
import torch
1010

11-
from vllm.attention.backends.registry import _Backend
11+
from vllm.attention.backends.abstract import AttentionImpl
12+
from vllm.attention.backends.registry import _Backend, backend_to_class_str
1213
from vllm.config import (
1314
CacheConfig,
1415
CompilationConfig,
@@ -20,9 +21,11 @@
2021
VllmConfig,
2122
)
2223
from vllm.config.model import ModelDType
23-
from vllm.platforms import current_platform
2424
from vllm.utils import resolve_obj_by_qualname
25-
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
25+
from vllm.v1.attention.backends.utils import (
26+
AttentionMetadataBuilder,
27+
CommonAttentionMetadata,
28+
)
2629
from vllm.v1.kv_cache_interface import FullAttentionSpec
2730

2831

@@ -117,44 +120,17 @@ def create_common_attn_metadata(
117120
)
118121

119122

120-
def get_attention_backend(backend_name: _Backend):
121-
"""Set up attention backend classes for testing.
122-
123-
Args:
124-
backend_name: Name of the backend ("flash_attn", "flashinfer", etc.)
125-
vllm_config: VllmConfig instance
126-
127-
Returns:
128-
Tuple of (backend_builder_class, backend_impl_class)
129-
"""
130-
backend_map = {
131-
_Backend.FLASH_ATTN: (
132-
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
133-
if current_platform.is_cuda()
134-
else "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
135-
),
136-
_Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend",
137-
_Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501
138-
_Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501
139-
_Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
140-
_Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", # noqa: E501
141-
_Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501
142-
_Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
143-
_Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501
144-
_Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", # noqa: E501
145-
_Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501
146-
}
147-
148-
if backend_name not in backend_map:
149-
raise ValueError(f"Unknown backend: {backend_name}")
150-
151-
backend_class_name = backend_map[backend_name]
152-
123+
def try_get_attention_backend(
124+
backend: _Backend,
125+
) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]:
126+
"""Try to get the attention backend class, skipping test if not found."""
127+
backend_class_str = backend_to_class_str(backend)
153128
try:
154-
backend_class = resolve_obj_by_qualname(backend_class_name)
129+
backend_class = resolve_obj_by_qualname(backend_class_str)
155130
return backend_class.get_builder_cls(), backend_class.get_impl_cls()
156131
except ImportError as e:
157-
pytest.skip(f"{backend_name} not available: {e}")
132+
pytest.skip(f"{backend_class_str} not available: {e}")
133+
raise AssertionError("unreachable") from None
158134

159135

160136
def create_standard_kv_cache_spec(vllm_config: VllmConfig) -> FullAttentionSpec:

tests/v1/spec_decode/test_eagle.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
BatchSpec,
1313
create_common_attn_metadata,
1414
create_standard_kv_cache_spec,
15-
get_attention_backend,
15+
try_get_attention_backend,
1616
)
1717
from vllm.attention.backends.registry import _Backend
1818
from vllm.config import (
@@ -535,11 +535,11 @@ def create_deterministic_logits(token_ids):
535535
sampling_metadata = mock.MagicMock()
536536

537537
if attn_backend == "FLASH_ATTN":
538-
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN)
538+
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN)
539539
elif attn_backend == "TRITON_ATTN":
540-
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TRITON_ATTN)
540+
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TRITON_ATTN)
541541
elif attn_backend == "TREE_ATTN":
542-
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN)
542+
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN)
543543
else:
544544
raise ValueError(f"Unsupported attention backend: {attn_backend}")
545545

@@ -674,7 +674,7 @@ def create_deterministic_logits(token_ids, k: int):
674674
proposer.attn_layer_names = ["layer.0"]
675675

676676
# Get the tree attention metadata builder.
677-
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN)
677+
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN)
678678
attn_metadata_builder = attn_metadata_builder_cls(
679679
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
680680
layer_names=proposer.attn_layer_names,

tests/v1/spec_decode/test_mtp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
BatchSpec,
1111
create_common_attn_metadata,
1212
create_standard_kv_cache_spec,
13-
get_attention_backend,
13+
try_get_attention_backend,
1414
)
1515
from vllm.attention.backends.registry import _Backend
1616
from vllm.config import (
@@ -177,7 +177,7 @@ def create_deterministic_logits(batch_size, vocab_size, token_offset):
177177
sampling_metadata = mock.MagicMock()
178178

179179
# Setup attention metadata
180-
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN)
180+
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN)
181181

182182
attn_metadata_builder = attn_metadata_builder_cls(
183183
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),

tests/v1/spec_decode/test_tree_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tests.v1.attention.utils import (
1010
create_standard_kv_cache_spec,
1111
create_vllm_config,
12-
get_attention_backend,
12+
try_get_attention_backend,
1313
)
1414
from vllm.attention.backends.registry import _Backend
1515
from vllm.config import ParallelConfig, SpeculativeConfig
@@ -63,7 +63,7 @@ def forward_attention(
6363

6464
# Build common metadata.
6565
model_name = "meta-llama/Meta-Llama-3-8B"
66-
builder_cls, impl_cls = get_attention_backend(backend)
66+
builder_cls, impl_cls = try_get_attention_backend(backend)
6767
vllm_config = create_vllm_config(model_name=model_name, max_model_len=max(seq_lens))
6868
if spec_token_tree is not None:
6969
# Create speculative config if token tree is specified.

vllm/attention/backends/registry.py

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
"""Attention backend registry"""
44

55
import enum
6+
from typing import Optional
7+
8+
from vllm.utils import resolve_obj_by_qualname
69

710

811
class _Backend(enum.Enum):
912
FLASH_ATTN = enum.auto()
1013
TRITON_ATTN = enum.auto()
1114
XFORMERS = enum.auto()
12-
ROCM_FLASH = enum.auto()
15+
ROCM_ATTN = enum.auto()
1316
ROCM_AITER_MLA = enum.auto()
1417
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
1518
TORCH_SDPA = enum.auto()
@@ -24,5 +27,83 @@ class _Backend(enum.Enum):
2427
NO_ATTENTION = enum.auto()
2528
FLEX_ATTENTION = enum.auto()
2629
TREE_ATTN = enum.auto()
27-
ROCM_ATTN = enum.auto()
2830
ROCM_AITER_UNIFIED_ATTN = enum.auto()
31+
32+
33+
BACKEND_MAP = {
34+
_Backend.FLASH_ATTN: "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend", # noqa: E501
35+
_Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501
36+
_Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", # noqa: E501
37+
_Backend.ROCM_ATTN: "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend", # noqa: E501
38+
_Backend.ROCM_AITER_MLA: "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend", # noqa: E501
39+
_Backend.ROCM_AITER_FA: "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend", # noqa: E501
40+
_Backend.TORCH_SDPA: "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend", # noqa: E501
41+
_Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend", # noqa: E501
42+
_Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", # noqa: E501
43+
_Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501
44+
_Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501
45+
_Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", # noqa: E501
46+
_Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501
47+
_Backend.PALLAS: "vllm.v1.attention.backends.pallas.PallasAttentionBackend", # noqa: E501
48+
_Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501
49+
_Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", # noqa: E501
50+
_Backend.ROCM_AITER_UNIFIED_ATTN: "vllm.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend", # noqa: E501
51+
}
52+
53+
54+
def register_attn_backend(backend: _Backend, class_path: Optional[str] = None):
55+
"""
56+
Decorator: register a custom attention backend into BACKEND_MAPPING.
57+
- If class_path is provided, use it.
58+
- Otherwise, auto-generate from the class object.
59+
Validation: only checks if 'backend' is a valid _Backend enum member.
60+
Overwriting existing mappings is allowed. This enables other hardware
61+
platforms to plug in custom out-of-tree backends.
62+
"""
63+
if not isinstance(backend, _Backend):
64+
raise ValueError(f"{backend} is not a valid _Backend enum value.")
65+
66+
def decorator(cls):
67+
path = class_path or f"{cls.__module__}.{cls.__qualname__}"
68+
BACKEND_MAP[backend] = path
69+
return cls
70+
71+
return decorator
72+
73+
74+
def backend_to_class_str(backend: _Backend) -> str:
75+
"""Get the backend class string
76+
77+
Args:
78+
backend: The backend enum value
79+
80+
Returns:
81+
The backend class string
82+
"""
83+
return BACKEND_MAP[backend]
84+
85+
86+
def backend_to_class(backend: _Backend) -> type:
87+
"""Get the backend class.
88+
89+
Args:
90+
backend: The backend enum value
91+
92+
Returns:
93+
The backend class
94+
"""
95+
backend_class_name = backend_to_class_str(backend)
96+
return resolve_obj_by_qualname(backend_class_name)
97+
98+
99+
def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
100+
"""
101+
Convert a string backend name to a _Backend enum value.
102+
103+
Returns:
104+
_Backend: enum value if backend_name is a valid in-tree type
105+
None: otherwise it's an invalid in-tree type or an out-of-tree platform
106+
is loaded.
107+
"""
108+
assert backend_name is not None
109+
return _Backend[backend_name] if backend_name in _Backend.__members__ else None

vllm/attention/layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
import vllm.envs as envs
1212
from vllm.attention import AttentionType
1313
from vllm.attention.backends.abstract import AttentionBackend
14-
from vllm.attention.backends.registry import _Backend
15-
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
14+
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
15+
from vllm.attention.selector import get_attn_backend
1616
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
1717
from vllm.config import CacheConfig, get_current_vllm_config
1818
from vllm.distributed.kv_transfer import (

0 commit comments

Comments
 (0)