Skip to content

Commit 3ba013e

Browse files
attention backends register themselves in the mapping
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
1 parent 5f42fc5 commit 3ba013e

28 files changed

+129
-86
lines changed

tests/kernels/attention/test_attention_selector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def clear_cache():
3232

3333
DEVICE_REGULAR_ATTN_BACKENDS = {
3434
"cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"],
35-
"hip": ["ROCM_FLASH"],
35+
"hip": ["ROCM_AITER_FA"],
3636
"cpu": ["TORCH_SDPA"],
3737
}
3838

tests/kernels/attention/test_rocm_attention_selector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@ def clear_cache():
1919
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
2020
def test_selector(monkeypatch: pytest.MonkeyPatch):
2121
with monkeypatch.context() as m:
22-
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH")
22+
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_FA")
2323

2424
# Set the current platform to ROCm using monkeypatch
2525
monkeypatch.setattr("vllm.attention.selector.current_platform",
2626
RocmPlatform())
2727

2828
# Test standard ROCm attention
2929
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
30-
assert (backend.get_name() == "ROCM_FLASH"
31-
or backend.get_name() == "TRITON_ATTN")
30+
assert (backend.get_name() == "ROCM_AITER_FA"
31+
or backend.get_name() == "ROCM_ATTN")
3232

3333
# MLA test for deepseek related
3434

tests/v1/attention/test_attention_backends.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
1212
create_standard_kv_cache_spec,
1313
create_vllm_config,
14-
get_attention_backend)
14+
try_get_attention_backend)
1515
from vllm.attention.backends.registry import _Backend
1616
from vllm.config import ModelConfig
1717
from vllm.platforms import current_platform
@@ -210,7 +210,7 @@ def run_attention_backend(
210210
actual_backend = _Backend.FLEX_ATTENTION
211211
use_direct_block_mask = False
212212

213-
builder_cls, impl_cls = get_attention_backend(actual_backend)
213+
builder_cls, impl_cls = try_get_attention_backend(actual_backend)
214214

215215
# Mock flashinfer's get_per_layer_parameters if needed
216216
if actual_backend == _Backend.FLASHINFER:

tests/v1/attention/test_mla_backends.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
1010
create_standard_kv_cache_spec,
1111
create_vllm_config,
12-
get_attention_backend)
12+
try_get_attention_backend)
1313
from vllm import _custom_ops as ops
1414
from vllm.attention.backends.registry import _Backend
1515
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
@@ -232,7 +232,7 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
232232
mock_kv_b_proj) -> torch.Tensor:
233233
"""Run attention computation using the specified backend's AttentionImpl."""
234234

235-
builder_cls, impl_cls = get_attention_backend(backend)
235+
builder_cls, impl_cls = try_get_attention_backend(backend)
236236

237237
# Build metadata
238238
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
@@ -393,7 +393,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
393393
# Determine if this is decode or prefill
394394
is_decode = []
395395
for i, backend in enumerate(BACKENDS_TO_TEST):
396-
builder_cls, _ = get_attention_backend(backend)
396+
builder_cls, _ = try_get_attention_backend(backend)
397397
is_decode.append(q_len <= builder_cls.reorder_batch_threshold)
398398

399399
# Split q into nope and rope components

tests/v1/attention/utils.py

Lines changed: 7 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88
import pytest
99
import torch
1010

11-
from vllm.attention.backends.registry import _Backend
11+
from vllm.attention.backends.registry import _Backend, backend_to_class_str
1212
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
1313
LoadConfig, ModelConfig, ModelDType, ParallelConfig,
1414
SchedulerConfig, VllmConfig)
15-
from vllm.platforms import current_platform
1615
from vllm.utils import resolve_obj_by_qualname
1716
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
1817
from vllm.v1.kv_cache_interface import FullAttentionSpec
@@ -110,54 +109,15 @@ def create_common_attn_metadata(
110109
)
111110

112111

