Skip to content

Commit e0329ed

Browse files
authored
Updates to Flex + VLLm integration (#21416)
Signed-off-by: drisspg <drisspguessous@gmail.com>
1 parent 6879cd8 commit e0329ed

File tree

3 files changed

+439
-103
lines changed

3 files changed

+439
-103
lines changed

tests/kernels/test_flex_attention.py

Lines changed: 87 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,17 @@
99
import torch
1010
from packaging import version
1111

12-
from vllm import SamplingParams
12+
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
13+
create_standard_kv_cache_spec,
14+
create_vllm_config)
15+
from vllm.v1.attention.backends.flex_attention import (
16+
FlexAttentionMetadataBuilder)
1317

14-
from ..models.utils import check_embeddings_close
18+
from ..models.utils import check_embeddings_close, check_logprobs_close
1519

1620
TORCH_VERSION = version.parse(torch.__version__)
1721
MINIMUM_TORCH_VERSION = version.parse("2.7.0")
22+
DIRECT_BUILD_VERSION = version.parse("2.9.dev0")
1823

1924

2025
def set_seed(seed):
@@ -34,22 +39,18 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
3439
"""Test that FlexAttention produces the same outputs as the default backend.
3540
3641
This test compares the outputs from the FlexAttention backend with
37-
the default backend, ensuring they are identical when using the same seed.
42+
the default backend, ensuring they are similar when using the same seed.
3843
"""
3944
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
4045
seed = 42
4146
max_tokens = 24
47+
num_logprobs = 5
4248
prompts = [
4349
"Hello, my name is",
4450
"The president of the United States is",
4551
"The capital of France is",
4652
]
4753

48-
sampling_params = SamplingParams(temperature=0.0,
49-
top_p=1.0,
50-
seed=seed,
51-
max_tokens=max_tokens)
52-
5354
# Run with flex attention
5455
with monkeypatch.context() as m:
5556
m.setenv("VLLM_USE_V1", "1")
@@ -61,7 +62,8 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
6162
tensor_parallel_size=1,
6263
num_gpu_blocks_override=128,
6364
enforce_eager=True) as llm_flex:
64-
output_flex = llm_flex.generate(prompts, sampling_params)
65+
output_flex = llm_flex.generate_greedy_logprobs(
66+
prompts, max_tokens, num_logprobs)
6567

6668
# Run with default backend
6769
with monkeypatch.context() as m:
@@ -71,20 +73,17 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
7173
runner="generate",
7274
tensor_parallel_size=1,
7375
num_gpu_blocks_override=128,
74-
enforce_eager=True) as llm_default:
75-
output_default = llm_default.generate(prompts, sampling_params)
76-
77-
# Compare outputs from both backends
78-
for i, (flex_result,
79-
default_result) in enumerate(zip(output_flex, output_default)):
80-
prompt = prompts[i]
81-
flex_text = flex_result[1][0]
82-
default_text = default_result[1][0]
83-
84-
assert flex_text == default_text, (
85-
f"FlexAttention output doesn't match default for: {prompt!r}\n"
86-
f"FlexAttention: {flex_text!r}\n"
87-
f"Default: {default_text!r}")
76+
enforce_eager=True,
77+
gpu_memory_utilization=0.85) as llm_default:
78+
output_default = llm_default.generate_greedy_logprobs(
79+
prompts, max_tokens, num_logprobs)
80+
81+
check_logprobs_close(
82+
outputs_0_lst=output_flex,
83+
outputs_1_lst=output_default,
84+
name_0="flex",
85+
name_1="default",
86+
)
8887

8988

9089
@pytest.mark.skipif(
@@ -136,5 +135,70 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
136135
)
137136

138137

