1515# This file is a part of the vllm-ascend project.
1616# Adapted from vllm/tests/kernels/test_moe.py
1717
18+ import math
1819import os
1920from typing import Any , Callable , List , Optional , Tuple , Union
2021
3738
3839import vllm_ascend .envs as envs_ascend
3940from vllm_ascend .ascend_config import get_ascend_config
41+ from vllm_ascend .ascend_forward_context import FusedMoEState
4042from vllm_ascend .ops .expert_load_balancer import ExpertLoadBalancer
41- from vllm_ascend .utils import (FusedMoEState , dispose_tensor ,
42- npu_stream_switch , npu_wait_tensor )
43+ from vllm_ascend .utils import (AscendSocVersion , dispose_tensor ,
44+ get_ascend_soc_version , npu_stream_switch ,
45+ npu_wait_tensor )
4346
4447MOE_ALL2ALL_BUFFER : bool = envs_ascend .MOE_ALL2ALL_BUFFER
4548
@@ -117,9 +120,24 @@ def fused_experts_with_mc2(
117120 top_k : int ,
118121 expert_map : torch .Tensor = None ,
119122 moe_all_to_all_group_name : Optional [str ] = None ,
120- shared_experts : Optional [Any ] = None
123+ shared_experts : Optional [Any ] = None ,
124+ is_torchair : bool = False ,
121125) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
122- global_bs = 0
126+ quant_mode = 0
127+ ep_group = get_ep_group ()
128+ ep_rank_id = ep_group .rank_in_group
129+ ep_world_size = ep_group .world_size
130+ tp_world_size = get_tp_group ().world_size
131+
132+ # NOTE: `global_bs` should be equal to `max_num_tokens_across_dp` * `ep_world_size`,
133+ # and `max_num_tokens_across_dp` has been split into `tp_world_size` parts before.
134+ global_bs = math .ceil (get_forward_context ().max_tokens_across_dp /
135+ tp_world_size ) * ep_world_size
136+
137+ # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
138+ need_extra_args = (get_ascend_soc_version () == AscendSocVersion .A3
139+ or is_torchair )
140+
123141 moe_expert_num = len (expert_map )
124142 kwargs_mc2 = {
125143 "x" : hidden_states ,
@@ -130,23 +148,20 @@ def fused_experts_with_mc2(
130148 "global_bs" : global_bs ,
131149 }
132150
133- quant_mode = 0
134- ep_group = get_ep_group ().device_group
135- assert torch .distributed .get_world_size () == ep_group .world_size
136- local_rank = torch .distributed .get_rank (group = ep_group )
137- all_to_all_group_size = torch .distributed .get_world_size (ep_group )
138-
139151 stage1_kwargs = {
140152 "scales" : None ,
141153 "quant_mode" : quant_mode ,
142154 "group_ep" : moe_all_to_all_group_name ,
143- "ep_world_size" : all_to_all_group_size ,
144- "ep_rank_id" : local_rank ,
145- # "group_tp": self.moe_rs_group_name,
146- "group_tp" : moe_all_to_all_group_name ,
147- "tp_world_size" : 1 ,
148- "tp_rank_id" : 0 ,
155+ "ep_world_size" : ep_world_size ,
156+ "ep_rank_id" : ep_rank_id ,
149157 }
158+ if need_extra_args :
159+ stage1_kwargs .update ({
160+ "group_tp" : moe_all_to_all_group_name ,
161+ "tp_world_size" : 1 ,
162+ "tp_rank_id" : 0 ,
163+ })
164+
150165 kwargs_mc2 .update (stage1_kwargs )
151166
152167 output = torch_npu .npu_moe_distribute_dispatch (** kwargs_mc2 )
@@ -205,14 +220,16 @@ def fused_experts_with_mc2(
205220 stage3_kwargs = {
206221 "ep_send_counts" : ep_recv_counts ,
207222 "group_ep" : moe_all_to_all_group_name ,
208- "ep_world_size" : all_to_all_group_size ,
209- "ep_rank_id" : local_rank ,
210- "tp_send_counts" : tp_recv_counts ,
211- # "group_tp": self.moe_rs_group_name,
212- "group_tp" : moe_all_to_all_group_name ,
213- "tp_world_size" : 1 ,
214- "tp_rank_id" : 0 ,
223+ "ep_world_size" : ep_world_size ,
224+ "ep_rank_id" : ep_rank_id ,
215225 }
226+ if need_extra_args :
227+ stage3_kwargs .update ({
228+ "tp_send_counts" : tp_recv_counts ,
229+ "group_tp" : moe_all_to_all_group_name ,
230+ "tp_world_size" : 1 ,
231+ "tp_rank_id" : 0 ,
232+ })
216233 kwargs_mc2 .update (stage3_kwargs )
217234
218235 hidden_states = torch_npu .npu_moe_distribute_combine (** kwargs_mc2 )
@@ -842,17 +859,14 @@ def __init__(self, moe: MoEConfig = None):
842859 super ().__init__ (moe = moe )
843860 vllm_config = get_current_vllm_config ()
844861
845- self .ep_group = get_ep_group ()
846- self .ep_size = self .ep_group .world_size
847862 self .global_batch_size = vllm_config .scheduler_config .max_num_seqs
848- self .local_batch_size = self .global_batch_size // self .ep_size
849863 self .max_model_len = vllm_config .model_config .max_model_len
850864
851865 ascend_config = get_ascend_config ()
852866 self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
853867
854868 try :
855- device_group = self . ep_group .device_group
869+ device_group = get_ep_group () .device_group
856870 # TODO: Try local_rank = ep_group.rank_in_group
857871 local_rank = torch .distributed .get_rank (group = device_group )
858872 backend = device_group ._get_backend (torch .device ("npu" ))
@@ -939,7 +953,8 @@ def apply(
939953 top_k = top_k ,
940954 expert_map = expert_map ,
941955 moe_all_to_all_group_name = self .moe_all_to_all_group_name ,
942- shared_experts = shared_experts )
956+ shared_experts = shared_experts ,
957+ is_torchair = self .torchair_graph_enabled )
943958 elif fused_moe_state == FusedMoEState .AllGather :
944959 return fused_experts (hidden_states = x ,
945960 w1 = layer .w13_weight ,
@@ -1049,17 +1064,15 @@ def __init__(
10491064 self .local_num_experts , self .expert_map = \
10501065 expert_load_balancer .get_rank_placement_map (
10511066 self .moe_instance_id ,
1052- get_ep_group (). rank_in_group )
1067+ self . ep_rank )
10531068 self .log2phy = expert_load_balancer .get_rank_log2phy_map (
1054- self .moe_instance_id ,
1055- get_ep_group ().rank_in_group )
1069+ self .moe_instance_id , self .ep_rank )
10561070 self .global_redundant_expert_num = \
10571071 expert_load_balancer .get_global_redundant_expert_num ()
10581072 else :
10591073 # Create a tensor of size num_experts filled with -1
10601074 self .local_num_experts , self .expert_map = determine_expert_map (
1061- self .ep_size ,
1062- get_ep_group ().rank_in_group , self .global_num_experts )
1075+ self .ep_size , self .ep_rank , self .global_num_experts )
10631076
10641077 self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
10651078 self .enable_multistream_moe = \
@@ -1102,7 +1115,6 @@ def __init__(
11021115 in ("GPTQMarlinMoEMethod" , "CompressedTensorsWNA16MoEMethod" )):
11031116 moe_quant_params ["intermediate_size_full" ] = intermediate_size
11041117
1105- self .ep_group = get_ep_group ()
11061118 # NOTE: self.tp_group is not expert_tp_group
11071119 self .tp_group = get_tp_group ().device_group
11081120 self .quant_method .create_weights (layer = self , ** moe_quant_params )
@@ -1148,7 +1160,7 @@ def forward(self,
11481160 # NOTE: When in torchair graph, it has been padded in model_runner_v1
11491161 if not self .torchair_graph_enabled or is_prefill :
11501162 max_num_tokens_across_dp = get_forward_context (
1151- ).dp_metadata . max_tokens_across_dp_cpu
1163+ ).max_tokens_across_dp
11521164 if num_tokens < max_num_tokens_across_dp :
11531165 hidden_states = nn .functional .pad (
11541166 hidden_states ,
0 commit comments