Skip to content

Commit 6872684

Browse files
MatthewBonannixuebwang-amd
authored andcommitted
[Attention][Spec Decode] FlashMLA spec decode support (vllm-project#26541)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 14a0d33 commit 6872684

File tree

5 files changed

+214
-91
lines changed

5 files changed

+214
-91
lines changed

tests/v1/attention/test_mla_backends.py

Lines changed: 156 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
"""Tests for v1 MLA backends without GPUModelRunner dependency."""
3+
"""Tests for v1 MLA backends without GPUModelRunner dependency.
4+
5+
Known Issues:
6+
- FLASH_ATTN_MLA backend occasionally produces NaN values in
7+
test_backend_correctness[mixed_small] when run after
8+
test_backend_correctness[small_prefill], but passes when run alone.
9+
"""
410

511
import pytest
612
import torch
@@ -14,6 +20,8 @@
1420
)
1521
from vllm import _custom_ops as ops
1622
from vllm.attention.backends.registry import _Backend
23+
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
24+
from vllm.config.vllm import set_current_vllm_config
1725
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
1826
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
1927
from vllm.v1.kv_cache_interface import FullAttentionSpec
@@ -29,6 +37,10 @@
2937
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
3038
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA)
3139

40+
# Remove FLASHMLA from the list if not supported
41+
if not is_flashmla_dense_supported()[0]:
42+
BACKENDS_TO_TEST.remove(_Backend.FLASHMLA)
43+
3244
torch.manual_seed(42)
3345

3446

@@ -66,6 +78,12 @@ def _convert_dtype_to_torch(dtype):
6678
"large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
6779
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
6880
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
81+
"spec_decode_small": BatchSpec(
82+
seq_lens=[128, 256, 512, 1024], query_lens=[4, 4, 4, 4]
83+
),
84+
"spec_decode_medium": BatchSpec(
85+
seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[8, 8, 8, 8, 8, 8]
86+
),
6987
}
7088

7189

@@ -239,61 +257,64 @@ def run_attention_backend(
239257

240258
builder_cls, impl_cls = try_get_attention_backend(backend)
241259

242-
# Build metadata
243-
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
244-
attn_metadata = builder.build(
245-
common_prefix_len=0,
246-
common_attn_metadata=common_attn_metadata,
247-
)
260+
# Set the current vllm config so that get_current_vllm_config() works
261+
# in the backend implementations
262+
with set_current_vllm_config(vllm_config):
263+
# Build metadata
264+
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
265+
attn_metadata = builder.build(
266+
common_prefix_len=0,
267+
common_attn_metadata=common_attn_metadata,
268+
)
248269

249-
# Instantiate MLA implementation
250-
num_heads = vllm_config.model_config.get_num_attention_heads(
251-
vllm_config.parallel_config
252-
)
253-
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
254-
vllm_config.parallel_config
255-
)
256-
head_size = vllm_config.model_config.get_head_size()
257-
scale = 1.0 / (head_size**0.5)
258-
impl = impl_cls(
259-
num_heads=num_heads,
260-
head_size=head_size,
261-
scale=scale,
262-
num_kv_heads=num_kv_heads,
263-
alibi_slopes=None,
264-
sliding_window=None,
265-
kv_cache_dtype="auto",
266-
logits_soft_cap=None,
267-
attn_type="decoder",
268-
kv_sharing_target_layer_name=None,
269-
q_lora_rank=None,
270-
kv_lora_rank=kv_lora_rank,
271-
qk_nope_head_dim=qk_nope_head_dim,
272-
qk_rope_head_dim=qk_rope_head_dim,
273-
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
274-
v_head_dim=v_head_dim,
275-
kv_b_proj=mock_kv_b_proj,
276-
)
270+
# Instantiate MLA implementation
271+
num_heads = vllm_config.model_config.get_num_attention_heads(
272+
vllm_config.parallel_config
273+
)
274+
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
275+
vllm_config.parallel_config
276+
)
277+
head_size = vllm_config.model_config.get_head_size()
278+
scale = 1.0 / (head_size**0.5)
279+
impl = impl_cls(
280+
num_heads=num_heads,
281+
head_size=head_size,
282+
scale=scale,
283+
num_kv_heads=num_kv_heads,
284+
alibi_slopes=None,
285+
sliding_window=None,
286+
kv_cache_dtype="auto",
287+
logits_soft_cap=None,
288+
attn_type="decoder",
289+
kv_sharing_target_layer_name=None,
290+
q_lora_rank=None,
291+
kv_lora_rank=kv_lora_rank,
292+
qk_nope_head_dim=qk_nope_head_dim,
293+
qk_rope_head_dim=qk_rope_head_dim,
294+
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
295+
v_head_dim=v_head_dim,
296+
kv_b_proj=mock_kv_b_proj,
297+
)
277298

