@@ -31,7 +31,7 @@ def clear_cache():
3131}
3232
3333DEVICE_REGULAR_ATTN_BACKENDS = {
34- "cuda" : ["XFORMERS" , "FLASHINFER" ],
34+ "cuda" : ["XFORMERS" , "FLASHINFER" , "FLASH_ATTN" ],
3535 "hip" : ["ROCM_FLASH" ],
3636 "cpu" : ["TORCH_SDPA" ],
3737}
@@ -86,7 +86,7 @@ def test_env(
8686 with patch ("vllm.attention.selector.current_platform" ,
8787 CpuPlatform ()):
8888 backend = get_attn_backend (16 , torch .float16 , None , block_size )
89- assert backend .get_name () == "TORCH_SDPA_VLLM_V1 "
89+ assert backend .get_name () == "TORCH_SDPA "
9090
9191 elif device == "hip" :
9292 with patch ("vllm.attention.selector.current_platform" ,
@@ -125,15 +125,15 @@ def test_env(
125125 None ,
126126 block_size ,
127127 use_mla = use_mla )
128- expected = f" { name } _VLLM_V1"
128+ expected = name
129129 assert backend .get_name () == expected
130130 else :
131131 backend = get_attn_backend (16 ,
132132 torch .float16 ,
133133 None ,
134134 block_size ,
135135 use_mla = use_mla )
136- expected = "TRITON_ATTN_VLLM_V1 "
136+ expected = "TRITON_ATTN "
137137 assert backend .get_name () == expected
138138
139139 elif device == "cuda" :
@@ -160,7 +160,7 @@ def test_env(
160160 None ,
161161 block_size ,
162162 use_mla = use_mla )
163- expected = "CUTLASS_MLA_VLLM_V1 "
163+ expected = "CUTLASS_MLA "
164164 assert backend .get_name () == expected
165165 elif name == "FLASHINFER_MLA" :
166166 if block_size not in [32 , 64 ]:
@@ -193,7 +193,7 @@ def test_env(
193193 None ,
194194 block_size ,
195195 use_mla = use_mla )
196- expected = f" { name } _VLLM_V1"
196+ expected = name
197197 assert backend .get_name () == expected
198198 elif name == "FLASH_ATTN_MLA" :
199199 backend = get_attn_backend (16 ,
@@ -210,33 +210,32 @@ def test_env(
210210 None ,
211211 block_size ,
212212 use_mla = use_mla )
213- expected = "TRITON_MLA_VLLM_V1 "
213+ expected = "TRITON_MLA "
214214 assert backend .get_name () == expected
215215 elif name == "FLASHINFER" :
216216 backend = get_attn_backend (16 ,
217217 torch .float16 ,
218218 None ,
219219 block_size ,
220220 use_mla = use_mla )
221- expected = "FLASHINFER_VLLM_V1 "
221+ expected = "FLASHINFER "
222222 assert backend .get_name () == expected
223- else :
223+ elif name == "XFORMERS" :
224224 backend = get_attn_backend (32 ,
225225 torch .float16 ,
226226 None ,
227227 block_size ,
228228 use_mla = use_mla )
229- expected = "FLASH_ATTN_VLLM_V1 "
229+ expected = "XFORMERS "
230230 assert backend .get_name () == expected
231-
232- backend = get_attn_backend (16 ,
231+ elif name == "FLASH_ATTN" :
232+ backend = get_attn_backend (32 ,
233233 torch .float16 ,
234234 None ,
235235 block_size ,
236236 use_mla = use_mla )
237- assert backend .get_name () == "FLEX_ATTENTION" , (
238- "Should fallback to FlexAttention if head size is "
239- "not supported by FlashAttention" )
237+ expected = "FLASH_ATTN"
238+ assert backend .get_name () == expected
240239
241240
242241@pytest .mark .parametrize ("device" , ["cpu" , "cuda" ])
@@ -252,7 +251,7 @@ def test_fp32_fallback(
252251 with patch ("vllm.attention.selector.current_platform" ,
253252 CpuPlatform ()):
254253 backend = get_attn_backend (16 , torch .float32 , None , 16 )
255- assert backend .get_name () == "TORCH_SDPA_VLLM_V1 "
254+ assert backend .get_name () == "TORCH_SDPA "
256255
257256 elif device == "cuda" :
258257 with patch ("vllm.attention.selector.current_platform" ,
@@ -266,6 +265,9 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
266265 # TODO: When testing for v1, pipe in `use_v1` as an argument to
267266 # get_attn_backend
268267
268+ pytest .skip ("Skipping as current backend selector does not " \
269+ "handle fallbacks when a backend is set via env var." )
270+
269271 with monkeypatch .context () as m :
270272 m .setenv (STR_BACKEND_ENV_VAR , STR_FLASH_ATTN_VAL )
271273
0 commit comments