22
33import torch
44import torch_npu
5- from vllm .distributed .parallel_state import get_tp_group
5+ from vllm .distributed .parallel_state import get_ep_group , get_tp_group
66from vllm .forward_context import ForwardContext , get_forward_context
77from vllm .utils import direct_register_custom_op
88
@@ -34,13 +34,30 @@ def _pre_process(
3434 expert_map : torch .Tensor ,
3535 num_experts : int ,
3636 ) -> tuple [torch .Tensor , torch .Tensor , int ]:
37- """Pre-process before MLP."""
37+ """Pre-process before MLP.
38+ Args:
39+ hidden_states: Tensor of shape (num_tokens, hidden_size)
40+ topk_ids: Tensor of shape (num_tokens, top_k_num)
41+ topk_weights: Tensor of shape (num_tokens, top_k_num)
42+ expert_map: Tensor mapping global expert IDs to local IDs
43+ num_experts: Number of local experts
44+ Returns:
45+ permuted_hidden_states: Tensor of shape (num_tokens * top_k_num, hidden_size)
46+ expert_tokens: Tensor of shape (num_experts,)
47+ group_list_type: Argument for grouped matmul
48+ """
3849 pass
3950
4051 @abstractmethod
4152 def _post_process (self , mlp_output : torch .Tensor ,
4253 hidden_states : torch .Tensor ) -> None :
43- """Post-process after MLP."""
54+ """Post-process after MLP.
55+ Args:
56+ mlp_output: Tensor of shape (num_tokens * top_k_num, hidden_size)
57+ hidden_states: Tensor of shape (num_tokens, hidden_size)
58+ Returns:
59+ None: This method mutates hidden_states in-place.
60+ """
4461 pass
4562
4663
@@ -63,9 +80,8 @@ def _post_process(self, mlp_output: torch.Tensor,
6380 pass
6481
6582
66- class AllGatherCommImpl (MoECommMethod ):
67- """This implementation is for the scenarios listed below:
68- 1. `enable_expert_parallel=True`.
83+ class NativeAllGatherCommImpl (MoECommMethod ):
84+ """This implementation should be compatible with all scenarios.
6985
7086 Note that this implementation purely consists of native PyTorch ops
7187 and does not use any NPU-specific ops. So the performance may not be optimal.
@@ -80,7 +96,6 @@ def _pre_process(
8096 expert_map : torch .Tensor ,
8197 num_experts : int ,
8298 ) -> tuple [torch .Tensor , torch .Tensor , int ]:
83- print ("Using AllGatherCommImpl for MoE communication." )
8499 num_tokens = hidden_states .shape [0 ]
85100
86101 # Generate token indices and flatten
@@ -98,6 +113,9 @@ def _pre_process(
98113
99114 # Filter valid token-expert pairs
100115 mask = local_experts_flat != - 1
116+ # FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
117+ # So we need to filter out invalid tokens by zeroing their weights.
118+ # This is a workaround and should be removed after the issue is fixed
101119 filtered_weights = torch .where (mask , weights_flat ,
102120 torch .zeros_like (weights_flat )).to (
103121 self .dtype )
@@ -106,7 +124,6 @@ def _pre_process(
106124 local_experts_flat ,
107125 torch .full_like (local_experts_flat , num_experts ),
108126 ).to (topk_ids .dtype )
109- self .mask = mask
110127
111128 # Sort by local expert IDs
112129 sort_indices = torch .argsort (filtered_experts .view (torch .float32 ))
@@ -121,39 +138,27 @@ def _pre_process(
121138 dtype = torch .int64 )
122139 ones = torch .ones_like (filtered_experts , dtype = torch .int64 )
123140 token_counts .scatter_add_ (0 , filtered_experts .to (torch .int64 ), ones )
124- token_counts = token_counts [:num_experts ]
125- expert_tokens = torch .cumsum (token_counts , dim = 0 , dtype = torch .int64 )
141+ expert_tokens = token_counts [:num_experts ]
126142
127143 # Rearrange hidden_states
128144 permuted_hidden_states = hidden_states [self .sorted_token_indices ]
129145
130- group_list_type = 0
146+ group_list_type = 1 # `count` mode
131147
132148 return permuted_hidden_states , expert_tokens , group_list_type
133149
134150 def _post_process (self , mlp_output : torch .Tensor ,
135151 hidden_states : torch .Tensor ) -> None :
136- weighted_down_out = mlp_output * self .sorted_weights .unsqueeze (1 )
152+ mlp_output = mlp_output * self .sorted_weights .unsqueeze (1 )
137153
138154 final_hidden_states = torch .zeros_like (hidden_states )
139-
140- # TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
141- # This created multiple NaN and index_add_ will mix them up which harms accuracy
142- # remove this mask and filter after it being fixed
143- num_valid_tokens = self .mask .sum ()
144- valid_token_mask = (torch .arange (
145- 0 , self .sorted_token_indices .shape [0 ],
146- device = self .device ).unsqueeze (1 ) < num_valid_tokens )
147- valid_output = torch .where (valid_token_mask , weighted_down_out ,
148- torch .zeros_like (weighted_down_out )).to (
149- self .dtype )
150155 final_hidden_states .index_add_ (0 , self .sorted_token_indices ,
151- valid_output )
156+ mlp_output )
152157
153158 hidden_states [:] = final_hidden_states
154159
155160
156- class AllReduceCommImpl (MoECommMethod ):
161+ class AllGatherCommImpl (MoECommMethod ):
157162 """This implementation is for the scenarios listed below:
158163 1. `enable_expert_parallel=False`.
159164 2. If `npu_moe_init_routing_v2` is available, we will support `enable_expert_parallel=True`,
@@ -169,70 +174,48 @@ def _pre_process(
169174 expert_map : torch .Tensor , # noqa: F841
170175 num_experts : int ,
171176 ) -> tuple [torch .Tensor , torch .Tensor , int ]:
172- print ("Using AllReduceCommImpl for MoE communication." )
173177 num_tokens = hidden_states .shape [0 ]
174178
175179 self .topk_weights = topk_weights
176180 self .topk_ids = topk_ids
177181
178- # 1. Prepare row indices for routing
179- row_idx_len = num_tokens * self .top_k_num
180- row_idx = torch .arange (row_idx_len ,
181- dtype = torch .int32 ,
182- device = self .device )
183- row_idx = row_idx .view (self .top_k_num , - 1 ).permute (1 , 0 ).contiguous ()
184-
185- # 2. Initial routing to expand tokens and experts
186- permuted_hidden_states , expanded_row_idx , expanded_expert_idx = (
187- torch_npu .npu_moe_init_routing (
182+ first_expert_idx = 0
183+ if expert_map is not None :
184+ # FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
185+ # So we need to filter out invalid tokens by zeroing their weights.
186+ # This is a workaround and should be removed after the issue is fixed
187+ mask = expert_map [topk_ids ] != - 1
188+ # NOTE: This is equivalent to self.topk_weights[~mask] = 0.0,
189+ # but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph
190+ self .topk_weights = torch .where (mask , topk_weights , 0.0 )
191+
192+ first_expert_idx = get_ep_group ().rank_in_group * num_experts
193+ last_expert_idx = first_expert_idx + num_experts
194+
195+ permuted_hidden_states , expanded_row_idx , expert_tokens , _ = (
196+ torch_npu .npu_moe_init_routing_v2 (
188197 hidden_states ,
189- row_idx = row_idx ,
190- expert_idx = topk_ids ,
191- active_num = num_tokens ,
198+ topk_ids ,
199+ active_num = num_tokens * self .top_k_num ,
200+ expert_num = self .global_num_experts ,
201+ expert_tokens_num_type = 1 , # Only support `count` mode now
202+ expert_tokens_num_flag = True , # Output `expert_tokens`
203+ active_expert_range = [first_expert_idx , last_expert_idx ],
204+ quant_mode = - 1 ,
192205 ))
193- # NOTE: Currently, V2 produces incorrect accuracy and weaker performance than V1
194- # first_expert_idx = 0
195- # if expert_map is not None:
196- # first_expert_idx = torch.nonzero(expert_map != -1, as_tuple=False)[0].item()
197- # last_expert_idx = first_expert_idx + num_experts
198- # permuted_hidden_states, expanded_row_idx, expert_tokens, _ = (
199- # torch_npu.npu_moe_init_routing_v2(
200- # hidden_states,
201- # topk_ids,
202- # active_num=num_tokens * self.top_k_num,
203- # expert_num=self.global_num_experts,
204- # expert_tokens_num_type=1, # Only support `count` mode now
205- # expert_tokens_num_flag=True, # Output `expert_tokens`
206- # active_expert_range=[first_expert_idx, last_expert_idx],
207- # quant_mode=-1,
208- # )
209- # )
210206 self .expanded_row_idx = expanded_row_idx
211207 permuted_hidden_states = permuted_hidden_states
212208
213- # 3. Compute expert tokens
214- expert_tokens = torch_npu .npu_moe_compute_expert_tokens (
215- expanded_expert_idx , num_experts ).to (torch .int64 )
216- # NOTE: This is also for npu_moe_init_routing_v2
217- # expert_tokens = torch.cumsum(expert_tokens, 0)
218-
219- group_list_type = 0
209+ group_list_type = 1 # `count` mode
220210
221211 return permuted_hidden_states , expert_tokens , group_list_type
222212
223213 def _post_process (self , mlp_output : torch .Tensor ,
224214 hidden_states : torch .Tensor ) -> None :
225- hidden_states [:] = torch_npu .npu_moe_finalize_routing (
226- mlp_output ,
227- skip1 = None ,
228- skip2 = None ,
229- bias = None ,
230- scales = self .topk_weights ,
231- expanded_src_to_dst_row = self .expanded_row_idx ,
232- export_for_source_row = self .topk_ids ,
233- # NOTE: For npu_moe_init_routing_v2
234- # drop_pad_mode=2,
235- )
215+ hidden_states [:] = torch_npu .npu_moe_token_unpermute (
216+ permuted_tokens = mlp_output ,
217+ sorted_indices = self .expanded_row_idx ,
218+ probs = self .topk_weights )
236219
237220
238221class MC2CommImpl (MoECommMethod ):
@@ -288,10 +271,27 @@ def _pre_process(
288271 num_experts : int ,
289272 ) -> tuple [torch .Tensor , torch .Tensor , int ]:
290273 # Store tensors needed for post_process
291- self .topk_ids = topk_ids . clone ()
292- self .topk_weights = topk_weights
274+ self .topk_ids = topk_ids
275+ self .topk_weights = topk_weights . to ( torch . float32 )
293276 self .mc2_mask = get_forward_context ().mc2_mask
294277
278+ # tp_size = get_tensor_model_parallel_world_size()
279+ # self.chunked_hidden_states = torch.tensor_split(hidden_states,
280+ # tp_size,
281+ # dim=0)
282+ # chunked_topk_ids = torch.tensor_split(self.topk_ids,
283+ # tp_size,
284+ # dim=0)
285+ # chunked_topk_weights = torch.tensor_split(self.topk_weights,
286+ # tp_size,
287+ # dim=0)
288+ # chunked_mc2_mask = torch.tensor_split(self.mc2_mask, tp_size, dim=0)
289+ # tp_rank = get_tensor_model_parallel_rank()
290+ # hidden_states = self.chunked_hidden_states[tp_rank]
291+ # self.topk_ids = chunked_topk_ids[tp_rank]
292+ # self.topk_weights = chunked_topk_weights[tp_rank]
293+ # self.mc2_mask = chunked_mc2_mask[tp_rank]
294+
295295 dispatch_kwargs = {
296296 "x" : hidden_states ,
297297 "expert_ids" : self .topk_ids ,
@@ -326,7 +326,7 @@ def _pre_process(
326326 expert_tokens ,
327327 self .ep_recv_counts ,
328328 self .tp_recv_counts ,
329- ) = torch_npu . npu_moe_distribute_dispatch_v2 (** dispatch_kwargs )[:6 ]
329+ ) = dispatch (** dispatch_kwargs )[:6 ]
330330
331331 group_list_type = 1
332332
@@ -337,7 +337,7 @@ def _post_process(self, mlp_output: torch.Tensor,
337337 combine_kwargs = {
338338 "expand_x" : mlp_output ,
339339 "expert_ids" : self .topk_ids ,
340- "expert_scales" : self .topk_weights . to ( torch . float32 ) ,
340+ "expert_scales" : self .topk_weights ,
341341 "expert_shard_type" : 0 ,
342342 "shared_expert_rank_num" : 0 ,
343343 "moe_expert_num" : self .global_num_experts ,
@@ -366,12 +366,17 @@ def _post_process(self, mlp_output: torch.Tensor,
366366 "x_active_mask" : self .mc2_mask ,
367367 })
368368
369- if self .enable_dispatch_v2 :
370- hidden_states [:] = torch_npu .npu_moe_distribute_combine_v2 (
371- ** combine_kwargs )
372- else :
373- hidden_states [:] = torch_npu .npu_moe_distribute_combine (
374- ** combine_kwargs )
369+ combine = torch_npu .npu_moe_distribute_combine_v2 if self .enable_dispatch_v2 else torch_npu .npu_moe_distribute_combine
370+
371+ hidden_states [:] = combine (** combine_kwargs )
372+
373+ # final_hidden_states = combine(**combine_kwargs)
374+
375+ # dist.all_gather(list(self.chunked_hidden_states), final_hidden_states, get_tp_group().device_group)
376+
377+ # final_hidden_states = torch.cat(self.chunked_hidden_states, dim=0)
378+
379+ # hidden_states[:] = final_hidden_states
375380
376381
377382def moe_comm_pre_process (
0 commit comments