@@ -31,7 +31,7 @@ def clear_cache():
3131}
3232
3333DEVICE_REGULAR_ATTN_BACKENDS  =  {
34-     "cuda" : ["XFORMERS" , "FLASHINFER" ],
34+     "cuda" : ["XFORMERS" , "FLASHINFER" ,  "FLASH_ATTN" ],
3535    "hip" : ["ROCM_FLASH" ],
3636    "cpu" : ["TORCH_SDPA" ],
3737}
@@ -86,7 +86,7 @@ def test_env(
8686            with  patch ("vllm.attention.selector.current_platform" ,
8787                       CpuPlatform ()):
8888                backend  =  get_attn_backend (16 , torch .float16 , None , block_size )
89-             assert  backend .get_name () ==  "TORCH_SDPA_VLLM_V1 " 
89+             assert  backend .get_name () ==  "TORCH_SDPA " 
9090
9191        elif  device  ==  "hip" :
9292            with  patch ("vllm.attention.selector.current_platform" ,
@@ -125,15 +125,15 @@ def test_env(
125125                                                   None ,
126126                                                   block_size ,
127127                                                   use_mla = use_mla )
128-                         expected  =  f" { name } _VLLM_V1" 
128+                         expected  =  name 
129129                        assert  backend .get_name () ==  expected 
130130                else :
131131                    backend  =  get_attn_backend (16 ,
132132                                               torch .float16 ,
133133                                               None ,
134134                                               block_size ,
135135                                               use_mla = use_mla )
136-                     expected  =  "TRITON_ATTN_VLLM_V1 " 
136+                     expected  =  "TRITON_ATTN " 
137137                    assert  backend .get_name () ==  expected 
138138
139139        elif  device  ==  "cuda" :
@@ -160,7 +160,7 @@ def test_env(
160160                                                       None ,
161161                                                       block_size ,
162162                                                       use_mla = use_mla )
163-                             expected  =  "CUTLASS_MLA_VLLM_V1 " 
163+                             expected  =  "CUTLASS_MLA " 
164164                            assert  backend .get_name () ==  expected 
165165                    elif  name  ==  "FLASHINFER_MLA" :
166166                        if  block_size  not  in   [32 , 64 ]:
@@ -193,7 +193,7 @@ def test_env(
193193                                                           None ,
194194                                                           block_size ,
195195                                                           use_mla = use_mla )
196-                                 expected  =  f" { name } _VLLM_V1" 
196+                                 expected  =  name 
197197                                assert  backend .get_name () ==  expected 
198198                    elif  name  ==  "FLASH_ATTN_MLA" :
199199                        backend  =  get_attn_backend (16 ,
@@ -210,33 +210,32 @@ def test_env(
210210                                                   None ,
211211                                                   block_size ,
212212                                                   use_mla = use_mla )
213-                         expected  =  "TRITON_MLA_VLLM_V1 " 
213+                         expected  =  "TRITON_MLA " 
214214                        assert  backend .get_name () ==  expected 
215215                elif  name  ==  "FLASHINFER" :
216216                    backend  =  get_attn_backend (16 ,
217217                                               torch .float16 ,
218218                                               None ,
219219                                               block_size ,
220220                                               use_mla = use_mla )
221-                     expected  =  "FLASHINFER_VLLM_V1 " 
221+                     expected  =  "FLASHINFER " 
222222                    assert  backend .get_name () ==  expected 
223-                 else :
223+                 elif   name   ==   "XFORMERS" :
224224                    backend  =  get_attn_backend (32 ,
225225                                               torch .float16 ,
226226                                               None ,
227227                                               block_size ,
228228                                               use_mla = use_mla )
229-                     expected  =  "FLASH_ATTN_VLLM_V1 " 
229+                     expected  =  "XFORMERS " 
230230                    assert  backend .get_name () ==  expected 
231- 
232-                     backend  =  get_attn_backend (16 ,
231+                  elif   name   ==   "FLASH_ATTN" : 
232+                     backend  =  get_attn_backend (32 ,
233233                                               torch .float16 ,
234234                                               None ,
235235                                               block_size ,
236236                                               use_mla = use_mla )
237-                     assert  backend .get_name () ==  "FLEX_ATTENTION" , (
238-                         "Should fallback to FlexAttention if head size is " 
239-                         "not supported by FlashAttention" )
237+                     expected  =  "FLASH_ATTN" 
238+                     assert  backend .get_name () ==  expected 
240239
241240
242241@pytest .mark .parametrize ("device" , ["cpu" , "cuda" ]) 
@@ -252,7 +251,7 @@ def test_fp32_fallback(
252251            with  patch ("vllm.attention.selector.current_platform" ,
253252                       CpuPlatform ()):
254253                backend  =  get_attn_backend (16 , torch .float32 , None , 16 )
255-             assert  backend .get_name () ==  "TORCH_SDPA_VLLM_V1 " 
254+             assert  backend .get_name () ==  "TORCH_SDPA " 
256255
257256        elif  device  ==  "cuda" :
258257            with  patch ("vllm.attention.selector.current_platform" ,
@@ -266,6 +265,9 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
266265    # TODO: When testing for v1, pipe in `use_v1` as an argument to 
267266    # get_attn_backend 
268267
268+     pytest .skip ("Skipping as current backend selector does not "  \
269+                 "handle fallbacks when a backend is set via env var." )
270+ 
269271    with  monkeypatch .context () as  m :
270272        m .setenv (STR_BACKEND_ENV_VAR , STR_FLASH_ATTN_VAL )
271273
0 commit comments