| 
15 | 15 | # This file is a part of the vllm-ascend project.  | 
16 | 16 | # Adapted from vllm/tests/kernels/test_moe.py  | 
17 | 17 | 
 
  | 
18 |  | -from typing import Any, Callable, List, Optional  | 
 | 18 | +from typing import Any, Callable, List, Optional, Tuple, Union  | 
19 | 19 | 
 
  | 
20 | 20 | import torch  | 
21 | 21 | import torch.distributed as dist  | 
 | 
34 | 34 | import vllm_ascend.envs as envs_ascend  | 
35 | 35 | from vllm_ascend.ascend_config import get_ascend_config  | 
36 | 36 | from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group  | 
 | 37 | +from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor  | 
37 | 38 | 
 
  | 
38 | 39 | VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2  | 
39 | 40 | USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM  | 
@@ -104,15 +105,17 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,  | 
104 | 105 |     return topk_ids_pad, unpad_indices  | 
105 | 106 | 
 
  | 
106 | 107 | 
 
  | 
107 |  | -def fused_experts_with_mc2(hidden_states: torch.Tensor,  | 
108 |  | -                           w1: torch.Tensor,  | 
109 |  | -                           w2: torch.Tensor,  | 
110 |  | -                           topk_weights: torch.Tensor,  | 
111 |  | -                           topk_ids: torch.Tensor,  | 
112 |  | -                           top_k: int,  | 
113 |  | -                           expert_map: torch.Tensor = None,  | 
114 |  | -                           moe_all_to_all_group_name: Optional[str] = None,  | 
115 |  | -                           **kwargs) -> torch.Tensor:  | 
 | 108 | +def fused_experts_with_mc2(  | 
 | 109 | +    hidden_states: torch.Tensor,  | 
 | 110 | +    w1: torch.Tensor,  | 
 | 111 | +    w2: torch.Tensor,  | 
 | 112 | +    topk_weights: torch.Tensor,  | 
 | 113 | +    topk_ids: torch.Tensor,  | 
 | 114 | +    top_k: int,  | 
 | 115 | +    expert_map: torch.Tensor = None,  | 
 | 116 | +    moe_all_to_all_group_name: Optional[str] = None,  | 
 | 117 | +    shared_experts: Optional[Any] = None  | 
 | 118 | +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:  | 
116 | 119 |     global_bs = 0  | 
117 | 120 |     moe_expert_num = len(expert_map)  | 
118 | 121 |     kwargs_mc2 = {  | 
@@ -152,6 +155,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,  | 
152 | 155 |     expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[  | 
153 | 156 |         0:5]  | 
154 | 157 | 
 
  | 
 | 158 | +    if shared_experts is not None:  | 
 | 159 | +        with npu_stream_switch("moe_secondary", 0):  | 
 | 160 | +            npu_wait_tensor(hidden_states, topk_weights)  | 
 | 161 | +            shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)  | 
 | 162 | +            npu_wait_tensor(shared_gate_up, expand_x)  | 
 | 163 | +            shared_act = shared_experts.act_fn(shared_gate_up)  | 
 | 164 | + | 
155 | 165 |     w1 = w1.transpose(1, 2)  | 
156 | 166 | 
 
  | 
157 | 167 |     group_list = expert_token_nums.to(torch.int64)  | 
@@ -208,7 +218,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,  | 
208 | 218 | 
 
  | 
209 | 219 |     hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)  | 
210 | 220 | 
 
  | 
211 |  | -    return hidden_states  | 
 | 221 | +    if shared_experts is None:  | 
 | 222 | +        return hidden_states  | 
 | 223 | +    else:  | 
 | 224 | +        with npu_stream_switch("moe_secondary", 0):  | 
 | 225 | +            npu_wait_tensor(shared_act, down_out_list)  | 
 | 226 | +            shared_hidden_states, _ = shared_experts.down_proj(shared_act)  | 
 | 227 | +        return hidden_states, shared_hidden_states  | 
212 | 228 | 
 
  | 
213 | 229 | 
 
  | 
214 | 230 | def apply_mlp(hidden_states_wrapper: List[torch.Tensor],  | 
@@ -873,6 +889,7 @@ def apply(  | 
873 | 889 |         e_score_correction_bias: Optional[torch.Tensor] = None,  | 
874 | 890 |         is_prefill: bool = False,  | 
875 | 891 |         enable_force_load_balance: bool = False,  | 
 | 892 | +        shared_experts: Optional[Any] = None,  | 
876 | 893 |         **kwargs,  | 
877 | 894 |     ) -> torch.Tensor:  | 
878 | 895 | 
 
  | 
@@ -922,7 +939,7 @@ def apply(  | 
922 | 939 |                 top_k=top_k,  | 
923 | 940 |                 expert_map=expert_map,  | 
924 | 941 |                 moe_all_to_all_group_name=self.moe_all_to_all_group_name,  | 
925 |  | -                **kwargs)  | 
 | 942 | +                shared_experts=shared_experts)  | 
926 | 943 |         elif self.torchair_graph_enabled or get_ep_group().world_size == 1:  | 
927 | 944 |             return fused_experts(hidden_states=x,  | 
928 | 945 |                                  w1=layer.w13_weight,  | 
 | 
0 commit comments