@@ -158,12 +158,13 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
158158
159159 self .aot_schedule = (get_flash_attn_version () == 3 )
160160 self .use_full_cuda_graph = compilation_config .full_cuda_graph
161- if self .use_full_cuda_graph and not self .aot_schedule :
162- raise ValueError ("Full CUDA graph mode requires AOT scheduling, "
163- "which requires FlashAttention 3." )
164- self .scheduler_metadata = torch .zeros (self .runner .max_num_reqs + 1 ,
165- dtype = torch .int32 ,
166- device = self .runner .device )
161+ if self .use_full_cuda_graph :
162+ # NOTE(lucas): AOT scheduling not supported in full cuda graph mode
163+ # yet. This is because the scheduler and kernel need to always use
164+ # the same num_splits (which acts as an upper bound with the
165+ # dynamic split scheduler) which is currently heuristically decided
166+ # by the kernel launching code.
167+ self .aot_schedule = False
167168
168169 # Sliding window size to be used with the AOT scheduler will be
169170 # populated on first build() call.
@@ -299,18 +300,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
299300 max_seq_len = max_seq_len ,
300301 causal = True )
301302
302- if self .use_full_cuda_graph :
303- assert scheduler_metadata is not None
304- n = scheduler_metadata .shape [0 ]
305- self .scheduler_metadata [:n ].copy_ (scheduler_metadata ,
306- non_blocking = True )
307- # NOTE(woosuk): We should zero out the rest of the scheduler
308- # metadata to guarantee the correctness. Otherwise, some thread
309- # blocks may use the invalid scheduler metadata and overwrite the
310- # output buffer.
311- self .scheduler_metadata [n :] = 0
312- scheduler_metadata = self .scheduler_metadata [:n ]
313-
314303 attn_metadata = FlashAttentionMetadata (
315304 num_actual_tokens = num_actual_tokens ,
316305 max_query_len = max_query_len ,
0 commit comments