278-
# Process weights to create W_UK_T and W_UV attributes needed by MLA
279-
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
280-
impl.process_weights_after_loading(act_dtype)
299+
# Process weights to create W_UK_T and W_UV attributes needed by MLA
300+
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
301+
impl.process_weights_after_loading(act_dtype)
281302

282-
# Create mock layer and output buffer
283-
mock_layer = MockAttentionLayer(device)
284-
num_tokens = query.shape[0]
285-
output = torch.empty(
286-
num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device
287-
)
303+
# Create mock layer and output buffer
304+
mock_layer = MockAttentionLayer(device)
305+
num_tokens = query.shape[0]
306+
output = torch.empty(
307+
num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device
308+
)
288309

289-
# Run forward pass
290-
# NOTE: The query, key, and value are already shaped correctly
291-
# in the calling test function.
292-
output = impl.forward(
293-
mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output
294-
)
310+
# Run forward pass
311+
# NOTE: The query, key, and value are already shaped correctly
312+
# in the calling test function.
313+
output = impl.forward(
314+
mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output
315+
)
295316

296-
return output
317+
return output
297318

298319

299320
@pytest.mark.parametrize(
@@ -309,6 +330,8 @@ def run_attention_backend(
309330
"large_prefill",
310331
"single_decode",
311332
"single_prefill",
333+
"spec_decode_small",
334+
"spec_decode_medium",
312335
],
313336
)
314337
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"])
@@ -328,10 +351,39 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
328351
simulated paged KV cache.
329352
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
330353
"""
354+
from vllm.v1.attention.backends.mla.common import QueryLenSupport
355+
331356
batch_spec = BATCH_SPECS[batch_spec_name]
357+
is_spec_decode_test = batch_spec_name.startswith("spec_decode")
358+
spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA}
359+
360+
block_size = 16
361+
required_blocks = sum(
362+
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
363+
)
364+
# Add 1 for null block at index 0, and some buffer
365+
num_gpu_blocks = required_blocks + 1 + 100
366+
332367
vllm_config = create_vllm_config(
333-
model_name=model, max_model_len=max(batch_spec.seq_lens), num_gpu_blocks=2048
368+
model_name=model,
369+
max_model_len=max(batch_spec.seq_lens),
370+
num_gpu_blocks=num_gpu_blocks,
371+
block_size=block_size,
334372
)
373+
374+
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
375+
if is_spec_decode_test:
376+
from vllm.config import SpeculativeConfig
377+
378+
# Get the query length from the batch spec (they should all be uniform)
379+
query_len = batch_spec.query_lens[0]
380+
# Set num_speculative_tokens to query_len - 1
381+
# (since threshold is 1 + num_spec_tokens)
382+
# Use ngram method which doesn't require a draft model
383+
vllm_config.speculative_config = SpeculativeConfig(
384+
method="ngram", num_speculative_tokens=query_len - 1
385+
)
386+
335387
device = torch.device("cuda:0")
336388

337389
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
@@ -395,11 +447,37 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
395447
# K_PE (rope component): [s_len, 1, qk_rope_head_dim]
396448
k_pe_full = torch.randn(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
397449

398-
# Determine if this is decode or prefill
450+
# Determine if this sequence uses the decode pipeline or prefill
451+
# pipeline for each backend
452+
# NOTE: For spec decode tests with uniform query_len > 1, backends that
453+
# support spec decode (FLASH_ATTN_MLA with varlen support, FLASHMLA with
454+
# uniform support) will use the decode pipeline (MQA-style), while
455+
# backends that only support single-token queries will use the prefill
456+
# pipeline (MHA-style). This ensures the reference implementation
457+
# matches each backend's actual decode/prefill pipeline path.
399458
is_decode = []
400-
for i, backend in enumerate(BACKENDS_TO_TEST):
459+
for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
401460
builder_cls, _ = try_get_attention_backend(backend)
402-
is_decode.append(q_len <= builder_cls.reorder_batch_threshold)
461+
if is_spec_decode_test:
462+
query_len_support = getattr(
463+
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
464+
)
465+
supports_spec = query_len_support != QueryLenSupport.SINGLE_ONLY
466+
is_decode.append(supports_spec)
467+
else:
468+
threshold = getattr(builder_cls, "reorder_batch_threshold", None)
469+
query_len_support = getattr(
470+
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
471+
)
472+
within_threshold = q_len <= threshold if threshold else False
473+
if (
474+
within_threshold
475+
and query_len_support == QueryLenSupport.UNIFORM
476+
and i > 0
477+
):
478+
first_q_len = query_lens[0]
479+
within_threshold = q_len == first_q_len
480+
is_decode.append(within_threshold)
403481

404482
# Split q into nope and rope components
405483
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
@@ -478,11 +556,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
478556
sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0)
479557
sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2)
480558

481-
for i, backend in enumerate(BACKENDS_TO_TEST):
482-
if is_decode[i]:
483-
all_sdpa_outputs[i].append(sdpa_out_i_decode)
559+
for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
560+
if is_decode[backend_idx]:
561+
all_sdpa_outputs[backend_idx].append(sdpa_out_i_decode)
484562
else:
485-
all_sdpa_outputs[i].append(sdpa_out_i_prefill)
563+
all_sdpa_outputs[backend_idx].append(sdpa_out_i_prefill)
486564

487565
# Inputs for vLLM MLA backends are just the new tokens
488566
all_q_vllm.append(q_c)
@@ -497,9 +575,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
497575
query_vllm = torch.cat(all_q_vllm, dim=0)
498576
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
499577
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
500-
sdpa_outputs = []
501-
for i, backend in enumerate(BACKENDS_TO_TEST):
502-
sdpa_outputs.append(torch.cat(all_sdpa_outputs[i], dim=0))
578+
sdpa_outputs = {}
579+
for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
580+
sdpa_outputs[backend] = torch.cat(all_sdpa_outputs[backend_idx], dim=0)
503581

504582
# Create mock kv_b_proj using the same weights as reference implementation
505583
from vllm.model_executor.layers.linear import ColumnParallelLinear
@@ -516,7 +594,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
516594
kv_b_proj_weight = kv_b_proj_weight.view(
517595
kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim)
518596
)
519-
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T)
597+
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False)
520598

521599
# Create metadata using original batch spec
522600
common_attn_metadata = create_common_attn_metadata(
@@ -537,7 +615,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
537615
)
538616

539617
# 4. Run vLLM backends and compare
540-
for i, backend_name in enumerate(BACKENDS_TO_TEST):
618+
for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST):
619+
# Skip backends that don't support spec decode for spec decode tests
620+
if is_spec_decode_test and backend_name not in spec_decode_backends:
621+
continue
622+
541623
backend_output = run_attention_backend(
542624
backend_name,
543625
kv_cache_spec,
@@ -556,14 +638,17 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
556638
mock_kv_b_proj,
557639
)
558640

641+
# Use backend_idx to get the correct SDPA output for this backend
642+
expected_output = sdpa_outputs[backend_name]
643+
559644
# Check shape and dtype consistency
560-
assert backend_output.shape == sdpa_outputs[i].shape, (
645+
assert backend_output.shape == expected_output.shape, (
561646
f"[{backend_name}] shape {backend_output.shape} != "
562-
f"SDPA shape {sdpa_outputs[i].shape}"
647+
f"SDPA shape {expected_output.shape}"
563648
)
564-
assert backend_output.dtype == sdpa_outputs[i].dtype, (
649+
assert backend_output.dtype == expected_output.dtype, (
565650
f"[{backend_name}] dtype {backend_output.dtype} != "
566-
f"SDPA dtype {sdpa_outputs[i].dtype}"
651+
f"SDPA dtype {expected_output.dtype}"
567652
)
568653

569654
assert torch.isfinite(backend_output).all(), (
@@ -574,12 +659,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
574659
rtol = 1e-2
575660
atol = 5e-1
576661

577-
max_diff = torch.max(torch.abs(backend_output - sdpa_outputs[i])).item()
662+
max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
578663
max_rel_diff = torch.max(
579-
torch.abs(backend_output - sdpa_outputs[i]) / torch.abs(sdpa_outputs[i])
664+
torch.abs(backend_output - expected_output) / torch.abs(expected_output)
580665
).item()
581666
all_close = torch.allclose(
582-
backend_output, sdpa_outputs[i], rtol=rtol, atol=atol
667+
backend_output, expected_output, rtol=rtol, atol=atol
583668
)
584669

585670
assert all_close, (

0 commit comments

Comments
 (0)