Skip to content

Conversation

@varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented Oct 4, 2025

Purpose

Fixes #26137

PR #24845 removed the padding. This PR re-introduce the padding.

Test Plan

server:
ALL2ALL Backend Naive:

vllm serve deepseek-ai/DeepSeek-V2-Lite-Chat     --disable-log-requests --no-enable-prefix-caching -tp 1 -dp 2 --max-num-seqs 256     --enable-expert-parallel --load-format dummy --gpu-memory-utilization 0.85

DBO:

NCCL_P2P_DISABLE=1 VLLM_ALL2ALL_BACKEND=deepep_low_latency  vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --max-num-seqs 512 --trust-remote-code --data-parallel-size 2 --enable-expert-parallel --gpu-memory-utilization 0.75 --disable-log-requests --enable-dbo --dbo-decode-token-threshold 4

DBO + small cudagraph size:

NCCL_P2P_DISABLE=1 VLLM_ALL2ALL_BACKEND=deepep_low_latency  vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --max-num-seqs 512 --trust-remote-code --data-parallel-size 2 --enable-expert-parallel --gpu-memory-utilization 0.75 --disable-log-requests --enable-dbo --dbo-decode-token-threshold 4 --cuda-graph-sizes 4

client

vllm bench serve --port 8000 --model ${MODEL} --dataset-name random --num-prompts 512 --random-input-len 1024 --random-output-len 512 --trust-remote-code 

Test Result

ALL2ALL Backend Naive, DBO and DBO + small cudagraph size doesn't deadlock.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request re-introduces padding for dummy runs, which is a necessary bug fix. The changes correctly replace num_tokens with num_tokens_after_padding in various places. However, I've identified a potential critical issue where num_tokens_after_padding can exceed the allocated buffer size, leading to an out-of-bounds memory access.

Comment on lines 3204 to 3227
if (self.supports_mm_inputs
and not self.model_config.is_encoder_decoder):
input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
inputs_embeds = self.inputs_embeds.gpu[:
num_tokens_after_padding]
model_kwargs = {
**model_kwargs,
**self._dummy_mm_kwargs(num_reqs),
}
elif self.enable_prompt_embeds:
input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
model_kwargs = self._init_model_kwargs(num_tokens)
inputs_embeds = self.inputs_embeds.gpu[:
num_tokens_after_padding]
model_kwargs = self._init_model_kwargs(
num_tokens_after_padding)
else:
input_ids = self.input_ids.gpu[:num_tokens]
input_ids = self.input_ids.gpu[:num_tokens_after_padding]
inputs_embeds = None

if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_tokens]
positions = self.mrope_positions.gpu[:, :
num_tokens_after_padding]
else:
positions = self.positions.gpu[:num_tokens]
positions = self.positions.gpu[:num_tokens_after_padding]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a potential out-of-bounds memory access issue here and in the following lines that use num_tokens_after_padding for slicing. The value of num_tokens_after_padding can exceed self.max_num_tokens, which is the size of buffers like self.input_ids, self.positions, and self.inputs_embeds.

Here's how it can happen in _dummy_run when DBO is enabled:

  1. _dummy_run is called with num_tokens equal to self.max_num_tokens.
  2. ubatch_split is called, which in turn calls get_dp_padding_ubatch.
  3. get_dp_padding_ubatch calculates num_tokens_padded = round_up(num_tokens, 2). If self.max_num_tokens is odd, this results in self.max_num_tokens + 1.
  4. This padded value is used to calculate num_tokens_per_ubatch, which is then communicated across DP ranks. The maximum is taken.
  5. Back in _dummy_run, num_tokens_after_padding is calculated based on the result from ubatch_split, and can become self.max_num_tokens + 1.

Slicing tensors like self.input_ids.gpu[:num_tokens_after_padding] will then result in an out-of-bounds access, which can lead to memory corruption or a crash. This is a critical issue that needs to be addressed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

REVIEWERS, PTAL

I have added an assert num_tokens_after_padding < self.max_num_tokens to address this.
A better fix is to roundup self.max_num_tokens here

self.max_num_tokens = scheduler_config.max_num_batched_tokens

But we tend to check max_num_tokens against scheduler_config.max_num_batched_tokens in code and it is a reasonable check. This probably needs to be handled more carefully.

However, we generally don't expect max_num_tokens to be odd, and it might never be a issue. But when it does happen, this assert should catch it.

@varun-sundar-rabindranath
Copy link
Contributor Author

Marking this as draft as this might not be the full fix and could introduce bugs in the DBO case.

cc @LucasWilkinson @SageMoore @ProExpertProg @ilmarkov @tlrmchlsmth

@varun-sundar-rabindranath varun-sundar-rabindranath marked this pull request as ready for review October 4, 2025 16:11
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The biggest potential gotcha here is that num_tokens_after_padding will be divided in half for the DBO case. I suspect you will want to pad up the inputs to the "full" padded amount. The UbatchWrapper will take care of slicing them down to the ubatch sizes.

This is definitely annoying but once #25768 merges we won't have this difference.

@varun-sundar-rabindranath
Copy link
Contributor Author

varun-sundar-rabindranath commented Oct 5, 2025

Marking this PR as draft (so we dont land it by mistake) as it deadlocks during DBO sanity checks (benchmarking)

@varun-sundar-rabindranath varun-sundar-rabindranath marked this pull request as draft October 5, 2025 01:06
@varun-sundar-rabindranath
Copy link
Contributor Author

The biggest potential gotcha here is that num_tokens_after_padding will be divided in half for the DBO case. I suspect you will want to pad up the inputs to the "full" padded amount. The UbatchWrapper will take care of slicing them down to the ubatch sizes.

Thanks @SageMoore . I noticed the update to num_tokens_after_padding just before the model forward pass call. Should we do another slice of the buffers after that update? if we do that, I think it'll be correct for the DBO case and a no-op for the non DBO case.

@varun-sundar-rabindranath
Copy link
Contributor Author

The biggest potential gotcha here is that num_tokens_after_padding will be divided in half for the DBO case. I suspect you will want to pad up the inputs to the "full" padded amount. The UbatchWrapper will take care of slicing them down to the ubatch sizes.

Thanks @SageMoore . I noticed the update to num_tokens_after_padding just before the model forward pass call. Should we do another slice of the buffers after that update? if we do that, I think it'll be correct for the DBO case and a no-op for the non DBO case.

Resolved IRL

@mergify
Copy link

mergify bot commented Oct 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @varun-sundar-rabindranath.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 6, 2025
Varun Sundar Rabindranath added 3 commits October 6, 2025 16:13
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
@mergify mergify bot removed the needs-rebase label Oct 6, 2025
@varun-sundar-rabindranath varun-sundar-rabindranath marked this pull request as ready for review October 6, 2025 16:27
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks reasonable, @varun-sundar-rabindranath. Thanks for the fix!

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to me, enabling CI. It would be nice to have a unit test to catalog this behavior as we refactor the model runner

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 6, 2025
@LucasWilkinson LucasWilkinson merged commit f23b4c0 into vllm-project:main Oct 6, 2025
49 checks passed
southfreebird pushed a commit to southfreebird/vllm that referenced this pull request Oct 7, 2025
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
mrasquinha-g pushed a commit to mrasquinha-g/vllm that referenced this pull request Oct 9, 2025
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Assertion in AgRsAll2AllManager in DP+EP

4 participants