@@ -3477,8 +3477,10 @@ def _capture_cudagraphs(self, compilation_cases: list[int],
34773477 # We skip EPLB here since we don't want to record dummy metrics
34783478 for num_tokens in compilation_cases :
34793479 # We currently only capture ubatched graphs when its a FULL
3480- # cudagraph and for uniform decode batches.
3481- capture_ubatched_graph = self .parallel_config .enable_dbo \
3480+ # cudagraph, a uniform decode batch, and the number of tokens
3481+ # is above the threshold. Otherwise we just capture a non-ubatched
3482+ # version of the graph
3483+ allow_microbatching = self .parallel_config .enable_dbo \
34823484 and cudagraph_runtime_mode == CUDAGraphMode .FULL \
34833485 and uniform_decode \
34843486 and check_ubatch_thresholds (
@@ -3487,37 +3489,27 @@ def _capture_cudagraphs(self, compilation_cases: list[int],
34873489 uniform_decode = uniform_decode ,
34883490 )
34893491
3490- # Currently we capture both microbatched and non-microbatched
3491- # graphs when capture_ubatched_graph is True, this is because
3492- # occasionally we will be forced out of microbatching due to other
3493- # DP ranks not microbatching (usually caused by an empty second
3494- # microbatch; once we resolve this, we can remove the
3495- # non-microbatched graph capture).
3496- allow_microbatching_options = [True , False ] if \
3497- capture_ubatched_graph else [False ]
3498- for allow_microbatching in allow_microbatching_options :
3499- for _ in range (
3500- self .compilation_config .cudagraph_num_of_warmups ):
3501- # Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
3502- # But be careful, warm up with `NONE`is orthogonal to
3503- # if we want to warm up attention or not. This is
3504- # different from the case where `FULL` implies capture
3505- # attention while `PIECEWISE` implies no attention.
3506- force_attention = (
3507- cudagraph_runtime_mode == CUDAGraphMode .FULL )
3508- self ._dummy_run (num_tokens ,
3509- cudagraph_runtime_mode = CUDAGraphMode .NONE ,
3510- force_attention = force_attention ,
3511- uniform_decode = uniform_decode ,
3512- allow_microbatching = allow_microbatching ,
3513- skip_eplb = True ,
3514- remove_lora = False )
3492+ for _ in range (self .compilation_config .cudagraph_num_of_warmups ):
3493+ # Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
3494+ # But be careful, warm up with `NONE`is orthogonal to
3495+ # if we want to warm up attention or not. This is
3496+ # different from the case where `FULL` implies capture
3497+ # attention while `PIECEWISE` implies no attention.
3498+ force_attention = (
3499+ cudagraph_runtime_mode == CUDAGraphMode .FULL )
35153500 self ._dummy_run (num_tokens ,
3516- cudagraph_runtime_mode = cudagraph_runtime_mode ,
3501+ cudagraph_runtime_mode = CUDAGraphMode .NONE ,
3502+ force_attention = force_attention ,
35173503 uniform_decode = uniform_decode ,
35183504 allow_microbatching = allow_microbatching ,
35193505 skip_eplb = True ,
35203506 remove_lora = False )
3507+ self ._dummy_run (num_tokens ,
3508+ cudagraph_runtime_mode = cudagraph_runtime_mode ,
3509+ uniform_decode = uniform_decode ,
3510+ allow_microbatching = allow_microbatching ,
3511+ skip_eplb = True ,
3512+ remove_lora = False )
35213513 self .maybe_remove_all_loras (self .lora_config )
35223514
35233515 def initialize_attn_backend (self , kv_cache_config : KVCacheConfig ) -> None :
0 commit comments