@@ -3075,13 +3075,19 @@ def _dummy_run(
30753075 # We currently only microbatch if the number of tokens is
30763076 # over a certain threshold.
30773077 if self .parallel_config .enable_dbo and allow_microbatching :
3078- ubatch_slices , num_tokens_after_padding = ubatch_split (
3078+ ubatch_slices , ubatch_num_tokens_after_padding = ubatch_split (
30793079 num_scheduled_tokens ,
30803080 total_num_scheduled_tokens ,
30813081 total_num_scheduled_tokens ,
30823082 uniform_decode = uniform_decode ,
30833083 vllm_config = self .vllm_config ,
30843084 )
3085+ # Currently when DBO is enabled `ubatch_split` returns
3086+ # the num_tokens_after_padding for a single ubatch, but we have 2
3087+ # TODO(sage,lucas): this is cruft that should be addressed in the
3088+ # padding refactor.
3089+ if ubatch_num_tokens_after_padding is not None :
3090+ num_tokens_after_padding = ubatch_num_tokens_after_padding * 2
30853091
30863092 # If we failed to microbatch, currently need to resynchronize
30873093 # TODO(lucas,sage): we should be able to avoid this second sync by
@@ -3198,7 +3204,7 @@ def _dummy_run(
31983204
31993205 # filter out the valid batch descriptor
32003206 _cg_mode , batch_descriptor = self .cudagraph_dispatcher .dispatch (
3201- BatchDescriptor (num_tokens = num_tokens ,
3207+ BatchDescriptor (num_tokens = num_tokens_after_padding ,
32023208 uniform_decode = uniform_decode )) \
32033209 if not is_profile else (CUDAGraphMode .NONE , None )
32043210 if cudagraph_runtime_mode is not None :
@@ -3212,7 +3218,13 @@ def _dummy_run(
32123218 cudagraph_runtime_mode = _cg_mode
32133219
32143220 if ubatch_slices is not None :
3215- num_tokens = num_tokens // 2
3221+ # Adjust values to reflect a single ubatch.
3222+ # TODO(sage,lucas): this is cruft that should be addressed in
3223+ # the padding refactor.
3224+ num_tokens_after_padding = ubatch_slices [0 ].num_tokens
3225+ if num_tokens_across_dp is not None :
3226+ num_tokens_across_dp [:] = num_tokens_after_padding
3227+
32163228 with self .maybe_randomize_inputs (input_ids ), set_forward_context (
32173229 attn_metadata ,
32183230 self .vllm_config ,
0 commit comments