Skip to content

Conversation

@tlrmchlsmth
Copy link
Member

@tlrmchlsmth tlrmchlsmth commented Sep 3, 2025

Purpose

Currently, when running attention with TP and using --enable-expert-parallel, the MoE layers will do duplicate work when using DeepEP. In this case, the output of attention will be replicated across TP ranks and each token copy will be dispatched to the EP ranks it gets routed to, multiplying the amount of work by tp_size.

This PR avoids this duplicate work by ensuring the input to the MoE layer is sequence parallel instead of replicated.

Notes:

  • The performance bug applies to any MoE model but this PR only fixes it for DeepSeekV3
  • We can do better by replacing the all_reduce at the end of attention with a reduce_scatter. This reduces the amount of computation but is a little more invasive to the model definition since we need to handle the sharding of the residuals. This has an extra effect of de-duplicating the work done during layer norms (minor improvement).

Test Plan

lm_eval --model local-completions --tasks gsm8k --model_args model=deepseek-ai/DeepSeek-R1-0528,base_url=http://infra-wide-ep-inference-gateway-istio.llm-d-wide-ep.svc.cluster.local/v1/completions,num_concurrent=1000,max_retries=3,tokenized_requests=False

Test Result

This PR

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9583|±  |0.0055|
|     |       |strict-match    |     5|exact_match|↑  |0.9560|±  |0.0056|

5b31cb1 (last good commit before #24119 landed):

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9591|±  |0.0055|
|     |       |strict-match    |     5|exact_match|↑  |0.9568|±  |0.0056|

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
@mergify mergify bot added the deepseek Related to DeepSeek models label Sep 3, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly introduces sequence parallelism to the MoE layer in the DeepseekV2 model to prevent redundant computations when using both Tensor Parallelism and Expert Parallelism. The approach of chunking the input before the MoE layer and gathering the output afterward is sound. I've found one critical issue that could lead to a runtime error, which I've detailed in a specific comment.

Comment on lines 155 to 157
# If using expert parallel, ensure the input to the experts is
# SP to avoid duplicate work.
# Not needed for pplx-kernels as it can handle duplicate input tokens.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abcdabcd987 and @nandor could you double-check me here: Can pplx handle replicated input tokens in the TP attn + EP MoE case?

# If using expert parallel, ensure the input to the experts is
# SP to avoid duplicate work.
# Not needed for pplx-kernels as it can handle duplicate input tokens.
self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND
Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should call this use_sequence_parallel_mlp since we use seq parallelism for just the mlp layer here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I like this because the MLP being sequence parallel is kind of a side effect. And we need to pass it into the fused_moe layer for the chunking.

I'm not a fan of the sequence_parallel name though so definitely open to suggestions

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
@robertgshaw2-redhat
Copy link
Collaborator

This looks good to me. Just left some comments on explaining the parallelism setup.

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 3, 2025
@robertgshaw2-redhat
Copy link
Collaborator

there are genuine failures in the CI related to DeepSeek MTP

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
@mergify
Copy link

mergify bot commented Sep 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tlrmchlsmth.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 3, 2025
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
@mergify mergify bot removed the needs-rebase label Sep 3, 2025
@tlrmchlsmth
Copy link
Member Author

Seeing some issues with CUDA graphs on:

ms-wide-ep-llm-d-modelservice-decode-0-1 vllm-worker-decode (EngineCore_7 pid=290) (VllmWorker TP0 pid=314) ERROR 09-03 19:20:29 [multiproc_executor.py:611]   File "/workspace/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1663, in process_chunk
ms-wide-ep-llm-d-modelservice-decode-0-1 vllm-worker-decode (EngineCore_7 pid=290) (VllmWorker TP0 pid=314) ERROR 09-03 19:20:29 [multiproc_executor.py:611]     staged_hidden_states.copy_(hidden_states, non_blocking=True)
ms-wide-ep-llm-d-modelservice-decode-0-1 vllm-worker-decode (EngineCore_7 pid=290) (VllmWorker TP0 pid=314) ERROR 09-03 19:20:29 [multiproc_executor.py:611] RuntimeError: output with shape [1, 7168] doesn't match the broadcast shape [0, 7168]

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
@tlrmchlsmth
Copy link
Member Author

Seeing some issues with CUDA graphs on:

ms-wide-ep-llm-d-modelservice-decode-0-1 vllm-worker-decode (EngineCore_7 pid=290) (VllmWorker TP0 pid=314) ERROR 09-03 19:20:29 [multiproc_executor.py:611]   File "/workspace/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1663, in process_chunk
ms-wide-ep-llm-d-modelservice-decode-0-1 vllm-worker-decode (EngineCore_7 pid=290) (VllmWorker TP0 pid=314) ERROR 09-03 19:20:29 [multiproc_executor.py:611]     staged_hidden_states.copy_(hidden_states, non_blocking=True)
ms-wide-ep-llm-d-modelservice-decode-0-1 vllm-worker-decode (EngineCore_7 pid=290) (VllmWorker TP0 pid=314) ERROR 09-03 19:20:29 [multiproc_executor.py:611] RuntimeError: output with shape [1, 7168] doesn't match the broadcast shape [0, 7168]

It was a torch.compile issue. I ended up having to work around it by wrapping the sp chunking code in a custom op

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
@nvpohanh
Copy link
Contributor

nvpohanh commented Sep 5, 2025

cc @weireweire

Comment on lines +125 to +161
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()

# all_gather needs the sequence length to be divisible by tp_size
seq_len = x.size(0)
remainder = seq_len % tp_size
if remainder != 0:
pad_len = tp_size - remainder
x = nn.functional.pad(x, (0, 0, 0, pad_len))

chunk = x.shape[0] // tp_size
start = tp_rank * chunk
return torch.narrow(x, 0, start, chunk)


def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
seq_len = cdiv(x.size(0), tp_size)
shape = list(x.shape)
shape[0] = seq_len
out = torch.empty(shape, dtype=x.dtype, device=x.device)
return out


direct_register_custom_op(
op_name="sequence_parallel_chunk",
op_func=sequence_parallel_chunk,
mutates_args=[],
fake_impl=sequence_parallel_chunk_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zou3519 @ProExpertProg in case you have better ideas than this wrap-it-in-a-custom-op hack

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tlrmchlsmth do you have the original error message and/or a stack trace?

Copy link
Collaborator

@zou3519 zou3519 Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this custom operator is technically incorrect. The output is a view of the input, which means bad things can happen in the presence of mutation. I don't know if vLLM specifically will hit any of those issues, it depends on how it's being used.

tlrmchlsmth and others added 4 commits September 7, 2025 17:14
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
@simon-mo simon-mo merged commit 955c624 into vllm-project:main Sep 9, 2025
40 of 42 checks passed
eicherseiji pushed a commit to eicherseiji/vllm that referenced this pull request Sep 9, 2025
…EP MoE (vllm-project#24134)

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
…EP MoE (vllm-project#24134)

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…EP MoE (vllm-project#24134)

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…EP MoE (vllm-project#24134)

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…EP MoE (vllm-project#24134)

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants