|
8 | 8 | import pytest |
9 | 9 | import torch |
10 | 10 |
|
11 | | -from vllm.attention.backends.registry import _Backend |
| 11 | +from vllm.attention.backends.registry import _Backend, backend_to_class_str |
12 | 12 | from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig, |
13 | 13 | LoadConfig, ModelConfig, ModelDType, ParallelConfig, |
14 | 14 | SchedulerConfig, VllmConfig) |
15 | | -from vllm.platforms import current_platform |
16 | 15 | from vllm.utils import resolve_obj_by_qualname |
17 | 16 | from vllm.v1.attention.backends.utils import CommonAttentionMetadata |
18 | 17 | from vllm.v1.kv_cache_interface import FullAttentionSpec |
@@ -110,54 +109,15 @@ def create_common_attn_metadata( |
110 | 109 | ) |
111 | 110 |
|
112 | 111 |
|
113 | | -def get_attention_backend(backend_name: _Backend): |
114 | | - """Set up attention backend classes for testing. |
115 | | -
|
116 | | - Args: |
117 | | - backend_name: Name of the backend ("flash_attn", "flashinfer", etc.) |
118 | | - vllm_config: VllmConfig instance |
119 | | -
|
120 | | - Returns: |
121 | | - Tuple of (backend_builder_class, backend_impl_class) |
122 | | - """ |
123 | | - backend_map = { |
124 | | - _Backend.FLASH_ATTN: |
125 | | - ("vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" |
126 | | - if current_platform.is_cuda() else |
127 | | - "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" |
128 | | - ), |
129 | | - _Backend.FLASHINFER: |
130 | | - "vllm.v1.attention.backends.flashinfer.FlashInferBackend", |
131 | | - _Backend.FLEX_ATTENTION: |
132 | | - "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", |
133 | | - _Backend.TRITON_ATTN: |
134 | | - "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", |
135 | | - _Backend.TREE_ATTN: |
136 | | - "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", |
137 | | - _Backend.XFORMERS: |
138 | | - "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", |
139 | | - _Backend.CUTLASS_MLA: |
140 | | - "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", |
141 | | - _Backend.FLASHMLA: |
142 | | - "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", |
143 | | - _Backend.FLASH_ATTN_MLA: |
144 | | - "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", |
145 | | - _Backend.FLASHINFER_MLA: |
146 | | - "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", |
147 | | - _Backend.TRITON_MLA: |
148 | | - "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", |
149 | | - } |
150 | | - |
151 | | - if backend_name not in backend_map: |
152 | | - raise ValueError(f"Unknown backend: {backend_name}") |
153 | | - |
154 | | - backend_class_name = backend_map[backend_name] |
155 | | - |
| 112 | +def try_get_attention_backend(backend: _Backend) -> tuple[type, type]: |
| 113 | + """Try to get the attention backend class, skipping test if not found.""" |
| 114 | + backend_class_str = backend_to_class_str(backend) |
156 | 115 | try: |
157 | | - backend_class = resolve_obj_by_qualname(backend_class_name) |
| 116 | + backend_class = resolve_obj_by_qualname(backend_class_str) |
158 | 117 | return backend_class.get_builder_cls(), backend_class.get_impl_cls() |
159 | 118 | except ImportError as e: |
160 | | - pytest.skip(f"{backend_name} not available: {e}") |
| 119 | + pytest.skip(f"{backend_class_str} not available: {e}") |
| 120 | + assert False # unreachable -- satisfies mypy |
161 | 121 |
|
162 | 122 |
|
163 | 123 | def create_standard_kv_cache_spec( |
|
0 commit comments