138+
@pytest.mark.skipif(
139+
not torch.cuda.is_available() or TORCH_VERSION < DIRECT_BUILD_VERSION,
140+
reason="CUDA not available or PyTorch version < 2.7",
141+
)
142+
def test_block_mask_direct_vs_slow_path():
143+
"""Test that direct path block mask is a superset of slow path.
144+
145+
The direct path may include extra blocks for performance (over-estimation),
146+
but must include all blocks that the slow path determines are necessary.
147+
"""
148+
device = torch.device("cuda")
149+
150+
vllm_config = create_vllm_config(model_name="meta-llama/Meta-Llama-3-8B",
151+
block_size=16,
152+
max_model_len=1024)
153+
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
154+
155+
# Use a mixed batch that will create groups spanning multiple sequences
156+
batch_spec = BatchSpec(seq_lens=[35, 64, 128, 256],
157+
query_lens=[33, 5, 32, 64],
158+
name="test_mixed_batch")
159+
160+
common_attn_metadata = create_common_attn_metadata(
161+
batch_spec, vllm_config.cache_config.block_size, device)
162+
163+
builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config,
164+
device)
165+
166+
metadata_direct = builder.build(common_prefix_len=0,
167+
common_attn_metadata=common_attn_metadata)
168+
builder.direct_build = False
169+
metadata_slow = builder.build(common_prefix_len=0,
170+
common_attn_metadata=common_attn_metadata)
171+
172+
assert metadata_direct.block_mask is not None
173+
assert metadata_slow.block_mask is not None
174+
175+
# Extract block indices for comparison, B, H are the same
176+
direct_indices = metadata_direct.block_mask.kv_indices[0, 0]
177+
slow_indices = metadata_slow.block_mask.kv_indices[0, 0]
178+
direct_num = metadata_direct.block_mask.kv_num_blocks[0, 0]
179+
slow_num = metadata_slow.block_mask.kv_num_blocks[0, 0]
180+
181+
# main test: every block needed by slow path must be in direct path
182+
num_groups = direct_num.shape[0]
183+
all_contained = True
184+
missing_details = []
185+
186+
for group_idx in range(num_groups):
187+
direct_blocks = set(
188+
direct_indices[group_idx, :direct_num[group_idx]].tolist())
189+
slow_blocks = set(
190+
slow_indices[group_idx, :slow_num[group_idx]].tolist())
191+
192+
missing_blocks = slow_blocks - direct_blocks
193+
if missing_blocks:
194+
all_contained = False
195+
missing_details.append(
196+
f"Group {group_idx}: missing {sorted(missing_blocks)}")
197+
198+
assert all_contained, (
199+
"Direct path is missing blocks required by slow path:\n" +
200+
"\n".join(missing_details))
201+
202+
139203
if __name__ == "__main__":
140204
pytest.main([__file__])

tests/v1/attention/test_attention_backends.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010
create_standard_kv_cache_spec,
1111
create_vllm_config,
1212
get_attention_backend)
13-
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
13+
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
1414
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
1515
set_kv_cache_layout)
1616
from vllm.v1.kv_cache_interface import FullAttentionSpec
1717

1818
BACKENDS_TO_TEST = [
1919
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1,
20-
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN
20+
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN,
21+
"FLEX_ATTENTION_SLOW"
2122
]
2223

2324
# Remove flashinfer from the list if it's not available
@@ -97,7 +98,7 @@ def create_and_prepopulate_kv_cache(
9798
common_attn_metadata: CommonAttentionMetadata,
9899
randomize_blocks: bool = True) -> torch.Tensor:
99100
"""Create and prepopulate a KV cache with context data.
100-
101+
101102
Args:
102103
k_contexts: List of key context tensors for each sequence
103104
v_contexts: List of value context tensors for each sequence
@@ -109,9 +110,9 @@ def create_and_prepopulate_kv_cache(
109110
device: Device to create the cache on
110111
num_blocks: Total number of blocks in the cache
111112
block_table: Block table tensor to populate
112-
randomize_blocks: Whether to randomly permute blocks
113+
randomize_blocks: Whether to randomly permute blocks
113114
or use sequential order
114-
115+
115116
Returns:
116117
Tuple of (kv_cache, updated_block_table)
117118
"""
@@ -206,10 +207,18 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
206207
kv_cache: torch.Tensor) -> torch.Tensor:
207208
"""Run attention computation using the specified backend's AttentionImpl."""
208209

209-
builder_cls, impl_cls = get_attention_backend(backend)
210+
# Handle special case for FLEX_ATTENTION_SLOW
211+
actual_backend = backend
212+
213+
use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0")
214+
if backend == "FLEX_ATTENTION_SLOW":
215+
actual_backend = _Backend.FLEX_ATTENTION
216+
use_direct_block_mask = False
217+
218+
builder_cls, impl_cls = get_attention_backend(actual_backend)
210219

211220
# Mock flashinfer's get_per_layer_parameters if needed
212-
if backend == _Backend.FLASHINFER_VLLM_V1:
221+
if actual_backend == _Backend.FLASHINFER_VLLM_V1:
213222
import unittest.mock
214223

215224
from vllm.v1.attention.backends.utils import PerLayerParameters
@@ -239,6 +248,8 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
239248
else:
240249
# Build metadata
241250
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
251+
if actual_backend == _Backend.FLEX_ATTENTION:
252+
builder.direct_build = use_direct_block_mask
242253
attn_metadata = builder.build(
243254
common_prefix_len=0,
244255
common_attn_metadata=common_attn_metadata,
@@ -453,11 +464,6 @@ def test_backend_correctness(batch_spec_name: str, model: str):
453464
rtol = 1e-2
454465
atol = 5e-3
455466

456-
if backend_name == _Backend.FLEX_ATTENTION:
457-
atol = 5e-1 # TODO: figure out why flex_attention has such large
458-
# numerical differences for medium_decode, medium_prefill,
459-
# mixed_medium
460-
461467
max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item()
462468
max_rel_diff = torch.max(
463469
torch.abs(backend_output - sdpa_output) /

0 commit comments

Comments
 (0)