@@ -89,10 +89,9 @@ def __init__(
8989 self .use_full_cuda_graph = (
9090 self .compilation_config .cudagraph_mode .has_full_cudagraphs ()
9191 )
92+ self .max_cudagraph_size = self .compilation_config .max_capture_size
9293
9394 if self .use_full_cuda_graph and self .fa_aot_schedule :
94- self .max_cudagraph_size = self .compilation_config .max_capture_size
95-
9695 if self .max_cudagraph_size > 992 :
9796 # This condition derives from FA3's internal heuristic.
9897 # TODO(woosuk): Support larger cudagraph sizes.
@@ -114,7 +113,7 @@ def __init__(
114113 self .max_num_splits = 1
115114
116115 def _schedule_decode (
117- self , num_reqs , cu_query_lens , max_query_len , seqlens , max_seq_len , causal
116+ self , num_reqs , cu_query_lens , max_query_len , seqlens , max_seq_len , causal , max_num_splits
118117 ):
119118 if self .fa_aot_schedule :
120119 return get_scheduler_metadata (
@@ -130,7 +129,7 @@ def _schedule_decode(
130129 page_size = self .page_size ,
131130 cu_seqlens_q = cu_query_lens ,
132131 causal = causal ,
133- num_splits = self . max_num_splits ,
132+ num_splits = max_num_splits ,
134133 )
135134 return None
136135
@@ -148,17 +147,27 @@ def _build_decode(
148147 max_query_len = query_lens_cpu .max ().item ()
149148 max_seq_len = seq_lens_device .max ().item ()
150149
150+
151+ # For Flash Attention MLA + full cudagraph
152+ max_num_splits = 0
153+ if self .use_full_cuda_graph and \
154+ num_decode_tokens <= self .max_cudagraph_size :
155+ # NOTE(woosuk): Setting num_splits > 1 may increase the memory
156+ # usage, because the intermediate buffers of size [num_splits,
157+ # num_heads, num_tokens, head_size] are allocated. Therefore,
158+ # we only set num_splits when using cuda graphs.
159+ max_num_splits = self .max_num_splits
160+
151161 scheduler_metadata = self ._schedule_decode (
152162 num_reqs = seq_lens_cpu .numel (),
153163 cu_query_lens = query_start_loc_device ,
154164 max_query_len = max_query_len ,
155165 seqlens = seq_lens_device ,
156166 max_seq_len = max_seq_len ,
157167 causal = True ,
168+ max_num_splits = max_num_splits ,
158169 )
159170
160- # For FA3 + full cudagraph
161- max_num_splits = 0
162171 if self .use_full_cuda_graph and scheduler_metadata is not None :
163172 n = scheduler_metadata .shape [0 ]
164173 # Ensure the persistent buffer is large enough
@@ -174,13 +183,6 @@ def _build_decode(
174183 self .scheduler_metadata [n :] = 0
175184 scheduler_metadata = self .scheduler_metadata [:n ]
176185
177- if num_decode_tokens <= self .max_cudagraph_size :
178- # NOTE(woosuk): Setting num_splits > 1 may increase the memory
179- # usage, because the intermediate buffers of size [num_splits,
180- # num_heads, num_tokens, head_size] are allocated. Therefore,
181- # we only set num_splits when using cuda graphs.
182- max_num_splits = self .max_num_splits
183-
184186 if vllm_is_batch_invariant ():
185187 max_num_splits = 1
186188
0 commit comments