From e9e2517ae72a508e797ff589bbb9f7b56cc6bfc3 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 27 Aug 2025 21:01:24 +0000 Subject: [PATCH 1/5] First pass, nonfunctional Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/common.py | 17 +++-- .../attention/backends/mla/flashattn_mla.py | 70 +++++++++++++++++-- vllm/v1/attention/backends/mla/flashmla.py | 11 +-- .../attention/backends/mla/rocm_aiter_mla.py | 10 +-- 4 files changed, 88 insertions(+), 20 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index b4c9aae254ea..c2bf4e5e5f31 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -435,11 +435,13 @@ def __init__(self, self.metadata_cls = metadata_cls \ if metadata_cls is not None else MLACommonMetadata self.kv_cache_spec = kv_cache_spec - self.device = device scheduler_config = vllm_config.scheduler_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config parallel_config = vllm_config.parallel_config + cache_config = vllm_config.cache_config + self.compilation_config = vllm_config.compilation_config + self.device = device + self.num_heads = self.model_config.get_num_attention_heads( parallel_config) self.mla_dims = get_mla_dims(self.model_config) @@ -578,10 +580,12 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): prefill.prefill_main = self._fi_prefill_main prefill.prefill_chunks = self._fi_prefill_chunks - def _build_decode( - self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor) -> MLACommonDecodeMetadata: + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int) -> MLACommonDecodeMetadata: return MLACommonDecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, @@ -733,6 +737,7 @@ def build(self, seq_lens_device=seq_lens[:num_decodes], query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1], query_start_loc_device=query_start_loc[:num_decodes + 1], + num_decode_tokens=num_decode_tokens, ) attn_metadata = self.metadata_cls( diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 0e08307ddf84..c5fa64ac8c39 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -17,11 +17,16 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) +from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata logger = init_logger(__name__) +# NOTE(matt): This is an arbitrary number, copied from +# woosuk's implementation in standard FlashAttention backend +_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16 + class FlashAttnMLABackend(MLACommonBackend): @@ -48,6 +53,7 @@ class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata): max_query_len: int max_seq_len: int scheduler_metadata: Optional[torch.Tensor] = None + max_num_splits: int = 0 @dataclass @@ -57,14 +63,41 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): class FlashAttnMLAMetadataBuilder( MLACommonMetadataBuilder[FlashAttnMLAMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_BATCH + reorder_batch_threshold: ClassVar[int] = 512 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): super().__init__(kv_cache_spec, layer_names, vllm_config, device, FlashAttnMLAMetadata) + self.max_num_splits = 0 # No upper bound on the number of splits. self.fa_aot_schedule = (get_flash_attn_version() == 3) + self.use_full_cuda_graph = \ + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + + if self.use_full_cuda_graph and self.fa_aot_schedule: + self.max_cudagraph_size = self.compilation_config.max_capture_size + + if self.max_cudagraph_size > 992: + # This condition derives from FA3's internal heuristic. + # TODO(woosuk): Support larger cudagraph sizes. + raise ValueError( + "Capture size larger than 992 is not supported for " + "full cuda graph.") + + self.scheduler_metadata = torch.zeros( + vllm_config.scheduler_config.max_num_seqs + 1, + dtype=torch.int32, + device=self.device, + ) + # When using cuda graph, we need to set the upper bound of the + # number of splits so that large enough intermediate buffers are + # pre-allocated during capture. + self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): if self.fa_aot_schedule: @@ -81,14 +114,16 @@ def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, page_size=self.page_size, cu_seqlens_q=cu_query_lens, causal=causal, + num_splits=self.max_num_splits, ) return None - def _build_decode( - self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor - ) -> FlashAttnMLADecodeMetadata: + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int) -> FlashAttnMLADecodeMetadata: query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) max_query_len = query_lens_cpu.max().item() max_seq_len = seq_lens_cpu.max().item() @@ -102,6 +137,29 @@ def _build_decode( causal=True, ) + # For FA3 + full cudagraph + max_num_splits = 0 + if self.use_full_cuda_graph and scheduler_metadata is not None: + n = scheduler_metadata.shape[0] + # Ensure the persistent buffer is large enough + assert n <= self.scheduler_metadata.shape[0], \ + f"Scheduler metadata size {n} exceeds buffer size " + \ + f"{self.scheduler_metadata.shape[0]}" + self.scheduler_metadata[:n] = scheduler_metadata + # NOTE(woosuk): We should zero out the rest of the scheduler + # metadata to guarantee the correctness. Otherwise, some thread + # blocks may use the invalid scheduler metadata and overwrite the + # output buffer. + self.scheduler_metadata[n:] = 0 + scheduler_metadata = self.scheduler_metadata[:n] + + if num_decode_tokens <= self.max_cudagraph_size: + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits + return FlashAttnMLADecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, @@ -109,6 +167,7 @@ def _build_decode( max_query_len=max_query_len, max_seq_len=max_seq_len, scheduler_metadata=scheduler_metadata, + max_num_splits=max_num_splits, ) @@ -184,6 +243,7 @@ def _forward_decode( causal=True, fa_version=3, # only version 3 is supported scheduler_metadata=attn_metadata.decode.scheduler_metadata, + num_splits=attn_metadata.decode.max_num_splits, ) return self._v_up_proj(o) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index df617ab7a8ea..1555c5b07dac 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -62,7 +62,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], super().__init__(kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata) - self.compilation_config = vllm_config.compilation_config self.num_q_heads = vllm_config.model_config.get_num_attention_heads( vllm_config.parallel_config) @@ -85,10 +84,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], device=self.device, dtype=torch.int32) - def _build_decode( - self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor) -> FlashMLADecodeMetadata: + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ get_mla_metadata( seq_lens_device, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 42670093daa9..2a92a51d47e8 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -104,10 +104,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], dtype=torch.int32, device=device) - def _build_decode( - self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor) -> AiterMLADecodeMetadata: + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int) -> AiterMLADecodeMetadata: page_size = self.kv_cache_spec.block_size block_table_bounds = (seq_lens_device + page_size - 1) // page_size device = self.device From 5a062763af9426c5917752345d0da14b4ed0ec92 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 29 Aug 2025 18:36:00 +0000 Subject: [PATCH 2/5] Remove unused parameter Signed-off-by: Matthew Bonanni --- tests/v1/attention/test_mla_backends.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index e7cd116fdc83..a62993950aff 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -73,7 +73,6 @@ def create_and_prepopulate_kv_cache( kv_c_contexts: list[torch.Tensor], k_pe_contexts: list[torch.Tensor], block_size: int, - num_kv_heads: int, head_size: int, dtype: torch.dtype, device: torch.device, @@ -87,7 +86,6 @@ def create_and_prepopulate_kv_cache( k_pe_contexts: List of key positional embedding context tensors for each sequence block_size: Size of each block - num_kv_heads: Number of KV heads (should be 1 for MLA) head_size: Size of each head (latent dimension) dtype: Data type for the cache device: Device to create the cache on @@ -285,8 +283,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): query_lens = batch_spec.query_lens num_q_heads = vllm_config.model_config.get_num_attention_heads( vllm_config.parallel_config) - num_kv_heads = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config) head_size = vllm_config.model_config.get_head_size() dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) block_size = vllm_config.cache_config.block_size @@ -476,7 +472,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): kv_c_contexts=kv_c_contexts, k_pe_contexts=k_pe_contexts, block_size=block_size, - num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, device=device, From f3c1b3ca20ed34bfa61331a61656181fd6763386 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 2 Sep 2025 19:30:24 +0000 Subject: [PATCH 3/5] Handle max_query_len zero during cudagraph capture Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/flashattn_mla.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index c5fa64ac8c39..d62285d76b99 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -229,12 +229,17 @@ def _forward_decode( kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:] + # NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the + # kernel uses this to calculate grid dimensions. Ensure it's at least 1 + # to prevent invalid grid configuration during graph capture. + max_seqlen_q = max(attn_metadata.decode.max_query_len, 1) + o = flash_attn_varlen_func( q=q_pe, k=k_pe_cache.unsqueeze(-2), # Add head dim of 1 v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 q_v=q_nope, - max_seqlen_q=attn_metadata.decode.max_query_len, + max_seqlen_q=max_seqlen_q, cu_seqlens_q=attn_metadata.decode.query_start_loc, max_seqlen_k=attn_metadata.decode.max_seq_len, seqused_k=attn_metadata.decode.seq_lens, From 0e286e31f1d6369b5cd66330d059735aa6e56fdb Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 2 Sep 2025 21:09:17 +0000 Subject: [PATCH 4/5] Add FA MLA to CG tests Signed-off-by: Matthew Bonanni --- tests/compile/piecewise/test_full_cudagraph.py | 12 +++++++++++- tests/v1/cudagraph/test_cudagraph_mode.py | 10 ++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index 97140a9db7af..2454f85342eb 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -61,6 +61,16 @@ class BackendConfig: "cudagraph_mode": "FULL_AND_PIECEWISE", }, specific_gpu_arch=(9, 0)), + # FlashAttention MLA on Hopper + "FlashAttentionMLA": + BackendConfig(name="FlashAttentionMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", + }, + comp_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", + }, + specific_gpu_arch=(9, 0)), # Cutlass MLA on Blackwell "CutlassMLA": BackendConfig( @@ -102,7 +112,7 @@ class BackendConfig: test_params_full_cudagraph = [] # deepseek-ai/DeepSeek-V2-Lite with MLA -MLA_backends = ["FlashMLA", "CutlassMLA"] +MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"] for mla_backend in MLA_backends: test_params_full_cudagraph.append( pytest.param( diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py index 81655e417500..25e01806f495 100644 --- a/tests/v1/cudagraph/test_cudagraph_mode.py +++ b/tests/v1/cudagraph/test_cudagraph_mode.py @@ -62,6 +62,16 @@ class BackendConfig: "cudagraph_mode": "FULL_AND_PIECEWISE", }, specific_gpu_arch=(9, 0)), + # FlashAttention MLA on Hopper + "FlashAttentionMLA": + BackendConfig(name="FlashAttentionMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", + }, + comp_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", + }, + specific_gpu_arch=(9, 0)), # FA2 "FA2": BackendConfig(name="FA2", From 235670f424885751ba47080b173e4e0c33444a58 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 8 Sep 2025 17:37:05 +0000 Subject: [PATCH 5/5] Update cudagraph build assertions Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/common.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index c0019d1c70da..226bc436058d 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -628,11 +628,12 @@ def build_for_cudagraph_capture( Currently, only decode is supported for full cudagraphs with MLA. """ m = common_attn_metadata - assert m.num_reqs == m.num_actual_tokens, \ + assert m.num_reqs <= (m.num_actual_tokens * + self.reorder_batch_threshold), \ "MLA only supports decode-only full CUDAGraph capture. " \ "Make sure all cudagraph capture sizes <= max_num_seq." - assert m.max_query_len == 1 # decode-only + assert m.max_query_len <= self.reorder_batch_threshold # decode only return self.build(0, m)