@@ -146,6 +146,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
146146 # required block_size.
147147 use_flashmla = False
148148 use_cutlass_mla = False
149+ use_flashinfer_mla = False
149150
150151 if envs .VLLM_ATTENTION_BACKEND is None :
151152 # Default case
@@ -164,6 +165,8 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
164165 use_flashmla = (envs .VLLM_ATTENTION_BACKEND == "FLASHMLA" )
165166 use_cutlass_mla = (
166167 envs .VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" )
168+ use_flashinfer_mla = (
169+ envs .VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA" )
167170
168171 from vllm .attention .ops .flashmla import is_flashmla_supported
169172 if use_flashmla and is_flashmla_supported ()[0 ] \
@@ -176,6 +179,11 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
176179 cache_config .block_size = 128
177180 logger .info ("Forcing kv cache block size to 128 for "
178181 "CUTLASS_MLA backend." )
182+ if use_flashinfer_mla and cache_config .block_size not in [32 , 64 ]:
183+ cache_config .block_size = 64
184+ logger .info (
185+ "Forcing kv cache block size to 64 for FlashInferMLA "
186+ "backend." )
179187
180188 # lazy import to avoid circular import
181189 from vllm .config import CUDAGraphMode
@@ -228,8 +236,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
228236 use_cutlassmla = selected_backend == _Backend .CUTLASS_MLA or (
229237 selected_backend is None and cls .is_device_capability (100 )
230238 and block_size == 128 )
231- use_flashinfermla = (selected_backend == _Backend .FLASHINFER_MLA
232- and cls .has_device_capability (100 ))
239+ use_flashinfermla = selected_backend == _Backend .FLASHINFER_MLA or (
240+ selected_backend is None and cls .is_device_capability (100 )
241+ and block_size in [32 , 64 ])
233242 use_flashmla = selected_backend in [
234243 _Backend .FLASHMLA , _Backend .FLASHMLA_VLLM_V1
235244 ] or (selected_backend is None and is_flashmla_supported ()[0 ])
0 commit comments