3333
3434import  vllm_ascend .envs  as  envs_ascend 
3535from  vllm_ascend .distributed .parallel_state  import  get_ep_group , get_etp_group 
36+ from  vllm_ascend .utils  import  npu_stream_switch , npu_wait_tensor 
3637
3738VLLM_ENABLE_MC2 : bool  =  envs_ascend .VLLM_ENABLE_MC2 
3839USING_LCCL_COM : bool  =  envs_ascend .USING_LCCL_COM 
@@ -47,6 +48,8 @@ def fused_experts_with_mc2(
4748    top_k : int ,
4849    expert_map : torch .Tensor  =  None ,
4950    moe_all_to_all_group_name : Optional [str ] =  None ,
51+     shared_experts : Optional [torch .nn .Module ] =  None ,
52+     graph_mode : bool  =  False ,
5053) ->  torch .Tensor :
5154    global_bs  =  0 
5255    moe_expert_num  =  len (expert_map )
@@ -88,6 +91,10 @@ def fused_experts_with_mc2(
8891    expand_x , dynamic_scale , expand_idx , expert_token_nums , ep_recv_counts  =  output [
8992        0 :5 ]
9093
94+     if  shared_experts  is  not None :
95+         with  npu_stream_switch ("expert_secondary" , 0 , enabled = graph_mode ):
96+             shared_gate_up , _  =  shared_experts .gate_up_proj (hidden_states )
97+ 
9198    w1  =  w1 .transpose (1 , 2 )
9299    expert_token_nums  =  torch .cumsum (expert_token_nums ,
93100                                     dim = 0 ,
@@ -102,6 +109,11 @@ def fused_experts_with_mc2(
102109        group_list = group_list ,
103110    )
104111
112+     if  shared_experts  is  not None :
113+         with  npu_stream_switch ("expert_secondary" , 0 , enabled = graph_mode ):
114+             npu_wait_tensor (shared_gate_up , expand_x , enabled = graph_mode )
115+             shared_act  =  shared_experts .act_fn (shared_gate_up )
116+ 
105117    # TODO: Remove this in the future. 
106118    gate_up_out  =  torch .cat (gate_up_out_list , dim = 0 )
107119    gate_up_out  =  torch_npu .npu_swiglu (gate_up_out )
@@ -145,7 +157,15 @@ def fused_experts_with_mc2(
145157
146158    hidden_states  =  torch_npu .npu_moe_distribute_combine (** kwargs )
147159
148-     return  hidden_states 
160+     if  shared_experts  is  not None :
161+         with  npu_stream_switch ("expert_secondary" , 0 , enabled = graph_mode ):
162+             npu_wait_tensor (shared_act , down_out_list , enabled = graph_mode )
163+             shared_hidden_states , _  =  shared_experts .down_proj (shared_act )
164+ 
165+     if  shared_experts  is  None :
166+         return  hidden_states 
167+     else :
168+         return  hidden_states , shared_hidden_states 
149169
150170
151171# currently expert parallelism implemented with all2all 
@@ -587,6 +607,8 @@ def __init__(self, moe: MoEConfig = None):
587607        self .ep_size  =  ep_group .world_size 
588608        self .global_batch_size  =  vllm_config .scheduler_config .max_num_seqs 
589609        self .local_batch_size  =  self .global_batch_size  //  self .ep_size 
610+         self .graph_mode  =  vllm_config .get ("additional_config" ,
611+                                           {}).get ("enable_graph_mode" , False )
590612
591613        try :
592614            device_group  =  ep_group .device_group 
@@ -624,6 +646,7 @@ def apply(
624646        scoring_func : str  =  "softmax" ,
625647        e_score_correction_bias : Optional [torch .Tensor ] =  None ,
626648        is_prefill : bool  =  False ,
649+         shared_experts : Optional [torch .nn .Module ] =  None ,
627650        ** kwargs ,
628651    ):
629652        # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern 
@@ -664,28 +687,37 @@ def apply(
664687                topk_ids = topk_ids ,
665688                top_k = top_k ,
666689                expert_map = expert_map ,
667-                 moe_all_to_all_group_name = self .moe_all_to_all_group_name )
690+                 moe_all_to_all_group_name = self .moe_all_to_all_group_name ,
691+                 shared_experts = shared_experts ,
692+                 graph_mode = self .graph_mode ,
693+             )
668694        elif  get_ep_group ().world_size  ==  1 :
669-             return  fused_experts (hidden_states = x ,
670-                                  w1 = layer .w13_weight ,
671-                                  w2 = layer .w2_weight ,
672-                                  topk_weights = topk_weights ,
673-                                  topk_ids = topk_ids ,
674-                                  top_k = top_k ,
675-                                  expert_map = expert_map )
695+             router_hidden_states   =  fused_experts (hidden_states = x ,
696+                                                   w1 = layer .w13_weight ,
697+                                                   w2 = layer .w2_weight ,
698+                                                   topk_weights = topk_weights ,
699+                                                   topk_ids = topk_ids ,
700+                                                   top_k = top_k ,
701+                                                   expert_map = expert_map )
676702        else :
677703            # The current implementation of deepseek moe splits hidden_states 
678704            # according to tp_size before they are feed into fused_moe module. 
679705            # Therefore, all2all is needed no matter how dp/tp is set so as to 
680706            # dispatch/combine tokens. 
681-             return  fused_experts_with_all2all (hidden_states = x ,
682-                                               w1 = layer .w13_weight ,
683-                                               w2 = layer .w2_weight ,
684-                                               topk_weights = topk_weights ,
685-                                               topk_ids = topk_ids ,
686-                                               top_k = top_k ,
687-                                               expert_map = expert_map ,
688-                                               ep_group = get_ep_group ())
707+             router_hidden_states  =  fused_experts_with_all2all (
708+                 hidden_states = x ,
709+                 w1 = layer .w13_weight ,
710+                 w2 = layer .w2_weight ,
711+                 topk_weights = topk_weights ,
712+                 topk_ids = topk_ids ,
713+                 top_k = top_k ,
714+                 expert_map = expert_map ,
715+                 ep_group = get_ep_group ())
716+ 
717+         if  shared_experts  is  None :
718+             return  router_hidden_states 
719+         else :
720+             return  router_hidden_states , shared_experts (x )
689721
690722
691723class  AscendFusedMoE (FusedMoE ):
@@ -815,7 +847,8 @@ def forward(self,
815847                router_logits : torch .Tensor ,
816848                is_prefill : bool ,
817849                enable_force_load_balance : bool  =  False ,
818-                 top_k = None ):
850+                 top_k : Optional [int ] =  None ,
851+                 shared_experts : Optional [torch .nn .Module ] =  None ):
819852        assert  self .quant_method  is  not None 
820853
821854        if  top_k :
@@ -842,7 +875,9 @@ def forward(self,
842875            scoring_func = self .scoring_func ,
843876            e_score_correction_bias = self .e_score_correction_bias ,
844877            is_prefill = is_prefill ,
845-             enable_force_load_balance = enable_force_load_balance )
878+             enable_force_load_balance = enable_force_load_balance ,
879+             shared_experts = shared_experts ,
880+         )
846881
847882        if  VLLM_ENABLE_MC2  and  not  is_prefill :
848883            ...
0 commit comments