diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f8b0b9cba1bc..4425d2731289 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3066,13 +3066,19 @@ def _dummy_run( # We currently only microbatch if the number of tokens is # over a certain threshold. if self.parallel_config.enable_dbo and allow_microbatching: - ubatch_slices, num_tokens_after_padding = ubatch_split( + ubatch_slices, ubatch_num_tokens_after_padding = ubatch_split( num_scheduled_tokens, total_num_scheduled_tokens, total_num_scheduled_tokens, uniform_decode=uniform_decode, vllm_config=self.vllm_config, ) + # Currently when DBO is enabled `ubatch_split` returns + # the num_tokens_after_padding for a single ubatch, but we have 2 + # TODO(sage,lucas): this is cruft that should be addressed in the + # padding refactor. + if ubatch_num_tokens_after_padding is not None: + num_tokens_after_padding = ubatch_num_tokens_after_padding * 2 # If we failed to microbatch, currently need to resynchronize # TODO(lucas,sage): we should be able to avoid this second sync by @@ -3189,7 +3195,7 @@ def _dummy_run( # filter out the valid batch descriptor _cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( - BatchDescriptor(num_tokens=num_tokens, + BatchDescriptor(num_tokens=num_tokens_after_padding, uniform_decode=uniform_decode)) \ if not is_profile else (CUDAGraphMode.NONE, None) if cudagraph_runtime_mode is not None: @@ -3203,7 +3209,13 @@ def _dummy_run( cudagraph_runtime_mode = _cg_mode if ubatch_slices is not None: - num_tokens = num_tokens // 2 + # Adjust values to reflect a single ubatch. + # TODO(sage,lucas): this is cruft that should be addressed in + # the padding refactor. + num_tokens_after_padding = ubatch_slices[0].num_tokens + if num_tokens_across_dp is not None: + num_tokens_across_dp[:] = num_tokens_after_padding + with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, self.vllm_config,