Skip to content

Conversation

@varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented May 26, 2025

Enable CUDA Graphs for DP + All2All kernels.

Fixes:

  1. The input buffers to the quant_method aren't captured properly when using CUDAGraphs + torch.compile. This PR introduces a staging area where the hidden_states and router_logits are copied into and it is this tensor that gets passed into quant_method.
  2. It is important that all DP ranks invoke the same number of dispatch and combine kernels. The kernels need to synchronize between DP ranks. When this requirement isn't respected, it manifests as a deadlock. To this effect, introduce a get_dp_padding method in gpu_model_runner.py.

Tests:
Verified correctness using lm_eval locally on 4xH100.

Varun Sundar Rabindranath added 3 commits May 26, 2025 12:00
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label May 26, 2025
Varun Sundar Rabindranath added 2 commits May 26, 2025 13:40
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
@varun-sundar-rabindranath
Copy link
Contributor Author

@bnellnm @youkaichao @tlrmchlsmth PTAL! Thanks.

Varun Sundar Rabindranath added 2 commits May 26, 2025 13:55
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None
if self.moe_parallel_config.use_pplx_kernels:
act_dtype = torch.get_default_dtype()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this always correct? Can we make this a part of the moe config (even if we get it from the same place).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if it is always correct. @mgoin is there a better way to get the activation dtype here ? any pointers ? Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to using dtype from model_config after speaking to Michael IRL.
@bnellnm About moving act_dtype to the moe config - act_dtype is only used here, do you see value in storing it moe config anyways ?

Comment on lines 1115 to 1125
# Gather num_tokens across dp rank
num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = num_tokens
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
device="cpu",
dtype=torch.int32)
from vllm.distributed.parallel_state import get_dp_group
torch.distributed.all_reduce(num_tokens_tensor,
group=get_dp_group().cpu_group)
max_tokens_across_dp_cpu = torch.max(num_tokens_tensor).item()
return max_tokens_across_dp_cpu - num_tokens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you get this from the forward_context?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is called before the forward_context is set.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like each DP rank is padding out to the maximum number of tokens? That's definitely too much padding in the chunked prefill case. Might be OK for disagg P/D.

BTW, could you factor this bit of code out so we don't duplicate it both here and in the forward_context?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah - the padding is pretty aggressive - I can come back to this on a follow up PR (I tried doing a simpler approach and faced some deadlock issues)

BTW, could you factor this bit of code out so we don't duplicate it both here and in the forward_context?
Refactored the code a bit 👍

Comment on lines 1115 to 1125
# Gather num_tokens across dp rank
num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = num_tokens
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
device="cpu",
dtype=torch.int32)
from vllm.distributed.parallel_state import get_dp_group
torch.distributed.all_reduce(num_tokens_tensor,
group=get_dp_group().cpu_group)
max_tokens_across_dp_cpu = torch.max(num_tokens_tensor).item()
return max_tokens_across_dp_cpu - num_tokens
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is called before the forward_context is set.

Varun Sundar Rabindranath added 7 commits May 27, 2025 12:04
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice cleanup

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label May 28, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) May 28, 2025 20:58
@tlrmchlsmth tlrmchlsmth merged commit 7951d78 into vllm-project:main May 28, 2025
74 checks passed
amitm02 pushed a commit to amitm02/vllm that referenced this pull request Jun 1, 2025
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: amit <amit.man@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants