@@ -19,45 +19,152 @@ def clear_cache():
1919 _cached_get_attn_backend .cache_clear ()
2020
2121
22- @pytest .mark .parametrize (
23- "name" , ["TORCH_SDPA" , "ROCM_FLASH" , "XFORMERS" , "FLASHINFER" ])
22+ # Define MLA and non-MLA backends separately
23+ DEVICE_MLA_BACKENDS = {
24+ "cuda" : ["TRITON_MLA" , "FLASHMLA" ],
25+ "hip" : ["TRITON_MLA" , "ROCM_AITER_MLA" ],
26+ "cpu" : [],
27+ }
28+
29+ DEVICE_REGULAR_ATTN_BACKENDS = {
30+ "cuda" : ["XFORMERS" , "FLASHINFER" ],
31+ "hip" : ["ROCM_FLASH" ],
32+ "cpu" : ["TORCH_SDPA" ],
33+ }
34+
35+ DEVICE_MLA_BLOCK_SIZES = {
36+ "cuda" : [16 , 64 ], # CUDA supports both standard and extended block sizes
37+ "hip" : [16 , 1 ], # HIP requires special handling for block_size=1
38+ "cpu" : [16 ] # CPU uses fixed block size from test cases
39+ }
40+
41+
42+ def generate_params ():
43+ params = []
44+ for use_mla in [True , False ]:
45+ for device in ["cuda" , "hip" , "cpu" ]:
46+ backends = DEVICE_MLA_BACKENDS [
47+ device ] if use_mla else DEVICE_REGULAR_ATTN_BACKENDS [device ]
48+ for name in backends :
49+ block_sizes = DEVICE_MLA_BLOCK_SIZES [device ] if use_mla else [
50+ 16
51+ ]
52+ for block_size in block_sizes :
53+ params .append (
54+ pytest .param (
55+ device ,
56+ name ,
57+ use_mla ,
58+ block_size ,
59+ id =
60+ f"{ device } _{ name } _mla_{ str (use_mla )[0 ]} _blks{ block_size } "
61+ ))
62+ return params
63+
64+
65+ @pytest .mark .parametrize ("device, name, use_mla, block_size" ,
66+ generate_params ())
2467@pytest .mark .parametrize ("use_v1" , [True , False ])
25- @pytest .mark .parametrize ("device" , ["cpu" , "hip" , "cuda" ])
2668def test_env (
69+ device : str ,
2770 name : str ,
71+ use_mla : bool ,
72+ block_size : int ,
2873 use_v1 : bool ,
29- device : str ,
3074 monkeypatch : pytest .MonkeyPatch ,
3175):
32- """Test that the attention selector can be set via environment variable.
33- Note that we do not test FlashAttn because it is the default backend.
34- """
35-
76+ """Test attention backend selection with valid device-backend pairs."""
3677 with monkeypatch .context () as m :
3778 m .setenv ("VLLM_USE_V1" , "1" if use_v1 else "0" )
3879 m .setenv (STR_BACKEND_ENV_VAR , name )
80+ m .setenv ("VLLM_MLA_DISABLE" , "1" if use_mla else "0" )
3981
4082 if device == "cpu" :
4183 with patch ("vllm.attention.selector.current_platform" ,
4284 CpuPlatform ()):
4385 backend = get_attn_backend (16 , torch .float16 , torch .float16 ,
44- 16 , False )
86+ block_size , False )
4587 assert backend .get_name () == "TORCH_SDPA"
88+
4689 elif device == "hip" :
4790 with patch ("vllm.attention.selector.current_platform" ,
4891 RocmPlatform ()):
49- backend = get_attn_backend (16 , torch .float16 , torch .float16 ,
50- 16 , False )
51- EXPECTED = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
52- assert backend .get_name () == EXPECTED
53- else :
54- if name in ["XFORMERS" , "FLASHINFER" ]:
55- with patch ("vllm.attention.selector.current_platform" ,
56- CudaPlatform ()):
57- backend = get_attn_backend (16 , torch .float16 ,
58- torch .float16 , 16 , False )
59- EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else name
60- assert backend .get_name () == EXPECTED
92+ if use_mla :
93+ # Validate HIP MLA backend-block_size combinations
94+ valid_combination = (
95+ (name == "TRITON_MLA" and block_size != 1 )
96+ or (name == "ROCM_AITER_MLA" and block_size == 1 ))
97+
98+ if valid_combination :
99+ backend = get_attn_backend (16 ,
100+ torch .float16 ,
101+ torch .float16 ,
102+ block_size ,
103+ False ,
104+ use_mla = use_mla )
105+ assert backend .get_name () == name
106+ else :
107+ with pytest .raises (ValueError ) as exc_info :
108+ get_attn_backend (16 ,
109+ torch .float16 ,
110+ torch .float16 ,
111+ block_size ,
112+ False ,
113+ use_mla = use_mla )
114+ assert f"The selected backend, { name } " in str (
115+ exc_info .value )
116+ else :
117+ backend = get_attn_backend (16 ,
118+ torch .float16 ,
119+ torch .float16 ,
120+ block_size ,
121+ False ,
122+ use_mla = use_mla )
123+ expected = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
124+ assert backend .get_name () == expected
125+
126+ elif device == "cuda" :
127+ with patch ("vllm.attention.selector.current_platform" ,
128+ CudaPlatform ()):
129+ if use_mla :
130+ if name == "FLASHMLA" and block_size == 64 :
131+ from vllm .attention .backends .flashmla import (
132+ is_flashmla_supported )
133+
134+ # only on cuda platforms with specific capability.
135+ is_supported , _ = is_flashmla_supported ()
136+
137+ if not is_supported :
138+ # if platform is not supported then skip this case.
139+ pytest .skip ()
140+ else :
141+ backend = get_attn_backend (16 ,
142+ torch .float16 ,
143+ torch .float16 ,
144+ block_size ,
145+ False ,
146+ use_mla = use_mla )
147+ expected = f"{ name } _VLLM_V1" if use_v1 else name
148+ assert backend .get_name () == expected
149+ else :
150+ backend = get_attn_backend (16 ,
151+ torch .float16 ,
152+ torch .float16 ,
153+ block_size ,
154+ False ,
155+ use_mla = use_mla )
156+ expected = ("TRITON_MLA_VLLM_V1"
157+ if use_v1 else "TRITON_MLA" )
158+ assert backend .get_name () == expected
159+ else :
160+ backend = get_attn_backend (16 ,
161+ torch .float16 ,
162+ torch .float16 ,
163+ block_size ,
164+ False ,
165+ use_mla = use_mla )
166+ expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name
167+ assert backend .get_name () == expected
61168
62169
63170def test_flash_attn (monkeypatch : pytest .MonkeyPatch ):
0 commit comments