Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion vllm_ascend/ops/common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import torch
import torch_npu
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
tensor_model_parallel_all_reduce)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import \
FusedMoEParallelConfig # isort: skip
Expand Down Expand Up @@ -373,6 +374,21 @@ def __init__(
self, method.__name__.lower(),
method(moe_config=self.moe_config)) # type: ignore[abstract]

def maybe_all_reduce_tensor_model_parallel(
self, final_hidden_states: torch.Tensor):
"""NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`,
and `alltoallcommimpl`, we do not need to all-reduce the final outputs since
the outputs are already aggregated across tensor parallel ranks in the
`finalize` function. In `allgathercommimpl`, we still need to all-reduce the
outputs since each rank only has partial outputs.
"""
forward_context = get_forward_context()
moe_comm_method_name = forward_context.moe_comm_method_name
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)

def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None
Expand Down Expand Up @@ -415,6 +431,38 @@ def forward_impl(self, hidden_states: torch.Tensor,
return final_hidden_states


class AscendSharedFusedMoE(AscendFusedMoE):

def __init__(
self,
shared_experts: torch.nn.Module,
use_overlapped: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self._shared_experts = shared_experts
self.use_overlapped = use_overlapped

def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = self._shared_experts(hidden_states)

# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_method_name = forward_context.moe_comm_method_name
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
shared_out = tensor_model_parallel_all_reduce(shared_out)

fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
return shared_out, fused_out


UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading

Expand Down
1 change: 1 addition & 0 deletions vllm_ascend/patch/platform/patch_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
#

import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_common.patch_shared_fused_moe # noqa
21 changes: 21 additions & 0 deletions vllm_ascend/patch/platform/patch_common/patch_shared_fused_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from vllm.model_executor.models import deepseek_v2, llama4

from vllm_ascend.ops.common_fused_moe import AscendSharedFusedMoE

deepseek_v2.SharedFusedMoE = AscendSharedFusedMoE
llama4.SharedFusedMoE = AscendSharedFusedMoE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the latest code, a common class has been extracted, named SharedFusedMoE.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be patch in worker module, since it'll be used by worker process?

Copy link
Collaborator Author

@yiz-liu yiz-liu Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the latest code, a common class has been extracted, named SharedFusedMoE.

Yes, I'm well aware of that, except patching that class directly (I assume you are suggesting shared_fused_moe.SharedFusedMoE = AscendSharedFusedMoE) will have no effect at all.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be patch in worker module, since it'll be used by worker process?

Tested in both directory, both are fine, so we can stick to this for now.

Loading