113-
def get_attention_backend(backend_name: _Backend):
114-
"""Set up attention backend classes for testing.
115-
116-
Args:
117-
backend_name: Name of the backend ("flash_attn", "flashinfer", etc.)
118-
vllm_config: VllmConfig instance
119-
120-
Returns:
121-
Tuple of (backend_builder_class, backend_impl_class)
122-
"""
123-
backend_map = {
124-
_Backend.FLASH_ATTN:
125-
("vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
126-
if current_platform.is_cuda() else
127-
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
128-
),
129-
_Backend.FLASHINFER:
130-
"vllm.v1.attention.backends.flashinfer.FlashInferBackend",
131-
_Backend.FLEX_ATTENTION:
132-
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",
133-
_Backend.TRITON_ATTN:
134-
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
135-
_Backend.TREE_ATTN:
136-
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
137-
_Backend.XFORMERS:
138-
"vllm.v1.attention.backends.xformers.XFormersAttentionBackend",
139-
_Backend.CUTLASS_MLA:
140-
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend",
141-
_Backend.FLASHMLA:
142-
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
143-
_Backend.FLASH_ATTN_MLA:
144-
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
145-
_Backend.FLASHINFER_MLA:
146-
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend",
147-
_Backend.TRITON_MLA:
148-
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
149-
}
150-
151-
if backend_name not in backend_map:
152-
raise ValueError(f"Unknown backend: {backend_name}")
153-
154-
backend_class_name = backend_map[backend_name]
155-
112+
def try_get_attention_backend(backend: _Backend) -> tuple[type, type]:
113+
"""Try to get the attention backend class, skipping test if not found."""
114+
backend_class_str = backend_to_class_str(backend)
156115
try:
157-
backend_class = resolve_obj_by_qualname(backend_class_name)
116+
backend_class = resolve_obj_by_qualname(backend_class_str)
158117
return backend_class.get_builder_cls(), backend_class.get_impl_cls()
159118
except ImportError as e:
160-
pytest.skip(f"{backend_name} not available: {e}")
119+
pytest.skip(f"{backend_class_str} not available: {e}")
120+
assert False # unreachable -- satisfies mypy
161121

162122

163123
def create_standard_kv_cache_spec(

tests/v1/spec_decode/test_eagle.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from tests.utils import get_attn_backend_list_based_on_platform
1111
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
1212
create_standard_kv_cache_spec,
13-
get_attention_backend)
13+
try_get_attention_backend)
1414
from vllm.attention.backends.registry import _Backend
1515
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
1616
ParallelConfig, SchedulerConfig, SpeculativeConfig,
@@ -515,13 +515,13 @@ def create_deterministic_logits(token_ids):
515515
sampling_metadata = mock.MagicMock()
516516

517517
if attn_backend == "FLASH_ATTN":
518-
attn_metadata_builder_cls, _ = get_attention_backend(
518+
attn_metadata_builder_cls, _ = try_get_attention_backend(
519519
_Backend.FLASH_ATTN)
520520
elif attn_backend == "TRITON_ATTN":
521-
attn_metadata_builder_cls, _ = get_attention_backend(
521+
attn_metadata_builder_cls, _ = try_get_attention_backend(
522522
_Backend.TRITON_ATTN)
523523
elif attn_backend == "TREE_ATTN":
524-
attn_metadata_builder_cls, _ = get_attention_backend(
524+
attn_metadata_builder_cls, _ = try_get_attention_backend(
525525
_Backend.TREE_ATTN)
526526
else:
527527
raise ValueError(f"Unsupported attention backend: {attn_backend}")
@@ -653,7 +653,8 @@ def create_deterministic_logits(token_ids, k: int):
653653
proposer.attn_layer_names = ["layer.0"]
654654

655655
# Get the tree attention metadata builder.
656-
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN)
656+
attn_metadata_builder_cls, _ = try_get_attention_backend(
657+
_Backend.TREE_ATTN)
657658
attn_metadata_builder = attn_metadata_builder_cls(
658659
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
659660
layer_names=proposer.attn_layer_names,

tests/v1/spec_decode/test_mtp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
1010
create_standard_kv_cache_spec,
11-
get_attention_backend)
11+
try_get_attention_backend)
1212
from vllm.attention.backends.registry import _Backend
1313
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
1414
ParallelConfig, SchedulerConfig, SpeculativeConfig,
@@ -174,7 +174,8 @@ def create_deterministic_logits(batch_size, vocab_size, token_offset):
174174
sampling_metadata = mock.MagicMock()
175175

176176
# Setup attention metadata
177-
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN)
177+
attn_metadata_builder_cls, _ = try_get_attention_backend(
178+
_Backend.FLASH_ATTN)
178179

