Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion tests/compile/piecewise/test_full_cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ class BackendConfig:
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# FlashAttention MLA on Hopper
"FlashAttentionMLA":
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
specific_gpu_arch=(9, 0)),
# Cutlass MLA on Blackwell
"CutlassMLA":
BackendConfig(
Expand Down Expand Up @@ -102,7 +112,7 @@ class BackendConfig:
test_params_full_cudagraph = []

# deepseek-ai/DeepSeek-V2-Lite with MLA
MLA_backends = ["FlashMLA", "CutlassMLA"]
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
for mla_backend in MLA_backends:
test_params_full_cudagraph.append(
pytest.param(
Expand Down
5 changes: 0 additions & 5 deletions tests/v1/attention/test_mla_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def create_and_prepopulate_kv_cache(
kv_c_contexts: list[torch.Tensor],
k_pe_contexts: list[torch.Tensor],
block_size: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
Expand All @@ -87,7 +86,6 @@ def create_and_prepopulate_kv_cache(
k_pe_contexts: List of key positional embedding context tensors
for each sequence
block_size: Size of each block
num_kv_heads: Number of KV heads (should be 1 for MLA)
head_size: Size of each head (latent dimension)
dtype: Data type for the cache
device: Device to create the cache on
Expand Down Expand Up @@ -285,8 +283,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
query_lens = batch_spec.query_lens
num_q_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config)
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config)
head_size = vllm_config.model_config.get_head_size()
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
block_size = vllm_config.cache_config.block_size
Expand Down Expand Up @@ -476,7 +472,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
kv_c_contexts=kv_c_contexts,
k_pe_contexts=k_pe_contexts,
block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
device=device,
Expand Down
10 changes: 10 additions & 0 deletions tests/v1/cudagraph/test_cudagraph_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ class BackendConfig:
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# FlashAttention MLA on Hopper
"FlashAttentionMLA":
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
specific_gpu_arch=(9, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",
Expand Down
22 changes: 14 additions & 8 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,11 +443,13 @@ def __init__(self,
self.metadata_cls = metadata_cls \
if metadata_cls is not None else MLACommonMetadata
self.kv_cache_spec = kv_cache_spec
self.device = device
scheduler_config = vllm_config.scheduler_config
self.model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
parallel_config = vllm_config.parallel_config
cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.device = device

self.num_heads = self.model_config.get_num_attention_heads(
parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
Expand Down Expand Up @@ -608,10 +610,12 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
prefill.prefill_main = self._fi_prefill_main
prefill.prefill_chunks = self._fi_prefill_chunks

def _build_decode(
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor) -> MLACommonDecodeMetadata:
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int) -> MLACommonDecodeMetadata:
return MLACommonDecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
Expand All @@ -624,11 +628,12 @@ def build_for_cudagraph_capture(
Currently, only decode is supported for full cudagraphs with MLA.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
assert m.num_reqs <= (m.num_actual_tokens *
self.reorder_batch_threshold), \
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."

assert m.max_query_len == 1 # decode-only
assert m.max_query_len <= self.reorder_batch_threshold # decode only

return self.build(0, m)

Expand Down Expand Up @@ -819,6 +824,7 @@ def build(self,
seq_lens_device=seq_lens[:num_decodes],
query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1],
query_start_loc_device=query_start_loc[:num_decodes + 1],
num_decode_tokens=num_decode_tokens,
)

attn_metadata = self.metadata_cls(
Expand Down
77 changes: 71 additions & 6 deletions vllm/v1/attention/backends/mla/flashattn_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata

logger = init_logger(__name__)

# NOTE(matt): This is an arbitrary number, copied from
# woosuk's implementation in standard FlashAttention backend
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16


class FlashAttnMLABackend(MLACommonBackend):

Expand All @@ -48,6 +53,7 @@ class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
max_query_len: int
max_seq_len: int
scheduler_metadata: Optional[torch.Tensor] = None
max_num_splits: int = 0


@dataclass
Expand All @@ -57,14 +63,41 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):

class FlashAttnMLAMetadataBuilder(
MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH

reorder_batch_threshold: ClassVar[int] = 512

def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
FlashAttnMLAMetadata)
self.max_num_splits = 0 # No upper bound on the number of splits.
self.fa_aot_schedule = (get_flash_attn_version() == 3)

self.use_full_cuda_graph = \
self.compilation_config.cudagraph_mode.has_full_cudagraphs()

if self.use_full_cuda_graph and self.fa_aot_schedule:
self.max_cudagraph_size = self.compilation_config.max_capture_size

if self.max_cudagraph_size > 992:
# This condition derives from FA3's internal heuristic.
# TODO(woosuk): Support larger cudagraph sizes.
raise ValueError(
"Capture size larger than 992 is not supported for "
"full cuda graph.")

self.scheduler_metadata = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1,
dtype=torch.int32,
device=self.device,
)
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
# pre-allocated during capture.
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH

def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
if self.fa_aot_schedule:
Expand All @@ -81,14 +114,16 @@ def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
page_size=self.page_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
num_splits=self.max_num_splits,
)
return None

def _build_decode(
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor
) -> FlashAttnMLADecodeMetadata:
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int) -> FlashAttnMLADecodeMetadata:
query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])
max_query_len = query_lens_cpu.max().item()
max_seq_len = seq_lens_cpu.max().item()
Expand All @@ -102,13 +137,37 @@ def _build_decode(
causal=True,
)

