Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 156 additions & 71 deletions tests/v1/attention/test_mla_backends.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for v1 MLA backends without GPUModelRunner dependency."""
"""Tests for v1 MLA backends without GPUModelRunner dependency.

Known Issues:
- FLASH_ATTN_MLA backend occasionally produces NaN values in
test_backend_correctness[mixed_small] when run after
test_backend_correctness[small_prefill], but passes when run alone.
"""

import pytest
import torch
Expand All @@ -14,6 +20,8 @@
)
from vllm import _custom_ops as ops
from vllm.attention.backends.registry import _Backend
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.config.vllm import set_current_vllm_config
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
Expand All @@ -29,6 +37,10 @@
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA)

# Remove FLASHMLA from the list if not supported
if not is_flashmla_dense_supported()[0]:
BACKENDS_TO_TEST.remove(_Backend.FLASHMLA)

torch.manual_seed(42)


Expand Down Expand Up @@ -66,6 +78,12 @@ def _convert_dtype_to_torch(dtype):
"large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
"spec_decode_small": BatchSpec(
seq_lens=[128, 256, 512, 1024], query_lens=[4, 4, 4, 4]
),
"spec_decode_medium": BatchSpec(
seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[8, 8, 8, 8, 8, 8]
),
}


Expand Down Expand Up @@ -239,61 +257,64 @@ def run_attention_backend(

builder_cls, impl_cls = try_get_attention_backend(backend)

# Build metadata
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
# Set the current vllm config so that get_current_vllm_config() works
# in the backend implementations
with set_current_vllm_config(vllm_config):
# Build metadata
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)

# Instantiate MLA implementation
num_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config
)
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config
)
head_size = vllm_config.model_config.get_head_size()
scale = 1.0 / (head_size**0.5)
impl = impl_cls(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
logits_soft_cap=None,
attn_type="decoder",
kv_sharing_target_layer_name=None,
q_lora_rank=None,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_b_proj=mock_kv_b_proj,
)
# Instantiate MLA implementation
num_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config
)
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config
)
head_size = vllm_config.model_config.get_head_size()
scale = 1.0 / (head_size**0.5)
impl = impl_cls(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
logits_soft_cap=None,
attn_type="decoder",
kv_sharing_target_layer_name=None,
q_lora_rank=None,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_b_proj=mock_kv_b_proj,
)

# Process weights to create W_UK_T and W_UV attributes needed by MLA
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
impl.process_weights_after_loading(act_dtype)
# Process weights to create W_UK_T and W_UV attributes needed by MLA
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
impl.process_weights_after_loading(act_dtype)

# Create mock layer and output buffer
mock_layer = MockAttentionLayer(device)
num_tokens = query.shape[0]
output = torch.empty(
num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device
)
# Create mock layer and output buffer
mock_layer = MockAttentionLayer(device)
num_tokens = query.shape[0]
output = torch.empty(
num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device
)

# Run forward pass
# NOTE: The query, key, and value are already shaped correctly
# in the calling test function.
output = impl.forward(
mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output
)
# Run forward pass
# NOTE: The query, key, and value are already shaped correctly
# in the calling test function.
output = impl.forward(
mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output
)

return output
return output


@pytest.mark.parametrize(
Expand All @@ -309,6 +330,8 @@ def run_attention_backend(
"large_prefill",
"single_decode",
"single_prefill",
"spec_decode_small",
"spec_decode_medium",
],
)
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"])
Expand All @@ -328,10 +351,39 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
simulated paged KV cache.
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
"""
from vllm.v1.attention.backends.mla.common import QueryLenSupport

batch_spec = BATCH_SPECS[batch_spec_name]
is_spec_decode_test = batch_spec_name.startswith("spec_decode")
spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA}

block_size = 16
required_blocks = sum(
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
)
# Add 1 for null block at index 0, and some buffer
num_gpu_blocks = required_blocks + 1 + 100

vllm_config = create_vllm_config(
model_name=model, max_model_len=max(batch_spec.seq_lens), num_gpu_blocks=2048
model_name=model,
max_model_len=max(batch_spec.seq_lens),
num_gpu_blocks=num_gpu_blocks,
block_size=block_size,
)

# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
if is_spec_decode_test:
from vllm.config import SpeculativeConfig

# Get the query length from the batch spec (they should all be uniform)
query_len = batch_spec.query_lens[0]
# Set num_speculative_tokens to query_len - 1
# (since threshold is 1 + num_spec_tokens)
# Use ngram method which doesn't require a draft model
vllm_config.speculative_config = SpeculativeConfig(
method="ngram", num_speculative_tokens=query_len - 1
)

device = torch.device("cuda:0")

kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
Expand Down Expand Up @@ -395,11 +447,37 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
# K_PE (rope component): [s_len, 1, qk_rope_head_dim]
k_pe_full = torch.randn(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device)

# Determine if this is decode or prefill
# Determine if this sequence uses the decode pipeline or prefill
# pipeline for each backend
# NOTE: For spec decode tests with uniform query_len > 1, backends that
# support spec decode (FLASH_ATTN_MLA with varlen support, FLASHMLA with
# uniform support) will use the decode pipeline (MQA-style), while
# backends that only support single-token queries will use the prefill
# pipeline (MHA-style). This ensures the reference implementation
# matches each backend's actual decode/prefill pipeline path.
is_decode = []
for i, backend in enumerate(BACKENDS_TO_TEST):
for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
builder_cls, _ = try_get_attention_backend(backend)
is_decode.append(q_len <= builder_cls.reorder_batch_threshold)
if is_spec_decode_test:
query_len_support = getattr(
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
)
supports_spec = query_len_support != QueryLenSupport.SINGLE_ONLY
is_decode.append(supports_spec)
else:
threshold = getattr(builder_cls, "reorder_batch_threshold", None)
query_len_support = getattr(
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
)
within_threshold = q_len <= threshold if threshold else False
if (
within_threshold
and query_len_support == QueryLenSupport.UNIFORM
and i > 0
):
first_q_len = query_lens[0]
within_threshold = q_len == first_q_len
is_decode.append(within_threshold)

# Split q into nope and rope components
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
Expand Down Expand Up @@ -478,11 +556,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0)
sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2)

for i, backend in enumerate(BACKENDS_TO_TEST):
if is_decode[i]:
all_sdpa_outputs[i].append(sdpa_out_i_decode)
for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
if is_decode[backend_idx]:
all_sdpa_outputs[backend_idx].append(sdpa_out_i_decode)
else:
all_sdpa_outputs[i].append(sdpa_out_i_prefill)
all_sdpa_outputs[backend_idx].append(sdpa_out_i_prefill)

# Inputs for vLLM MLA backends are just the new tokens
all_q_vllm.append(q_c)
Expand All @@ -497,9 +575,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
query_vllm = torch.cat(all_q_vllm, dim=0)
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
sdpa_outputs = []
for i, backend in enumerate(BACKENDS_TO_TEST):
sdpa_outputs.append(torch.cat(all_sdpa_outputs[i], dim=0))
sdpa_outputs = {}
for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
sdpa_outputs[backend] = torch.cat(all_sdpa_outputs[backend_idx], dim=0)

# Create mock kv_b_proj using the same weights as reference implementation
from vllm.model_executor.layers.linear import ColumnParallelLinear
Expand All @@ -516,7 +594,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
kv_b_proj_weight = kv_b_proj_weight.view(
kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim)
)
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T)
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False)

# Create metadata using original batch spec
common_attn_metadata = create_common_attn_metadata(
Expand All @@ -537,7 +615,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
)

# 4. Run vLLM backends and compare
for i, backend_name in enumerate(BACKENDS_TO_TEST):
for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST):
# Skip backends that don't support spec decode for spec decode tests
if is_spec_decode_test and backend_name not in spec_decode_backends:
continue

backend_output = run_attention_backend(
backend_name,
kv_cache_spec,
Expand All @@ -556,14 +638,17 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
mock_kv_b_proj,
)

# Use backend_idx to get the correct SDPA output for this backend
expected_output = sdpa_outputs[backend_name]

# Check shape and dtype consistency
assert backend_output.shape == sdpa_outputs[i].shape, (
assert backend_output.shape == expected_output.shape, (
f"[{backend_name}] shape {backend_output.shape} != "
f"SDPA shape {sdpa_outputs[i].shape}"
f"SDPA shape {expected_output.shape}"
)
assert backend_output.dtype == sdpa_outputs[i].dtype, (
assert backend_output.dtype == expected_output.dtype, (
f"[{backend_name}] dtype {backend_output.dtype} != "
f"SDPA dtype {sdpa_outputs[i].dtype}"
f"SDPA dtype {expected_output.dtype}"
)

assert torch.isfinite(backend_output).all(), (
Expand All @@ -574,12 +659,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
rtol = 1e-2
atol = 5e-1

max_diff = torch.max(torch.abs(backend_output - sdpa_outputs[i])).item()
max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
max_rel_diff = torch.max(
torch.abs(backend_output - sdpa_outputs[i]) / torch.abs(sdpa_outputs[i])
torch.abs(backend_output - expected_output) / torch.abs(expected_output)
).item()
all_close = torch.allclose(
backend_output, sdpa_outputs[i], rtol=rtol, atol=atol
backend_output, expected_output, rtol=rtol, atol=atol
)

assert all_close, (
Expand Down
Loading