11# SPDX-License-Identifier: Apache-2.0 
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 
3- """Tests for v1 MLA backends without GPUModelRunner dependency.""" 
3+ """Tests for v1 MLA backends without GPUModelRunner dependency. 
4+ 
5+ Known Issues: 
6+ - FLASH_ATTN_MLA backend occasionally produces NaN values in 
7+   test_backend_correctness[mixed_small] when run after 
8+   test_backend_correctness[small_prefill], but passes when run alone. 
9+ """ 
410
511import  pytest 
612import  torch 
1420)
1521from  vllm  import  _custom_ops  as  ops 
1622from  vllm .attention .backends .registry  import  _Backend 
23+ from  vllm .attention .ops .flashmla  import  is_flashmla_dense_supported 
24+ from  vllm .config .vllm  import  set_current_vllm_config 
1725from  vllm .utils  import  STR_DTYPE_TO_TORCH_DTYPE , cdiv 
1826from  vllm .v1 .attention .backends .utils  import  CommonAttentionMetadata 
1927from  vllm .v1 .kv_cache_interface  import  FullAttentionSpec 
2937if  not  torch .cuda .is_available () or  torch .cuda .get_device_properties (0 ).major  <  10 :
3038    BACKENDS_TO_TEST .remove (_Backend .CUTLASS_MLA )
3139
40+ # Remove FLASHMLA from the list if not supported 
41+ if  not  is_flashmla_dense_supported ()[0 ]:
42+     BACKENDS_TO_TEST .remove (_Backend .FLASHMLA )
43+ 
3244torch .manual_seed (42 )
3345
3446
@@ -66,6 +78,12 @@ def _convert_dtype_to_torch(dtype):
6678    "large_prefill" : BatchSpec (seq_lens = [4096 ] *  8 , query_lens = [32 ] *  8 ),
6779    "single_decode" : BatchSpec (seq_lens = [1024 ], query_lens = [1 ]),
6880    "single_prefill" : BatchSpec (seq_lens = [1024 ], query_lens = [64 ]),
81+     "spec_decode_small" : BatchSpec (
82+         seq_lens = [128 , 256 , 512 , 1024 ], query_lens = [4 , 4 , 4 , 4 ]
83+     ),
84+     "spec_decode_medium" : BatchSpec (
85+         seq_lens = [512 , 1024 , 2048 , 512 , 1024 , 2048 ], query_lens = [8 , 8 , 8 , 8 , 8 , 8 ]
86+     ),
6987}
7088
7189
@@ -239,61 +257,64 @@ def run_attention_backend(
239257
240258    builder_cls , impl_cls  =  try_get_attention_backend (backend )
241259
242-     # Build metadata 
243-     builder  =  builder_cls (kv_cache_spec , layer_names , vllm_config , device )
244-     attn_metadata  =  builder .build (
245-         common_prefix_len = 0 ,
246-         common_attn_metadata = common_attn_metadata ,
247-     )
260+     # Set the current vllm config so that get_current_vllm_config() works 
261+     # in the backend implementations 
262+     with  set_current_vllm_config (vllm_config ):
263+         # Build metadata 
264+         builder  =  builder_cls (kv_cache_spec , layer_names , vllm_config , device )
265+         attn_metadata  =  builder .build (
266+             common_prefix_len = 0 ,
267+             common_attn_metadata = common_attn_metadata ,
268+         )
248269
249-     # Instantiate MLA implementation 
250-     num_heads  =  vllm_config .model_config .get_num_attention_heads (
251-         vllm_config .parallel_config 
252-     )
253-     num_kv_heads  =  vllm_config .model_config .get_num_kv_heads (
254-         vllm_config .parallel_config 
255-     )
256-     head_size  =  vllm_config .model_config .get_head_size ()
257-     scale  =  1.0  /  (head_size ** 0.5 )
258-     impl  =  impl_cls (
259-         num_heads = num_heads ,
260-         head_size = head_size ,
261-         scale = scale ,
262-         num_kv_heads = num_kv_heads ,
263-         alibi_slopes = None ,
264-         sliding_window = None ,
265-         kv_cache_dtype = "auto" ,
266-         logits_soft_cap = None ,
267-         attn_type = "decoder" ,
268-         kv_sharing_target_layer_name = None ,
269-         q_lora_rank = None ,
270-         kv_lora_rank = kv_lora_rank ,
271-         qk_nope_head_dim = qk_nope_head_dim ,
272-         qk_rope_head_dim = qk_rope_head_dim ,
273-         qk_head_dim = qk_nope_head_dim  +  qk_rope_head_dim ,
274-         v_head_dim = v_head_dim ,
275-         kv_b_proj = mock_kv_b_proj ,
276-     )
270+          # Instantiate MLA implementation 
271+          num_heads  =  vllm_config .model_config .get_num_attention_heads (
272+              vllm_config .parallel_config 
273+          )
274+          num_kv_heads  =  vllm_config .model_config .get_num_kv_heads (
275+              vllm_config .parallel_config 
276+          )
277+          head_size  =  vllm_config .model_config .get_head_size ()
278+          scale  =  1.0  /  (head_size ** 0.5 )
279+          impl  =  impl_cls (
280+              num_heads = num_heads ,
281+              head_size = head_size ,
282+              scale = scale ,
283+              num_kv_heads = num_kv_heads ,
284+              alibi_slopes = None ,
285+              sliding_window = None ,
286+              kv_cache_dtype = "auto" ,
287+              logits_soft_cap = None ,
288+              attn_type = "decoder" ,
289+              kv_sharing_target_layer_name = None ,
290+              q_lora_rank = None ,
291+              kv_lora_rank = kv_lora_rank ,
292+              qk_nope_head_dim = qk_nope_head_dim ,
293+              qk_rope_head_dim = qk_rope_head_dim ,
294+              qk_head_dim = qk_nope_head_dim  +  qk_rope_head_dim ,
295+              v_head_dim = v_head_dim ,
296+              kv_b_proj = mock_kv_b_proj ,
297+          )
277298
278-     # Process weights to create W_UK_T and W_UV attributes needed by MLA 
279-     act_dtype  =  _convert_dtype_to_torch (vllm_config .model_config .dtype )
280-     impl .process_weights_after_loading (act_dtype )
299+          # Process weights to create W_UK_T and W_UV attributes needed by MLA 
300+          act_dtype  =  _convert_dtype_to_torch (vllm_config .model_config .dtype )
301+          impl .process_weights_after_loading (act_dtype )
281302
282-     # Create mock layer and output buffer 
283-     mock_layer  =  MockAttentionLayer (device )
284-     num_tokens  =  query .shape [0 ]
285-     output  =  torch .empty (
286-         num_tokens , num_heads  *  v_head_dim , dtype = query .dtype , device = query .device 
287-     )
303+          # Create mock layer and output buffer 
304+          mock_layer  =  MockAttentionLayer (device )
305+          num_tokens  =  query .shape [0 ]
306+          output  =  torch .empty (
307+              num_tokens , num_heads  *  v_head_dim , dtype = query .dtype , device = query .device 
308+          )
288309
289-     # Run forward pass 
290-     # NOTE: The query, key, and value are already shaped correctly 
291-     # in the calling test function. 
292-     output  =  impl .forward (
293-         mock_layer , query , kv_c , k_pe , kv_cache , attn_metadata , output = output 
294-     )
310+          # Run forward pass 
311+          # NOTE: The query, key, and value are already shaped correctly 
312+          # in the calling test function. 
313+          output  =  impl .forward (
314+              mock_layer , query , kv_c , k_pe , kv_cache , attn_metadata , output = output 
315+          )
295316
296-     return  output 
317+          return  output 
297318
298319
299320@pytest .mark .parametrize ( 
@@ -309,6 +330,8 @@ def run_attention_backend(
309330        "large_prefill" , 
310331        "single_decode" , 
311332        "single_prefill" , 
333+         "spec_decode_small" , 
334+         "spec_decode_medium" , 
312335    ], 
313336) 
314337@pytest .mark .parametrize ("model" , ["deepseek-ai/DeepSeek-V2-Lite-Chat" ]) 
@@ -328,10 +351,39 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
328351       simulated paged KV cache. 
329352    5. Comparing the vLLM backend's output to the ground-truth SDPA output. 
330353    """ 
354+     from  vllm .v1 .attention .backends .mla .common  import  QueryLenSupport 
355+ 
331356    batch_spec  =  BATCH_SPECS [batch_spec_name ]
357+     is_spec_decode_test  =  batch_spec_name .startswith ("spec_decode" )
358+     spec_decode_backends  =  {_Backend .FLASH_ATTN_MLA , _Backend .FLASHMLA }
359+ 
360+     block_size  =  16 
361+     required_blocks  =  sum (
362+         (seq_len  +  block_size  -  1 ) //  block_size  for  seq_len  in  batch_spec .seq_lens 
363+     )
364+     # Add 1 for null block at index 0, and some buffer 
365+     num_gpu_blocks  =  required_blocks  +  1  +  100 
366+ 
332367    vllm_config  =  create_vllm_config (
333-         model_name = model , max_model_len = max (batch_spec .seq_lens ), num_gpu_blocks = 2048 
368+         model_name = model ,
369+         max_model_len = max (batch_spec .seq_lens ),
370+         num_gpu_blocks = num_gpu_blocks ,
371+         block_size = block_size ,
334372    )
373+ 
374+     # For spec decode tests, add a speculative_config to set the reorder_batch_threshold 
375+     if  is_spec_decode_test :
376+         from  vllm .config  import  SpeculativeConfig 
377+ 
378+         # Get the query length from the batch spec (they should all be uniform) 
379+         query_len  =  batch_spec .query_lens [0 ]
380+         # Set num_speculative_tokens to query_len - 1 
381+         # (since threshold is 1 + num_spec_tokens) 
382+         # Use ngram method which doesn't require a draft model 
383+         vllm_config .speculative_config  =  SpeculativeConfig (
384+             method = "ngram" , num_speculative_tokens = query_len  -  1 
385+         )
386+ 
335387    device  =  torch .device ("cuda:0" )
336388
337389    kv_cache_spec  =  create_standard_kv_cache_spec (vllm_config )
@@ -395,11 +447,37 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
395447        # K_PE (rope component): [s_len, 1, qk_rope_head_dim] 
396448        k_pe_full  =  torch .randn (s_len , 1 , qk_rope_head_dim , dtype = dtype , device = device )
397449
398-         # Determine if this is decode or prefill 
450+         # Determine if this sequence uses the decode pipeline or prefill 
451+         # pipeline for each backend 
452+         # NOTE: For spec decode tests with uniform query_len > 1, backends that 
453+         # support spec decode (FLASH_ATTN_MLA with varlen support, FLASHMLA with 
454+         # uniform support) will use the decode pipeline (MQA-style), while 
455+         # backends that only support single-token queries will use the prefill 
456+         # pipeline (MHA-style). This ensures the reference implementation 
457+         # matches each backend's actual decode/prefill pipeline path. 
399458        is_decode  =  []
400-         for  i , backend  in  enumerate (BACKENDS_TO_TEST ):
459+         for  backend_idx , backend  in  enumerate (BACKENDS_TO_TEST ):
401460            builder_cls , _  =  try_get_attention_backend (backend )
402-             is_decode .append (q_len  <=  builder_cls .reorder_batch_threshold )
461+             if  is_spec_decode_test :
462+                 query_len_support  =  getattr (
463+                     builder_cls , "query_len_support" , QueryLenSupport .SINGLE_ONLY 
464+                 )
465+                 supports_spec  =  query_len_support  !=  QueryLenSupport .SINGLE_ONLY 
466+                 is_decode .append (supports_spec )
467+             else :
468+                 threshold  =  getattr (builder_cls , "reorder_batch_threshold" , None )
469+                 query_len_support  =  getattr (
470+                     builder_cls , "query_len_support" , QueryLenSupport .SINGLE_ONLY 
471+                 )
472+                 within_threshold  =  q_len  <=  threshold  if  threshold  else  False 
473+                 if  (
474+                     within_threshold 
475+                     and  query_len_support  ==  QueryLenSupport .UNIFORM 
476+                     and  i  >  0 
477+                 ):
478+                     first_q_len  =  query_lens [0 ]
479+                     within_threshold  =  q_len  ==  first_q_len 
480+                 is_decode .append (within_threshold )
403481
404482        # Split q into nope and rope components 
405483        q_nope , q_pe  =  q_c .split ([qk_nope_head_dim , qk_rope_head_dim ], dim = - 1 )
@@ -478,11 +556,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
478556        sdpa_out_i_prefill  =  sdpa_out_i_prefill .transpose (1 , 2 ).squeeze (0 )
479557        sdpa_out_i_prefill  =  sdpa_out_i_prefill .flatten (start_dim = - 2 )
480558
481-         for  i , backend  in  enumerate (BACKENDS_TO_TEST ):
482-             if  is_decode [i ]:
483-                 all_sdpa_outputs [i ].append (sdpa_out_i_decode )
559+         for  backend_idx , backend  in  enumerate (BACKENDS_TO_TEST ):
560+             if  is_decode [backend_idx ]:
561+                 all_sdpa_outputs [backend_idx ].append (sdpa_out_i_decode )
484562            else :
485-                 all_sdpa_outputs [i ].append (sdpa_out_i_prefill )
563+                 all_sdpa_outputs [backend_idx ].append (sdpa_out_i_prefill )
486564
487565        # Inputs for vLLM MLA backends are just the new tokens 
488566        all_q_vllm .append (q_c )
@@ -497,9 +575,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
497575    query_vllm  =  torch .cat (all_q_vllm , dim = 0 )
498576    kv_c_vllm  =  torch .cat (all_kv_c_vllm , dim = 0 )
499577    k_pe_vllm  =  torch .cat (all_k_pe_vllm , dim = 0 )
500-     sdpa_outputs  =  [] 
501-     for  i , backend  in  enumerate (BACKENDS_TO_TEST ):
502-         sdpa_outputs . append ( torch .cat (all_sdpa_outputs [i ], dim = 0 ) )
578+     sdpa_outputs  =  {} 
579+     for  backend_idx , backend  in  enumerate (BACKENDS_TO_TEST ):
580+         sdpa_outputs [ backend ]  =   torch .cat (all_sdpa_outputs [backend_idx ], dim = 0 )
503581
504582    # Create mock kv_b_proj using the same weights as reference implementation 
505583    from  vllm .model_executor .layers .linear  import  ColumnParallelLinear 
@@ -516,7 +594,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
516594    kv_b_proj_weight  =  kv_b_proj_weight .view (
517595        kv_lora_rank , num_q_heads  *  (qk_nope_head_dim  +  v_head_dim )
518596    )
519-     mock_kv_b_proj .weight  =  torch .nn .Parameter (kv_b_proj_weight .T )
597+     mock_kv_b_proj .weight  =  torch .nn .Parameter (kv_b_proj_weight .T ,  requires_grad = False )
520598
521599    # Create metadata using original batch spec 
522600    common_attn_metadata  =  create_common_attn_metadata (
@@ -537,7 +615,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
537615    )
538616
539617    # 4. Run vLLM backends and compare 
540-     for  i , backend_name  in  enumerate (BACKENDS_TO_TEST ):
618+     for  backend_idx , backend_name  in  enumerate (BACKENDS_TO_TEST ):
619+         # Skip backends that don't support spec decode for spec decode tests 
620+         if  is_spec_decode_test  and  backend_name  not  in   spec_decode_backends :
621+             continue 
622+ 
541623        backend_output  =  run_attention_backend (
542624            backend_name ,
543625            kv_cache_spec ,
@@ -556,14 +638,17 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
556638            mock_kv_b_proj ,
557639        )
558640
641+         # Use backend_idx to get the correct SDPA output for this backend 
642+         expected_output  =  sdpa_outputs [backend_name ]
643+ 
559644        # Check shape and dtype consistency 
560-         assert  backend_output .shape  ==  sdpa_outputs [ i ] .shape , (
645+         assert  backend_output .shape  ==  expected_output .shape , (
561646            f"[{ backend_name }  ] shape { backend_output .shape }   != " 
562-             f"SDPA shape { sdpa_outputs [ i ] .shape }  " 
647+             f"SDPA shape { expected_output .shape }  " 
563648        )
564-         assert  backend_output .dtype  ==  sdpa_outputs [ i ] .dtype , (
649+         assert  backend_output .dtype  ==  expected_output .dtype , (
565650            f"[{ backend_name }  ] dtype { backend_output .dtype }   != " 
566-             f"SDPA dtype { sdpa_outputs [ i ] .dtype }  " 
651+             f"SDPA dtype { expected_output .dtype }  " 
567652        )
568653
569654        assert  torch .isfinite (backend_output ).all (), (
@@ -574,12 +659,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
574659        rtol  =  1e-2 
575660        atol  =  5e-1 
576661
577-         max_diff  =  torch .max (torch .abs (backend_output  -  sdpa_outputs [ i ] )).item ()
662+         max_diff  =  torch .max (torch .abs (backend_output  -  expected_output )).item ()
578663        max_rel_diff  =  torch .max (
579-             torch .abs (backend_output  -  sdpa_outputs [ i ] ) /  torch .abs (sdpa_outputs [ i ] )
664+             torch .abs (backend_output  -  expected_output ) /  torch .abs (expected_output )
580665        ).item ()
581666        all_close  =  torch .allclose (
582-             backend_output , sdpa_outputs [ i ] , rtol = rtol , atol = atol 
667+             backend_output , expected_output , rtol = rtol , atol = atol 
583668        )
584669
585670        assert  all_close , (
0 commit comments