# For FA3 + full cudagraph
max_num_splits = 0
if self.use_full_cuda_graph and scheduler_metadata is not None:
n = scheduler_metadata.shape[0]
# Ensure the persistent buffer is large enough
assert n <= self.scheduler_metadata.shape[0], \
f"Scheduler metadata size {n} exceeds buffer size " + \
f"{self.scheduler_metadata.shape[0]}"
self.scheduler_metadata[:n] = scheduler_metadata
# NOTE(woosuk): We should zero out the rest of the scheduler
# metadata to guarantee the correctness. Otherwise, some thread
# blocks may use the invalid scheduler metadata and overwrite the
# output buffer.
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]

if num_decode_tokens <= self.max_cudagraph_size:
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits

return FlashAttnMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
query_start_loc=query_start_loc_device,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
scheduler_metadata=scheduler_metadata,
max_num_splits=max_num_splits,
)


Expand Down Expand Up @@ -175,12 +234,17 @@ def _forward_decode(
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:]

# NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
# kernel uses this to calculate grid dimensions. Ensure it's at least 1
# to prevent invalid grid configuration during graph capture.
max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)

o = flash_attn_varlen_func(
q=q_pe,
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
q_v=q_nope,
max_seqlen_q=attn_metadata.decode.max_query_len,
max_seqlen_q=max_seqlen_q,
cu_seqlens_q=attn_metadata.decode.query_start_loc,
max_seqlen_k=attn_metadata.decode.max_seq_len,
seqused_k=attn_metadata.decode.seq_lens,
Expand All @@ -189,6 +253,7 @@ def _forward_decode(
causal=True,
fa_version=3, # only version 3 is supported
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
num_splits=attn_metadata.decode.max_num_splits,
)

return self._v_up_proj(o)
11 changes: 6 additions & 5 deletions vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
FlashMLAMetadata)

self.compilation_config = vllm_config.compilation_config
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config)

Expand All @@ -85,10 +84,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
device=self.device,
dtype=torch.int32)

def _build_decode(
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor) -> FlashMLADecodeMetadata:
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens_device,
Expand Down
10 changes: 6 additions & 4 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
dtype=torch.int32,
device=device)

def _build_decode(
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor) -> AiterMLADecodeMetadata:
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int) -> AiterMLADecodeMetadata:
page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
device = self.device
Expand Down