11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33"""Tests for v1 attention backends without GPUModelRunner dependency."""
4+ from functools import partial
5+ from typing import Optional , Union
46
57import pytest
68import torch
9+ from torch .nn .attention .flex_attention import create_block_mask , flex_attention
710
811from tests .v1 .attention .utils import (BatchSpec , _Backend ,
912 create_common_attn_metadata ,
1013 create_standard_kv_cache_spec ,
1114 create_vllm_config ,
1215 get_attention_backend )
16+ from vllm .config import ModelConfig
17+ from vllm .platforms import current_platform
1318from vllm .utils import STR_DTYPE_TO_TORCH_DTYPE , cdiv , is_torch_equal_or_newer
1419from vllm .v1 .attention .backends .utils import (CommonAttentionMetadata ,
1520 set_kv_cache_layout )
@@ -183,13 +188,19 @@ def __init__(self, device: torch.device):
183188 self ._v_scale_float = 1.0
184189
185190
186- def run_attention_backend (backend : _Backend , kv_cache_spec : FullAttentionSpec ,
187- layer_names : list [str ], vllm_config ,
188- device : torch .device ,
189- common_attn_metadata : CommonAttentionMetadata ,
190- query : torch .Tensor , key : torch .Tensor ,
191- value : torch .Tensor ,
192- kv_cache : torch .Tensor ) -> torch .Tensor :
191+ def run_attention_backend (
192+ backend : _Backend ,
193+ kv_cache_spec : FullAttentionSpec ,
194+ layer_names : list [str ],
195+ vllm_config ,
196+ device : torch .device ,
197+ common_attn_metadata : CommonAttentionMetadata ,
198+ query : torch .Tensor ,
199+ key : torch .Tensor ,
200+ value : torch .Tensor ,
201+ kv_cache : torch .Tensor ,
202+ sliding_window : Optional [int ] = None ,
203+ ) -> torch .Tensor :
193204 """Run attention computation using the specified backend's AttentionImpl."""
194205
195206 # Handle special case for FLEX_ATTENTION_SLOW
@@ -253,7 +264,7 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
253264 scale = scale ,
254265 num_kv_heads = num_kv_heads ,
255266 alibi_slopes = None ,
256- sliding_window = None ,
267+ sliding_window = sliding_window ,
257268 kv_cache_dtype = "auto" ,
258269 )
259270
@@ -275,13 +286,16 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
275286 return output
276287
277288
278- @pytest .mark .parametrize ("batch_spec_name" , [
279- "small_decode" , "small_prefill" , "mixed_small" , "medium_decode" ,
280- "medium_prefill" , "mixed_medium" , "large_decode" , "large_prefill" ,
281- "single_decode" , "single_prefill"
282- ])
283- @pytest .mark .parametrize ("model" , ["meta-llama/Meta-Llama-3-8B" ])
284- def test_backend_correctness (batch_spec_name : str , model : str ):
289+ def _test_backend_correctness (
290+ batch_spec : BatchSpec ,
291+ model : str ,
292+ backend_to_test : list [Union [_Backend , str ]],
293+ mask_mod ,
294+ * ,
295+ block_size : int = 16 ,
296+ atol : float = 1e-2 ,
297+ rtol : float = 1e-2 ,
298+ ):
285299 """
286300 Test that all backends produce similar outputs to a reference implementation
287301 using torch.nn.functional.scaled_dot_product_attention.
@@ -297,9 +311,10 @@ def test_backend_correctness(batch_spec_name: str, model: str):
297311 simulated paged KV cache.
298312 5. Comparing the vLLM backend's output to the ground-truth SDPA output.
299313 """
300- batch_spec = BATCH_SPECS [ batch_spec_name ]
314+ current_platform . seed_everything ( 42 )
301315 vllm_config = create_vllm_config (model_name = model ,
302316 max_model_len = max (batch_spec .seq_lens ),
317+ block_size = block_size ,
303318 num_gpu_blocks = 8192 )
304319 device = torch .device ("cuda:0" )
305320
@@ -314,6 +329,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
314329 num_kv_heads = vllm_config .model_config .get_num_kv_heads (
315330 vllm_config .parallel_config )
316331 head_size = vllm_config .model_config .get_head_size ()
332+ sliding_window = vllm_config .model_config .get_sliding_window ()
317333 dtype = _convert_dtype_to_torch (vllm_config .model_config .dtype )
318334 block_size = vllm_config .cache_config .block_size
319335 scale = 1.0 / (head_size ** 0.5 )
@@ -361,22 +377,21 @@ def test_backend_correctness(batch_spec_name: str, model: str):
361377 # Create causal mask: query token i attends to positions 0 to
362378 # (context_len + i)
363379 kv_len = s_len
364- offset = context_len
365- attn_mask = torch .full ((q_len , kv_len ),
366- float ('-inf' ),
367- device = device ,
368- dtype = dtype )
369- for i in range (q_len ):
370- attn_mask [i , :offset + i + 1 ] = 0.0
371-
372- sdpa_out_i = torch .nn .functional .scaled_dot_product_attention (
373- q_sdpa_in ,
374- k_sdpa_in ,
375- v_sdpa_in ,
376- attn_mask = attn_mask ,
377- scale = scale ,
378- enable_gqa = True )
379- # Convert back to (L, H, D)
380+
381+ final_mask_mod = partial (mask_mod , context_len = context_len )
382+ block_mask = create_block_mask (final_mask_mod ,
383+ B = None ,
384+ H = None ,
385+ Q_LEN = q_len ,
386+ KV_LEN = kv_len ,
387+ device = device )
388+ sdpa_out_i = flex_attention (q_sdpa_in ,
389+ k_sdpa_in ,
390+ v_sdpa_in ,
391+ block_mask = block_mask ,
392+ scale = scale ,
393+ enable_gqa = True )
394+
380395 all_sdpa_outputs .append (sdpa_out_i .transpose (1 , 2 ).squeeze (0 ))
381396
382397 # Inputs for vLLM backends are just the new tokens
@@ -412,7 +427,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
412427 # 4. Run vLLM backends and compare
413428 # Note: flex_attention has known Triton kernel compatibility issues
414429 # with test infrastructures
415- for backend_name in BACKENDS_TO_TEST :
430+ for backend_name in backend_to_test :
416431 # FlashAttentionm + FlexAttention:
417432 # [2, num_blocks, block_size, num_kv_heads, head_size]
418433 # FlashInfer:
@@ -427,12 +442,19 @@ def test_backend_correctness(batch_spec_name: str, model: str):
427442 2 , 3 ).contiguous ().transpose (2 , 3 )
428443 set_kv_cache_layout ("HND" )
429444
430- backend_output = run_attention_backend (backend_name , kv_cache_spec ,
431- ["placeholder" ], vllm_config ,
432- device , common_attn_metadata ,
433- query_vllm , key_vllm ,
434- value_vllm ,
435- kv_cache_for_backend )
445+ backend_output = run_attention_backend (
446+ backend_name ,
447+ kv_cache_spec ,
448+ ["placeholder" ],
449+ vllm_config ,
450+ device ,
451+ common_attn_metadata ,
452+ query_vllm ,
453+ key_vllm ,
454+ value_vllm ,
455+ kv_cache_for_backend ,
456+ sliding_window = sliding_window ,
457+ )
436458
437459 # Check shape and dtype consistency
438460 assert backend_output .shape == sdpa_output .shape , (
@@ -446,18 +468,102 @@ def test_backend_correctness(batch_spec_name: str, model: str):
446468 f"[{ backend_name } ] produced non-finite values" )
447469
448470 # Check numerical similarity
449- rtol = 1e-2
450- atol = 5e-3
451-
452- max_diff = torch .max (torch .abs (backend_output - sdpa_output )).item ()
453- max_rel_diff = torch .max (
454- torch .abs (backend_output - sdpa_output ) /
455- torch .abs (sdpa_output )).item ()
456- all_close = torch .allclose (backend_output ,
471+ def error_msg (msg : str , backend_name : str ):
472+ return (f"[{ backend_name } ] output differs from SDPA baseline. "
473+ f"{ msg } " )
474+
475+ torch .testing .assert_close (backend_output ,
457476 sdpa_output ,
458477 rtol = rtol ,
459- atol = atol )
478+ atol = atol ,
479+ msg = partial (error_msg ,
480+ backend_name = backend_name ))
460481
461- assert all_close , (
462- f"[{ backend_name } ] output differs from SDPA baseline. "
463- f"Max diff: { max_diff :.6f} , max rel diff: { max_rel_diff :.6f} )" )
482+
483+ @pytest .mark .parametrize ("batch_spec_name" , [
484+ "small_decode" , "small_prefill" , "mixed_small" , "medium_decode" ,
485+ "medium_prefill" , "mixed_medium" , "large_decode" , "large_prefill" ,
486+ "single_decode" , "single_prefill"
487+ ])
488+ @pytest .mark .parametrize ("model" , ["meta-llama/Meta-Llama-3-8B" ])
489+ def test_causal_backend_correctness (batch_spec_name : str , model : str ):
490+ """Test backend's correctness with causal attention."""
491+
492+ def causal_mask_mod (
493+ b : torch .Tensor ,
494+ h : torch .Tensor ,
495+ q_idx : torch .Tensor ,
496+ kv_idx : torch .Tensor ,
497+ * ,
498+ context_len : int ,
499+ ):
500+ return (q_idx + context_len ) >= kv_idx
501+
502+ batch_spec = BATCH_SPECS [batch_spec_name ]
503+ LARGE_BLOCK_BACKENDS = ([_Backend .FLEX_ATTENTION ]
504+ if is_torch_equal_or_newer ("2.9.0.dev0" ) else [])
505+ SMALL_BLOCK_BACKENDS = [
506+ x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
507+ ]
508+ _test_backend_correctness (batch_spec , model , SMALL_BLOCK_BACKENDS ,
509+ causal_mask_mod )
510+
511+ # Fast FlexAttention needs to run with block_size=128
512+ if LARGE_BLOCK_BACKENDS :
513+ _test_backend_correctness (batch_spec ,
514+ model ,
515+ LARGE_BLOCK_BACKENDS ,
516+ causal_mask_mod ,
517+ block_size = 128 )
518+
519+
520+ SLIDING_WINDOW_BACKENDS_TO_TEST = [
521+ _Backend .FLASH_ATTN_VLLM_V1 , _Backend .FLEX_ATTENTION ,
522+ _Backend .TRITON_ATTN_VLLM_V1 , "FLEX_ATTENTION_SLOW"
523+ ]
524+
525+
526+ @pytest .mark .parametrize ("batch_spec_name" , [
527+ "small_decode" , "small_prefill" , "mixed_medium" , "large_decode" ,
528+ "large_prefill"
529+ ])
530+ @pytest .mark .parametrize ("model" , ["microsoft/Phi-tiny-MoE-instruct" ])
531+ def test_sliding_window_backend_correctness (batch_spec_name : str , model : str ):
532+ """Test backend's correctness with sliding window attention."""
533+
534+ def sliding_window_mask_mod (
535+ b : torch .Tensor ,
536+ h : torch .Tensor ,
537+ q_idx : torch .Tensor ,
538+ kv_idx : torch .Tensor ,
539+ * ,
540+ context_len : int ,
541+ sliding_window : int ,
542+ ):
543+ causal_mask = q_idx + context_len >= kv_idx
544+ window_mask = q_idx + context_len - kv_idx < sliding_window
545+ return causal_mask & window_mask
546+
547+ batch_spec = BATCH_SPECS [batch_spec_name ]
548+ model_config = ModelConfig (model = model ,
549+ max_model_len = max (batch_spec .seq_lens ))
550+ sliding_window = model_config .get_sliding_window ()
551+ sliding_window_mask_mod_fn = partial (sliding_window_mask_mod ,
552+ sliding_window = sliding_window )
553+
554+ LARGE_BLOCK_BACKENDS = ([_Backend .FLEX_ATTENTION ]
555+ if is_torch_equal_or_newer ("2.9.0.dev0" ) else [])
556+ SMALL_BLOCK_BACKENDS = [
557+ x for x in SLIDING_WINDOW_BACKENDS_TO_TEST
558+ if x not in LARGE_BLOCK_BACKENDS
559+ ]
560+ _test_backend_correctness (batch_spec , model , SMALL_BLOCK_BACKENDS ,
561+ sliding_window_mask_mod_fn )
562+
563+ # Fast FlexAttention needs to run with block_size=128
564+ if LARGE_BLOCK_BACKENDS :
565+ _test_backend_correctness (batch_spec ,
566+ model ,
567+ LARGE_BLOCK_BACKENDS ,
568+ sliding_window_mask_mod_fn ,
569+ block_size = 128 )
0 commit comments