179180
attn_metadata_builder = attn_metadata_builder_cls(
180181
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),

tests/v1/spec_decode/test_tree_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from tests.v1.attention.utils import (create_standard_kv_cache_spec,
1010
create_vllm_config,
11-
get_attention_backend)
11+
try_get_attention_backend)
1212
from vllm.attention.backends.registry import _Backend
1313
from vllm.config import ParallelConfig, SpeculativeConfig
1414
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
@@ -60,7 +60,7 @@ def forward_attention(
6060

6161
# Build common metadata.
6262
model_name = "meta-llama/Meta-Llama-3-8B"
63-
builder_cls, impl_cls = get_attention_backend(backend)
63+
builder_cls, impl_cls = try_get_attention_backend(backend)
6464
vllm_config = create_vllm_config(model_name=model_name,
6565
max_model_len=max(seq_lens))
6666
if spec_token_tree is not None:

vllm/attention/backends/registry.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
"""Attention backend registry"""
44

55
import enum
6+
from typing import Optional, Type
7+
8+
from vllm.utils import resolve_obj_by_qualname
69

710

811
class _Backend(enum.Enum):
912
FLASH_ATTN = enum.auto()
1013
TRITON_ATTN = enum.auto()
1114
XFORMERS = enum.auto()
12-
ROCM_FLASH = enum.auto()
15+
ROCM_ATTN = enum.auto()
1316
ROCM_AITER_MLA = enum.auto()
1417
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
1518
TORCH_SDPA = enum.auto()
@@ -24,4 +27,64 @@ class _Backend(enum.Enum):
2427
NO_ATTENTION = enum.auto()
2528
FLEX_ATTENTION = enum.auto()
2629
TREE_ATTN = enum.auto()
27-
ROCM_ATTN = enum.auto()
30+
31+
32+
BACKEND_MAPPING = {}
33+
34+
35+
def register_attn_backend(backend: _Backend, class_path: str | None = None):
36+
"""
37+
Decorator: register a custom attention backend into BACKEND_MAPPING.
38+
- If class_path is provided, use it.
39+
- Otherwise, auto-generate from the class object.
40+
Validation: only checks if 'backend' is a valid _Backend enum member.
41+
Overwriting existing mappings is allowed.
42+
"""
43+
if not isinstance(backend, _Backend):
44+
raise ValueError(f"{backend} is not a valid _Backend enum value.")
45+
46+
def decorator(cls):
47+
path = class_path or f"{cls.__module__}.{cls.__qualname__}"
48+
BACKEND_MAPPING[backend] = path
49+
return cls
50+
51+
return decorator
52+
53+
54+
def backend_to_class_str(backend: _Backend) -> str:
55+
"""Get the backend class string
56+
57+
Args:
58+
backend: The backend enum value
59+
60+
Returns:
61+
The backend class string
62+
"""
63+
return BACKEND_MAPPING[backend]
64+
65+
66+
def backend_to_class(backend: _Backend) -> Type:
67+
"""Get the backend class.
68+
69+
Args:
70+
backend: The backend enum value
71+
72+
Returns:
73+
The backend class
74+
"""
75+
backend_class_name = backend_to_class_str(backend)
76+
return resolve_obj_by_qualname(backend_class_name)
77+
78+
79+
def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
80+
"""
81+
Convert a string backend name to a _Backend enum value.
82+
83+
Returns:
84+
_Backend: enum value if backend_name is a valid in-tree type
85+
None: otherwise it's an invalid in-tree type or an out-of-tree platform
86+
is loaded.
87+
"""
88+
assert backend_name is not None
89+
return _Backend[backend_name] if backend_name in _Backend.__members__ else \
90+
None

vllm/attention/layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import vllm.envs as envs
1111
from vllm.attention import AttentionType
1212
from vllm.attention.backends.abstract import AttentionBackend
13-
from vllm.attention.backends.registry import _Backend
14-
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
13+
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
14+
from vllm.attention.selector import get_attn_backend
1515
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
1616
from vllm.config import CacheConfig, get_current_vllm_config
1717
from vllm.distributed.kv_transfer import (get_kv_transfer_group,

0 commit comments

Comments
 (0)