33"""Attention backend registry"""
44
55import enum
6+ from typing import Optional
7+
8+ from vllm .utils import resolve_obj_by_qualname
69
710
811class _Backend (enum .Enum ):
912 FLASH_ATTN = enum .auto ()
1013 TRITON_ATTN = enum .auto ()
1114 XFORMERS = enum .auto ()
12- ROCM_FLASH = enum .auto ()
15+ ROCM_ATTN = enum .auto ()
1316 ROCM_AITER_MLA = enum .auto ()
1417 ROCM_AITER_FA = enum .auto () # used for ViT attn backend
1518 TORCH_SDPA = enum .auto ()
@@ -24,5 +27,83 @@ class _Backend(enum.Enum):
2427 NO_ATTENTION = enum .auto ()
2528 FLEX_ATTENTION = enum .auto ()
2629 TREE_ATTN = enum .auto ()
27- ROCM_ATTN = enum .auto ()
2830 ROCM_AITER_UNIFIED_ATTN = enum .auto ()
31+
32+
33+ BACKEND_MAP = {
34+ _Backend .FLASH_ATTN : "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" , # noqa: E501
35+ _Backend .TRITON_ATTN : "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" , # noqa: E501
36+ _Backend .XFORMERS : "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" , # noqa: E501
37+ _Backend .ROCM_ATTN : "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" , # noqa: E501
38+ _Backend .ROCM_AITER_MLA : "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" , # noqa: E501
39+ _Backend .ROCM_AITER_FA : "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" , # noqa: E501
40+ _Backend .TORCH_SDPA : "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" , # noqa: E501
41+ _Backend .FLASHINFER : "vllm.v1.attention.backends.flashinfer.FlashInferBackend" , # noqa: E501
42+ _Backend .FLASHINFER_MLA : "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" , # noqa: E501
43+ _Backend .TRITON_MLA : "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" , # noqa: E501
44+ _Backend .CUTLASS_MLA : "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend" , # noqa: E501
45+ _Backend .FLASHMLA : "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend" , # noqa: E501
46+ _Backend .FLASH_ATTN_MLA : "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend" , # noqa: E501
47+ _Backend .PALLAS : "vllm.v1.attention.backends.pallas.PallasAttentionBackend" , # noqa: E501
48+ _Backend .FLEX_ATTENTION : "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" , # noqa: E501
49+ _Backend .TREE_ATTN : "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" , # noqa: E501
50+ _Backend .ROCM_AITER_UNIFIED_ATTN : "vllm.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend" , # noqa: E501
51+ }
52+
53+
54+ def register_attn_backend (backend : _Backend , class_path : Optional [str ] = None ):
55+ """
56+ Decorator: register a custom attention backend into BACKEND_MAPPING.
57+ - If class_path is provided, use it.
58+ - Otherwise, auto-generate from the class object.
59+ Validation: only checks if 'backend' is a valid _Backend enum member.
60+ Overwriting existing mappings is allowed. This enables other hardware
61+ platforms to plug in custom out-of-tree backends.
62+ """
63+ if not isinstance (backend , _Backend ):
64+ raise ValueError (f"{ backend } is not a valid _Backend enum value." )
65+
66+ def decorator (cls ):
67+ path = class_path or f"{ cls .__module__ } .{ cls .__qualname__ } "
68+ BACKEND_MAP [backend ] = path
69+ return cls
70+
71+ return decorator
72+
73+
74+ def backend_to_class_str (backend : _Backend ) -> str :
75+ """Get the backend class string
76+
77+ Args:
78+ backend: The backend enum value
79+
80+ Returns:
81+ The backend class string
82+ """
83+ return BACKEND_MAP [backend ]
84+
85+
86+ def backend_to_class (backend : _Backend ) -> type :
87+ """Get the backend class.
88+
89+ Args:
90+ backend: The backend enum value
91+
92+ Returns:
93+ The backend class
94+ """
95+ backend_class_name = backend_to_class_str (backend )
96+ return resolve_obj_by_qualname (backend_class_name )
97+
98+
99+ def backend_name_to_enum (backend_name : str ) -> Optional [_Backend ]:
100+ """
101+ Convert a string backend name to a _Backend enum value.
102+
103+ Returns:
104+ _Backend: enum value if backend_name is a valid in-tree type
105+ None: otherwise it's an invalid in-tree type or an out-of-tree platform
106+ is loaded.
107+ """
108+ assert backend_name is not None
109+ return _Backend [backend_name ] if backend_name in _Backend .__members__ else None
0 commit comments