Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3066,13 +3066,19 @@ def _dummy_run(
# We currently only microbatch if the number of tokens is
# over a certain threshold.
if self.parallel_config.enable_dbo and allow_microbatching:
ubatch_slices, num_tokens_after_padding = ubatch_split(
ubatch_slices, ubatch_num_tokens_after_padding = ubatch_split(
num_scheduled_tokens,
total_num_scheduled_tokens,
total_num_scheduled_tokens,
uniform_decode=uniform_decode,
vllm_config=self.vllm_config,
)
# Currently when DBO is enabled `ubatch_split` returns
# the num_tokens_after_padding for a single ubatch, but we have 2
# TODO(sage,lucas): this is cruft that should be addressed in the
# padding refactor.
if ubatch_num_tokens_after_padding is not None:
num_tokens_after_padding = ubatch_num_tokens_after_padding * 2

# If we failed to microbatch, currently need to resynchronize
# TODO(lucas,sage): we should be able to avoid this second sync by
Expand Down Expand Up @@ -3189,7 +3195,7 @@ def _dummy_run(

# filter out the valid batch descriptor
_cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch(
BatchDescriptor(num_tokens=num_tokens,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Using num_tokens for the BatchDescriptor is incorrect in a distributed setting as it represents the local, unpadded number of tokens, which can differ across data-parallel ranks. This discrepancy can lead to ranks dispatching to different CUDA graphs (or none at all), causing a hang during subsequent collective operations. The change to use num_tokens_after_padding is correct, as this value is synchronized across all ranks, ensuring consistent CUDA graph dispatching and preventing hangs.

                BatchDescriptor(num_tokens=num_tokens_after_padding,

BatchDescriptor(num_tokens=num_tokens_after_padding,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

uniform_decode=uniform_decode)) \
if not is_profile else (CUDAGraphMode.NONE, None)
if cudagraph_runtime_mode is not None:
Expand All @@ -3203,7 +3209,13 @@ def _dummy_run(
cudagraph_runtime_mode = _cg_mode

if ubatch_slices is not None:
num_tokens = num_tokens // 2
# Adjust values to reflect a single ubatch.
# TODO(sage,lucas): this is cruft that should be addressed in
# the padding refactor.
num_tokens_after_padding = ubatch_slices[0].num_tokens
if num_tokens_across_dp is not None:
num_tokens_across_dp[:] = num_tokens_after_padding

with self.maybe_randomize_inputs(input_ids), set_forward_context(
attn_metadata,
self.vllm_config,
Expand Down