Skip to content

Commit 1b2ea7a

Browse files
MatthewBonannirobertgshaw2-redhat
authored andcommitted
[Attention] FlashAttention MLA cudagraph support (vllm-project#23958)
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
1 parent ff4b305 commit 1b2ea7a

File tree

7 files changed

+118
-29
lines changed

7 files changed

+118
-29
lines changed

tests/compile/piecewise/test_full_cudagraph.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,16 @@ class BackendConfig:
6161
"cudagraph_mode": "FULL_AND_PIECEWISE",
6262
},
6363
specific_gpu_arch=(9, 0)),
64+
# FlashAttention MLA on Hopper
65+
"FlashAttentionMLA":
66+
BackendConfig(name="FlashAttentionMLA",
67+
env_vars={
68+
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
69+
},
70+
comp_config={
71+
"cudagraph_mode": "FULL_DECODE_ONLY",
72+
},
73+
specific_gpu_arch=(9, 0)),
6474
# Cutlass MLA on Blackwell
6575
"CutlassMLA":
6676
BackendConfig(
@@ -102,7 +112,7 @@ class BackendConfig:
102112
test_params_full_cudagraph = []
103113

104114
# deepseek-ai/DeepSeek-V2-Lite with MLA
105-
MLA_backends = ["FlashMLA", "CutlassMLA"]
115+
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
106116
for mla_backend in MLA_backends:
107117
test_params_full_cudagraph.append(
108118
pytest.param(

tests/v1/attention/test_mla_backends.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def create_and_prepopulate_kv_cache(
7373
kv_c_contexts: list[torch.Tensor],
7474
k_pe_contexts: list[torch.Tensor],
7575
block_size: int,
76-
num_kv_heads: int,
7776
head_size: int,
7877
dtype: torch.dtype,
7978
device: torch.device,
@@ -87,7 +86,6 @@ def create_and_prepopulate_kv_cache(
8786
k_pe_contexts: List of key positional embedding context tensors
8887
for each sequence
8988
block_size: Size of each block
90-
num_kv_heads: Number of KV heads (should be 1 for MLA)
9189
head_size: Size of each head (latent dimension)
9290
dtype: Data type for the cache
9391
device: Device to create the cache on
@@ -285,8 +283,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
285283
query_lens = batch_spec.query_lens
286284
num_q_heads = vllm_config.model_config.get_num_attention_heads(
287285
vllm_config.parallel_config)
288-
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
289-
vllm_config.parallel_config)
290286
head_size = vllm_config.model_config.get_head_size()
291287
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
292288
block_size = vllm_config.cache_config.block_size
@@ -476,7 +472,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
476472
kv_c_contexts=kv_c_contexts,
477473
k_pe_contexts=k_pe_contexts,
478474
block_size=block_size,
479-
num_kv_heads=num_kv_heads,
480475
head_size=head_size,
481476
dtype=dtype,
482477
device=device,

tests/v1/cudagraph/test_cudagraph_mode.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,16 @@ class BackendConfig:
6262
"cudagraph_mode": "FULL_AND_PIECEWISE",
6363
},
6464
specific_gpu_arch=(9, 0)),
65+
# FlashAttention MLA on Hopper
66+
"FlashAttentionMLA":
67+
BackendConfig(name="FlashAttentionMLA",
68+
env_vars={
69+
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
70+
},
71+
comp_config={
72+
"cudagraph_mode": "FULL_DECODE_ONLY",
73+
},
74+
specific_gpu_arch=(9, 0)),
6575
# FA2
6676
"FA2":
6777
BackendConfig(name="FA2",

vllm/v1/attention/backends/mla/common.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -443,11 +443,13 @@ def __init__(self,
443443
self.metadata_cls = metadata_cls \
444444
if metadata_cls is not None else MLACommonMetadata
445445
self.kv_cache_spec = kv_cache_spec
446-
self.device = device
447446
scheduler_config = vllm_config.scheduler_config
448447
self.model_config = vllm_config.model_config
449-
cache_config = vllm_config.cache_config
450448
parallel_config = vllm_config.parallel_config
449+
cache_config = vllm_config.cache_config
450+
self.compilation_config = vllm_config.compilation_config
451+
self.device = device
452+
451453
self.num_heads = self.model_config.get_num_attention_heads(
452454
parallel_config)
453455
self.mla_dims = get_mla_dims(self.model_config)
@@ -608,10 +610,12 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
608610
prefill.prefill_main = self._fi_prefill_main
609611
prefill.prefill_chunks = self._fi_prefill_chunks
610612

611-
def _build_decode(
612-
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
613-
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
614-
query_start_loc_device: torch.Tensor) -> MLACommonDecodeMetadata:
613+
def _build_decode(self, block_table_tensor: torch.Tensor,
614+
seq_lens_cpu: torch.Tensor,
615+
seq_lens_device: torch.Tensor,
616+
query_start_loc_cpu: torch.Tensor,
617+
query_start_loc_device: torch.Tensor,
618+
num_decode_tokens: int) -> MLACommonDecodeMetadata:
615619
return MLACommonDecodeMetadata(
616620
block_table=block_table_tensor,
617621
seq_lens=seq_lens_device,
@@ -624,11 +628,12 @@ def build_for_cudagraph_capture(
624628
Currently, only decode is supported for full cudagraphs with MLA.
625629
"""
626630
m = common_attn_metadata
627-
assert m.num_reqs == m.num_actual_tokens, \
631+
assert m.num_reqs <= (m.num_actual_tokens *
632+
self.reorder_batch_threshold), \
628633
"MLA only supports decode-only full CUDAGraph capture. " \
629634
"Make sure all cudagraph capture sizes <= max_num_seq."
630635

631-
assert m.max_query_len == 1 # decode-only
636+
assert m.max_query_len <= self.reorder_batch_threshold # decode only
632637

633638
return self.build(0, m)
634639

@@ -819,6 +824,7 @@ def build(self,
819824
seq_lens_device=seq_lens[:num_decodes],
820825
query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1],
821826
query_start_loc_device=query_start_loc[:num_decodes + 1],
827+
num_decode_tokens=num_decode_tokens,
822828
)
823829

824830
attn_metadata = self.metadata_cls(

vllm/v1/attention/backends/mla/flashattn_mla.py

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,16 @@
1717
MLACommonImpl,
1818
MLACommonMetadata,
1919
MLACommonMetadataBuilder)
20+
from vllm.v1.attention.backends.utils import AttentionCGSupport
2021
from vllm.v1.kv_cache_interface import AttentionSpec
2122
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
2223

2324
logger = init_logger(__name__)
2425

26+
# NOTE(matt): This is an arbitrary number, copied from
27+
# woosuk's implementation in standard FlashAttention backend
28+
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
29+
2530

2631
class FlashAttnMLABackend(MLACommonBackend):
2732

@@ -48,6 +53,7 @@ class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
4853
max_query_len: int
4954
max_seq_len: int
5055
scheduler_metadata: Optional[torch.Tensor] = None
56+
max_num_splits: int = 0
5157

5258

5359
@dataclass
@@ -57,14 +63,41 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
5763

5864
class FlashAttnMLAMetadataBuilder(
5965
MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
66+
cudagraph_support: ClassVar[AttentionCGSupport] = \
67+
AttentionCGSupport.UNIFORM_BATCH
68+
6069
reorder_batch_threshold: ClassVar[int] = 512
6170

6271
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
6372
vllm_config: VllmConfig, device: torch.device):
6473
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
6574
FlashAttnMLAMetadata)
75+
self.max_num_splits = 0 # No upper bound on the number of splits.
6676
self.fa_aot_schedule = (get_flash_attn_version() == 3)
6777

78+
self.use_full_cuda_graph = \
79+
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
80+
81+
if self.use_full_cuda_graph and self.fa_aot_schedule:
82+
self.max_cudagraph_size = self.compilation_config.max_capture_size
83+
84+
if self.max_cudagraph_size > 992:
85+
# This condition derives from FA3's internal heuristic.
86+
# TODO(woosuk): Support larger cudagraph sizes.
87+
raise ValueError(
88+
"Capture size larger than 992 is not supported for "
89+
"full cuda graph.")
90+
91+
self.scheduler_metadata = torch.zeros(
92+
vllm_config.scheduler_config.max_num_seqs + 1,
93+
dtype=torch.int32,
94+
device=self.device,
95+
)
96+
# When using cuda graph, we need to set the upper bound of the
97+
# number of splits so that large enough intermediate buffers are
98+
# pre-allocated during capture.
99+
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
100+
68101
def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
69102
max_seq_len, causal):
70103
if self.fa_aot_schedule:
@@ -81,14 +114,16 @@ def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
81114
page_size=self.page_size,
82115
cu_seqlens_q=cu_query_lens,
83116
causal=causal,
117+
num_splits=self.max_num_splits,
84118
)
85119
return None
86120

87-
def _build_decode(
88-
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
89-
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
90-
query_start_loc_device: torch.Tensor
91-
) -> FlashAttnMLADecodeMetadata:
121+
def _build_decode(self, block_table_tensor: torch.Tensor,
122+
seq_lens_cpu: torch.Tensor,
123+
seq_lens_device: torch.Tensor,
124+
query_start_loc_cpu: torch.Tensor,
125+
query_start_loc_device: torch.Tensor,
126+
num_decode_tokens: int) -> FlashAttnMLADecodeMetadata:
92127
query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])
93128
max_query_len = query_lens_cpu.max().item()
94129
max_seq_len = seq_lens_cpu.max().item()
@@ -102,13 +137,37 @@ def _build_decode(
102137
causal=True,
103138
)
104139

140+
# For FA3 + full cudagraph
141+
max_num_splits = 0
142+
if self.use_full_cuda_graph and scheduler_metadata is not None:
143+
n = scheduler_metadata.shape[0]
144+
# Ensure the persistent buffer is large enough
145+
assert n <= self.scheduler_metadata.shape[0], \
146+
f"Scheduler metadata size {n} exceeds buffer size " + \
147+
f"{self.scheduler_metadata.shape[0]}"
148+
self.scheduler_metadata[:n] = scheduler_metadata
149+
# NOTE(woosuk): We should zero out the rest of the scheduler
150+
# metadata to guarantee the correctness. Otherwise, some thread
151+
# blocks may use the invalid scheduler metadata and overwrite the
152+
# output buffer.
153+
self.scheduler_metadata[n:] = 0
154+
scheduler_metadata = self.scheduler_metadata[:n]
155+
156+
if num_decode_tokens <= self.max_cudagraph_size:
157+
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
158+
# usage, because the intermediate buffers of size [num_splits,
159+
# num_heads, num_tokens, head_size] are allocated. Therefore,
160+
# we only set num_splits when using cuda graphs.
161+
max_num_splits = self.max_num_splits
162+
105163
return FlashAttnMLADecodeMetadata(
106164
block_table=block_table_tensor,
107165
seq_lens=seq_lens_device,
108166
query_start_loc=query_start_loc_device,
109167
max_query_len=max_query_len,
110168
max_seq_len=max_seq_len,
111169
scheduler_metadata=scheduler_metadata,
170+
max_num_splits=max_num_splits,
112171
)
113172

114173

@@ -175,12 +234,17 @@ def _forward_decode(
175234
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
176235
k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:]
177236

237+
# NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
238+
# kernel uses this to calculate grid dimensions. Ensure it's at least 1
239+
# to prevent invalid grid configuration during graph capture.
240+
max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
241+
178242
o = flash_attn_varlen_func(
179243
q=q_pe,
180244
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
181245
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
182246
q_v=q_nope,
183-
max_seqlen_q=attn_metadata.decode.max_query_len,
247+
max_seqlen_q=max_seqlen_q,
184248
cu_seqlens_q=attn_metadata.decode.query_start_loc,
185249
max_seqlen_k=attn_metadata.decode.max_seq_len,
186250
seqused_k=attn_metadata.decode.seq_lens,
@@ -189,6 +253,7 @@ def _forward_decode(
189253
causal=True,
190254
fa_version=3, # only version 3 is supported
191255
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
256+
num_splits=attn_metadata.decode.max_num_splits,
192257
)
193258

194259
return self._v_up_proj(o)

vllm/v1/attention/backends/mla/flashmla.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
6262
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
6363
FlashMLAMetadata)
6464

65-
self.compilation_config = vllm_config.compilation_config
6665
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
6766
vllm_config.parallel_config)
6867

@@ -85,10 +84,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
8584
device=self.device,
8685
dtype=torch.int32)
8786

88-
def _build_decode(
89-
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
90-
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
91-
query_start_loc_device: torch.Tensor) -> FlashMLADecodeMetadata:
87+
def _build_decode(self, block_table_tensor: torch.Tensor,
88+
seq_lens_cpu: torch.Tensor,
89+
seq_lens_device: torch.Tensor,
90+
query_start_loc_cpu: torch.Tensor,
91+
query_start_loc_device: torch.Tensor,
92+
num_decode_tokens: int) -> FlashMLADecodeMetadata:
9293
tile_scheduler_metadata, num_splits = \
9394
get_mla_metadata(
9495
seq_lens_device,

vllm/v1/attention/backends/mla/rocm_aiter_mla.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
104104
dtype=torch.int32,
105105
device=device)
106106

107-
def _build_decode(
108-
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
109-
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
110-
query_start_loc_device: torch.Tensor) -> AiterMLADecodeMetadata:
107+
def _build_decode(self, block_table_tensor: torch.Tensor,
108+
seq_lens_cpu: torch.Tensor,
109+
seq_lens_device: torch.Tensor,
110+
query_start_loc_cpu: torch.Tensor,
111+
query_start_loc_device: torch.Tensor,
112+
num_decode_tokens: int) -> AiterMLADecodeMetadata:
111113
page_size = self.kv_cache_spec.block_size
112114
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
113115
device = self.device

0 commit comments

Comments
 (0)