Skip to content

Conversation

@alexm-redhat
Copy link
Collaborator

@alexm-redhat alexm-redhat commented Oct 8, 2025

This PR executes the shared_experts part of the FusedMoE on a separate GPU stream, so that the execution is parallelized with the "selected_experts" part. This is possible since the outputs of both are independent and are later combined. Thanks @wenscarl for pointing this out.

For DeepSeekR1 FP8 with Flashinfer latency kernels (trtllm-gen) on 8xB200s batch size 32, the TPOT improves from 23.35ms to 22.09ms (with latest FlashInfer codebase), so about ~5.7% e2e improvement.

@mergify
Copy link

mergify bot commented Oct 8, 2025

Documentation preview: https://vllm--26440.org.readthedocs.build/en/26440/

@mergify mergify bot added documentation Improvements or additions to documentation deepseek Related to DeepSeek models labels Oct 8, 2025
@alexm-redhat alexm-redhat marked this pull request as draft October 8, 2025 19:48
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Comment on lines 2226 to 2235
use_explicit_se = (
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
and self.shared_experts is not None
)
if use_explicit_se:
current_stream = torch.cuda.current_stream()
self.shared_experts_stream.wait_stream(current_stream)

router_logits, _ = self.gate(hidden_states)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Avoid recomputing router logits via nonexistent gate

The updated forward_impl now calls self.gate(hidden_states) and discards the router_logits argument. FusedMoE does not define a gate module by default (the base property returns None), and most existing callers supply precomputed logits and have no gate attribute wired into the layer. This will raise a TypeError (NoneType is not callable) the first time a standard MoE block without shared experts executes, effectively breaking every model that previously passed logits into FusedMoE.

Useful? React with 👍 / 👎.

