Skip to content

Commit 5beacce

Browse files
Daisy-Ma-coderqqma
andauthored
[BugFix] bugfix for Flash Attention MLA with full cuda graph IMA following pr-25490 (#27128)
Signed-off-by: qqma <qqma@amazon.com> Co-authored-by: qqma <qqma@amazon.com>
1 parent 8669c69 commit 5beacce

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

vllm/v1/attention/backends/mla/flashattn_mla.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,9 @@ def __init__(
8989
self.use_full_cuda_graph = (
9090
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
9191
)
92+
self.max_cudagraph_size = self.compilation_config.max_capture_size
9293

9394
if self.use_full_cuda_graph and self.fa_aot_schedule:
94-
self.max_cudagraph_size = self.compilation_config.max_capture_size
95-
9695
if self.max_cudagraph_size > 992:
9796
# This condition derives from FA3's internal heuristic.
9897
# TODO(woosuk): Support larger cudagraph sizes.
@@ -114,7 +113,14 @@ def __init__(
114113
self.max_num_splits = 1
115114

116115
def _schedule_decode(
117-
self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
116+
self,
117+
num_reqs,
118+
cu_query_lens,
119+
max_query_len,
120+
seqlens,
121+
max_seq_len,
122+
causal,
123+
max_num_splits,
118124
):
119125
if self.fa_aot_schedule:
120126
return get_scheduler_metadata(
@@ -130,7 +136,7 @@ def _schedule_decode(
130136
page_size=self.page_size,
131137
cu_seqlens_q=cu_query_lens,
132138
causal=causal,
133-
num_splits=self.max_num_splits,
139+
num_splits=max_num_splits,
134140
)
135141
return None
136142

@@ -148,17 +154,25 @@ def _build_decode(
148154
max_query_len = query_lens_cpu.max().item()
149155
max_seq_len = seq_lens_device.max().item()
150156

157+
# For Flash Attention MLA + full cudagraph
158+
max_num_splits = 0
159+
if self.use_full_cuda_graph and num_decode_tokens <= self.max_cudagraph_size:
160+
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
161+
# usage, because the intermediate buffers of size [num_splits,
162+
# num_heads, num_tokens, head_size] are allocated. Therefore,
163+
# we only set num_splits when using cuda graphs.
164+
max_num_splits = self.max_num_splits
165+
151166
scheduler_metadata = self._schedule_decode(
152167
num_reqs=seq_lens_cpu.numel(),
153168
cu_query_lens=query_start_loc_device,
154169
max_query_len=max_query_len,
155170
seqlens=seq_lens_device,
156171
max_seq_len=max_seq_len,
157172
causal=True,
173+
max_num_splits=max_num_splits,
158174
)
159175

160-
# For FA3 + full cudagraph
161-
max_num_splits = 0
162176
if self.use_full_cuda_graph and scheduler_metadata is not None:
163177
n = scheduler_metadata.shape[0]
164178
# Ensure the persistent buffer is large enough
@@ -174,13 +188,6 @@ def _build_decode(
174188
self.scheduler_metadata[n:] = 0
175189
scheduler_metadata = self.scheduler_metadata[:n]
176190

177-
if num_decode_tokens <= self.max_cudagraph_size:
178-
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
179-
# usage, because the intermediate buffers of size [num_splits,
180-
# num_heads, num_tokens, head_size] are allocated. Therefore,
181-
# we only set num_splits when using cuda graphs.
182-
max_num_splits = self.max_num_splits
183-
184191
if vllm_is_batch_invariant():
185192
max_num_splits = 1
186193

0 commit comments

Comments
 (0)