99import torch
1010from packaging import version
1111
12- from vllm import SamplingParams
12+ from tests .v1 .attention .utils import (BatchSpec , create_common_attn_metadata ,
13+ create_standard_kv_cache_spec ,
14+ create_vllm_config )
15+ from vllm .v1 .attention .backends .flex_attention import (
16+ FlexAttentionMetadataBuilder )
1317
14- from ..models .utils import check_embeddings_close
18+ from ..models .utils import check_embeddings_close , check_logprobs_close
1519
1620TORCH_VERSION = version .parse (torch .__version__ )
1721MINIMUM_TORCH_VERSION = version .parse ("2.7.0" )
22+ DIRECT_BUILD_VERSION = version .parse ("2.9.dev0" )
1823
1924
2025def set_seed (seed ):
@@ -34,22 +39,18 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
3439 """Test that FlexAttention produces the same outputs as the default backend.
3540
3641 This test compares the outputs from the FlexAttention backend with
37- the default backend, ensuring they are identical when using the same seed.
42+ the default backend, ensuring they are similar when using the same seed.
3843 """
3944 model_name = "Qwen/Qwen2.5-1.5B-Instruct"
4045 seed = 42
4146 max_tokens = 24
47+ num_logprobs = 5
4248 prompts = [
4349 "Hello, my name is" ,
4450 "The president of the United States is" ,
4551 "The capital of France is" ,
4652 ]
4753
48- sampling_params = SamplingParams (temperature = 0.0 ,
49- top_p = 1.0 ,
50- seed = seed ,
51- max_tokens = max_tokens )
52-
5354 # Run with flex attention
5455 with monkeypatch .context () as m :
5556 m .setenv ("VLLM_USE_V1" , "1" )
@@ -61,7 +62,8 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
6162 tensor_parallel_size = 1 ,
6263 num_gpu_blocks_override = 128 ,
6364 enforce_eager = True ) as llm_flex :
64- output_flex = llm_flex .generate (prompts , sampling_params )
65+ output_flex = llm_flex .generate_greedy_logprobs (
66+ prompts , max_tokens , num_logprobs )
6567
6668 # Run with default backend
6769 with monkeypatch .context () as m :
@@ -71,20 +73,17 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
7173 runner = "generate" ,
7274 tensor_parallel_size = 1 ,
7375 num_gpu_blocks_override = 128 ,
74- enforce_eager = True ) as llm_default :
75- output_default = llm_default .generate (prompts , sampling_params )
76-
77- # Compare outputs from both backends
78- for i , (flex_result ,
79- default_result ) in enumerate (zip (output_flex , output_default )):
80- prompt = prompts [i ]
81- flex_text = flex_result [1 ][0 ]
82- default_text = default_result [1 ][0 ]
83-
84- assert flex_text == default_text , (
85- f"FlexAttention output doesn't match default for: { prompt !r} \n "
86- f"FlexAttention: { flex_text !r} \n "
87- f"Default: { default_text !r} " )
76+ enforce_eager = True ,
77+ gpu_memory_utilization = 0.85 ) as llm_default :
78+ output_default = llm_default .generate_greedy_logprobs (
79+ prompts , max_tokens , num_logprobs )
80+
81+ check_logprobs_close (
82+ outputs_0_lst = output_flex ,
83+ outputs_1_lst = output_default ,
84+ name_0 = "flex" ,
85+ name_1 = "default" ,
86+ )
8887
8988
9089@pytest .mark .skipif (
@@ -136,5 +135,70 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
136135 )
137136
138137
138+ @pytest .mark .skipif (
139+ not torch .cuda .is_available () or TORCH_VERSION < DIRECT_BUILD_VERSION ,
140+ reason = "CUDA not available or PyTorch version < 2.7" ,
141+ )
142+ def test_block_mask_direct_vs_slow_path ():
143+ """Test that direct path block mask is a superset of slow path.
144+
145+ The direct path may include extra blocks for performance (over-estimation),
146+ but must include all blocks that the slow path determines are necessary.
147+ """
148+ device = torch .device ("cuda" )
149+
150+ vllm_config = create_vllm_config (model_name = "meta-llama/Meta-Llama-3-8B" ,
151+ block_size = 16 ,
152+ max_model_len = 1024 )
153+ kv_cache_spec = create_standard_kv_cache_spec (vllm_config )
154+
155+ # Use a mixed batch that will create groups spanning multiple sequences
156+ batch_spec = BatchSpec (seq_lens = [35 , 64 , 128 , 256 ],
157+ query_lens = [33 , 5 , 32 , 64 ],
158+ name = "test_mixed_batch" )
159+
160+ common_attn_metadata = create_common_attn_metadata (
161+ batch_spec , vllm_config .cache_config .block_size , device )
162+
163+ builder = FlexAttentionMetadataBuilder (kv_cache_spec , [], vllm_config ,
164+ device )
165+
166+ metadata_direct = builder .build (common_prefix_len = 0 ,
167+ common_attn_metadata = common_attn_metadata )
168+ builder .direct_build = False
169+ metadata_slow = builder .build (common_prefix_len = 0 ,
170+ common_attn_metadata = common_attn_metadata )
171+
172+ assert metadata_direct .block_mask is not None
173+ assert metadata_slow .block_mask is not None
174+
175+ # Extract block indices for comparison, B, H are the same
176+ direct_indices = metadata_direct .block_mask .kv_indices [0 , 0 ]
177+ slow_indices = metadata_slow .block_mask .kv_indices [0 , 0 ]
178+ direct_num = metadata_direct .block_mask .kv_num_blocks [0 , 0 ]
179+ slow_num = metadata_slow .block_mask .kv_num_blocks [0 , 0 ]
180+
181+ # main test: every block needed by slow path must be in direct path
182+ num_groups = direct_num .shape [0 ]
183+ all_contained = True
184+ missing_details = []
185+
186+ for group_idx in range (num_groups ):
187+ direct_blocks = set (
188+ direct_indices [group_idx , :direct_num [group_idx ]].tolist ())
189+ slow_blocks = set (
190+ slow_indices [group_idx , :slow_num [group_idx ]].tolist ())
191+
192+ missing_blocks = slow_blocks - direct_blocks
193+ if missing_blocks :
194+ all_contained = False
195+ missing_details .append (
196+ f"Group { group_idx } : missing { sorted (missing_blocks )} " )
197+
198+ assert all_contained , (
199+ "Direct path is missing blocks required by slow path:\n " +
200+ "\n " .join (missing_details ))
201+
202+
139203if __name__ == "__main__" :
140204 pytest .main ([__file__ ])
0 commit comments