Comment on lines 19 to 41
def __init__(
self,
shared_experts: torch.nn.Module,
gate: torch.nn.Module,
use_overlapped: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self._shared_experts = shared_experts
self._gate = gate

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge New SharedFusedMoE constructor breaks existing call sites

The constructor now requires a gate module but none of the existing usages (e.g. in glm4_moe.py and llama4.py) pass this parameter. Instantiating those models will now raise a TypeError at import time because the additional positional argument has no default. Unless every caller is updated simultaneously, this change makes all current shared-fused MoE models unusable.

Useful? React with 👍 / 👎.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization by executing the shared_experts computation on a separate CUDA stream, allowing it to run in parallel with the selected_experts part. This is a great idea for improving throughput.

I've found a critical issue in the implementation that could cause crashes in certain MoE configurations. Please see my comment below for details and a suggested fix.

Comment on lines 2230 to 2234
if use_explicit_se:
current_stream = torch.cuda.current_stream()
self.shared_experts_stream.wait_stream(current_stream)

router_logits, _ = self.gate(hidden_states)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to self.gate(hidden_states) is unconditional, but self.gate can be None for MoE layers that do not use shared experts (e.g., the base FusedMoE class). This will cause a TypeError when forward_impl is called for such layers.

For instance, when DeepseekV2MoE is configured without shared experts, it creates a plain FusedMoE instance. Its forward method computes router_logits and passes them to self.experts.forward(). However, the modified forward_impl ignores these logits and attempts to call self.gate(), which is None for FusedMoE, leading to a crash.

To fix this, the gate computation should only happen when use_explicit_se is true, which is the case for shared experts where the gate is guaranteed to exist.

Suggested change
if use_explicit_se:
current_stream = torch.cuda.current_stream()
self.shared_experts_stream.wait_stream(current_stream)
router_logits, _ = self.gate(hidden_states)
if use_explicit_se:
current_stream = torch.cuda.current_stream()
self.shared_experts_stream.wait_stream(current_stream)
router_logits, _ = self.gate(hidden_states)

@mergify
Copy link

mergify bot commented Oct 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @alexm-redhat.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 11, 2025
@alexm-redhat alexm-redhat marked this pull request as ready for review October 13, 2025 20:26
@chatgpt-codex-connector
Copy link

💡 Codex Review

if isinstance(self.experts, SharedFusedMoE):
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=hidden_states
)
else:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts(

P0 Badge Router logits skipped when SharedFusedMoE runs on GPU

In the SharedFusedMoE path the gate is no longer called before invoking the fused MoE kernel. DeepseekV2MoE.forward now calls self.experts(hidden_states, router_logits=hidden_states) whenever self.experts is a SharedFusedMoE, but FusedMoE.forward_native (the GPU path used by super().forward) still expects router_logits to already contain gate outputs and never invokes self.gate. As a result, the custom CUDA op receives hidden_states instead of logits, which both violates the expected (num_tokens, n_experts) shape and skips routing entirely, leading to runtime errors or incorrect routing for every GPU execution of SharedFusedMoE. The gating step needs to be reinstated or moved inside the GPU path before calling the custom op.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

@alexm-redhat alexm-redhat force-pushed the moe_dual_stream branch 3 times, most recently from f7a3214 to 68d5997 Compare October 14, 2025 14:19
@mergify mergify bot removed the needs-rebase label Oct 14, 2025
@alexm-redhat alexm-redhat force-pushed the moe_dual_stream branch 6 times, most recently from 05900a9 to 6f30ab9 Compare October 16, 2025 16:11
@LucasWilkinson
Copy link
Collaborator

Do you mind just quickly checking that this doesnt break DBO (i.e. --enable-dbo, you can test it with DeepSeek-V2-Lite with dp > 1 and DeepEP LL)

@alexm-redhat
Copy link
Collaborator Author

@LucasWilkinson will check now

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice optimization! Overall looks pretty good to me assuming DBO works; left a few comments

# TODO: Allow disabling of the separate shared experts stream for
# debug purposes. Remove this after more extensive testings with
# TP/DP and other execution modes
disable_shared_experts_stream = os.environ.get(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we move this to envs.py?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, I think VLLM_DISABLE_SHARED_EXPERTS_STREAM is fine

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, moved to envs.py

# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if self.gate is not None:
router_logits, _ = self.gate(hidden_states)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we know if moving this out of the torch.compile region affects perf if we are not using multi-stream?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't run when multi-stream is disabled. As I understand, gate is always inside a torch compiled region, no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok; I think its a bit confusing that we always pass gate intoSharedFusedMoE; I think its hard to tell the control flow in the modeling code maybe instead of:

if isinstance(self.experts, SharedFusedMoE) and self.experts.use_overlapped:
    fused_moe_out = self.experts(
        hidden_states=hidden_states, router_logits=hidden_states
    )
else:
    # router_logits: (num_tokens, n_experts)
    router_logits, _ = self.gate(hidden_states)
    fused_moe_out = self.experts(
        hidden_states=hidden_states, router_logits=router_logits
    )

we can do

class SharedFusedMoE(FusedMoE):
    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if not self.use_overlapped:
            ...
            router_logits, _ = self.gate(hidden_states)
            fused_out = super().forward(
                hidden_states=hidden_states,
                router_logits=router_logits,
            )
        else:
            shared_out, fused_out = super().forward(
                hidden_states=hidden_states,
                router_logits=hidden_states,
            )
        return shared_out, fused_out

this way in the modeling code we can assume that if we are using SharedFusedMoE it will always handle the gate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea, let me try it, I may have some issues with the interface removing the router_logits input, but let's see how I can remove it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually there is a problem when the FusedMoE class is not SharedFusedMoE, since then the gate() needs to be outside anyway. I.e the if/else cannot be removed, however, I can remove the "non-trivial" check with "overlap" by simply providing a function like is_router_internal() for the FusedMoE base class. Will try to do it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added is_internal_router property so it is cleaner now.

):
# Start the separate shared experts stream here since we want
# to run in parallel with the router/gate (next op below)
current_stream = torch.cuda.current_stream()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use from vllm.utils.torch_utils import current_stream

self.shared_experts_stream.wait_stream(current_stream())

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, didn't know it is possible simply to import. I suspect it is a constant handle.

# For chunked, we start the shared experts stream here
# (Note that no concurrency with the router/gate)
current_stream = torch.cuda.current_stream()
self.shared_experts_stream.wait_stream(current_stream)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use from vllm.utils.torch_utils import current_stream (

"""
replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`.
it turns out that `torch.cuda.current_stream()` is quite expensive,
as it will construct a new stream object at each call.
here we patch `torch.cuda.set_stream` to keep track of the current stream
directly, so that we can avoid calling `torch.cuda.current_stream()`.
the underlying hypothesis is that we do not call `torch._C._cuda_setStream`
from C/C++ code.
"""
)

self.shared_experts_stream.wait_stream(current_stream())

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed

# TODO: Allow disabling of the separate shared experts stream for
# debug purposes. Remove this after more extensive testings with
# TP/DP and other execution modes
disable_shared_experts_stream = os.environ.get(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, I think VLLM_DISABLE_SHARED_EXPERTS_STREAM is fine

"DISABLE_MOE_SHARED_EXPERTS_CUDA_STREAM", None
)

if disable_shared_experts_stream is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change the var from None by default to False and just do a regular bool check

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed to bool / False

fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if isinstance(self.experts, SharedFusedMoE) and self.experts.use_overlapped:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean that any model that want to utilize the multi-stream feature must update their own model definition code? For example, will Qwen-Next also benefit from this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bit complicated, since there are 2 improvements: (1) the use of the cuda stream for shared_experts and (2) the moving of the gate / router op to be after the shared_experts execution (so it is parallelized as well). For all models, (1) will be done via the fact they use SharedFusedMoE, however, for (2) you need to change the model code to move the gate inside (like it is done here for DeepSeekV2).

In terms of perf, around 70% comes from (1) and 30% from (2).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see...

…el with the FusedMoE)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
@alexm-redhat alexm-redhat enabled auto-merge (squash) October 21, 2025 19:41
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 21, 2025
@alexm-redhat alexm-redhat merged commit 344a001 into main Oct 21, 2025
61 checks passed
@alexm-redhat alexm-redhat deleted the moe_dual_stream branch October 21, 2025 21:38
baonudesifeizhai pushed a commit to baonudesifeizhai/vllm that referenced this pull request Oct 21, 2025
…_experts" inside FusedMoE (vllm-project#26440)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
sstamenk pushed a commit to sstamenk/vllm that referenced this pull request Oct 23, 2025
…_experts" inside FusedMoE (vllm-project#26440)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: sstamenk <strahinja.stamenkovic@amd.com>
usberkeley pushed a commit to usberkeley/vllm that referenced this pull request Oct 23, 2025
…_experts" inside FusedMoE (vllm-project#26440)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 23, 2025
…_experts" inside FusedMoE (vllm-project#26440)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Oct 25, 2025
### What this PR does / why we need it?
Upgrade to new vllm commit:
vllm-project/vllm@c9461e0

- Fix many imports, caused by
vllm-project/vllm#26908
- Fix import ```sha256```, caused by
vllm-project/vllm#27169
- Remove ```SchedulerConfig.send_delta_data```, caused by
vllm-project/vllm#27142
- Fix ```FusedMoE``` because of dual stream execution, caused by
vllm-project/vllm#26440

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
CI passed with new added/existing test.


- vLLM version: v0.11.0rc3
- vLLM main:
vllm-project/vllm@17c540a

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: Icey <1790571317@qq.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
liziyu179 pushed a commit to nwpu-zxr/vllm-ascend that referenced this pull request Oct 25, 2025
### What this PR does / why we need it?
Upgrade to new vllm commit:
vllm-project/vllm@c9461e0

- Fix many imports, caused by
vllm-project/vllm#26908
- Fix import ```sha256```, caused by
vllm-project/vllm#27169
- Remove ```SchedulerConfig.send_delta_data```, caused by
vllm-project/vllm#27142
- Fix ```FusedMoE``` because of dual stream execution, caused by
vllm-project/vllm#26440

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.11.0rc3
- vLLM main:
vllm-project/vllm@17c540a

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: Icey <1790571317@qq.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…_experts" inside FusedMoE (vllm-project#26440)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…_experts" inside FusedMoE (vllm-project#26440)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
@ZJY0516
Copy link
Contributor

ZJY0516 commented Oct 27, 2025

I was wondering whether this can be done throgh torch.compile custom pass?

logger.info_once("Disabling MoE shared_experts cuda stream")
self.shared_experts_stream = None
else:
self.shared_experts_stream = torch.cuda.Stream()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexm-redhat We may need to have global two streams rather than two streams per FusedMoE layer. With the feature we see an explosion of streams which may not be ideal

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants