@@ -42,28 +42,40 @@ def _pre_process(
4242 num_experts : int ,
4343 ) -> tuple [torch .Tensor , torch .Tensor , int ]:
4444 """Pre-process before MLP.
45+
4546 Args:
46- hidden_states: Tensor of shape (num_tokens, hidden_size)
47- topk_ids: Tensor of shape (num_tokens, top_k_num)
48- topk_weights: Tensor of shape (num_tokens, top_k_num)
49- expert_map: Tensor mapping global expert IDs to local IDs
50- num_experts: Number of local experts
47+ hidden_states (torch.Tensor): Tensor of shape (num_tokens, hidden_size)
48+ topk_ids (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
49+ topk_weights (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
50+ expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
51+ Mapping from global expert IDs to local expert IDs.
52+ num_experts (int): Number of local experts (experts on this device).
53+
5154 Returns:
52- permuted_hidden_states: Tensor of shape (num_tokens * top_k_num, hidden_size)
53- expert_tokens: Tensor of shape (num_experts,)
54- group_list_type: Argument for grouped matmul
55+ tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
56+ - permuted_hidden_states (torch.Tensor): Tensor of shape
57+ (num_tokens * top_k_num, hidden_size) after permuting
58+ hidden_states based on topk_ids.
59+ - expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
60+ Number of tokens assigned to each expert.
61+ - group_list_type (int): Type of group list, 0 for `cumsum`
62+ and 1 for `count`. This is mainly for `npu_grouped_matmul`
63+ to determine how to handle the output.
64+ Raises:
65+ NotImplementedError: If the method is not implemented in the subclass.
5566 """
5667 pass
5768
5869 @abstractmethod
5970 def _post_process (self , mlp_output : torch .Tensor ,
6071 hidden_states : torch .Tensor ) -> None :
6172 """Post-process after MLP.
73+
6274 Args:
63- mlp_output: Tensor of shape (num_tokens * top_k_num, hidden_size)
64- hidden_states: Tensor of shape (num_tokens, hidden_size)
65- Returns:
66- None: This method mutates hidden_states in-place .
75+ mlp_output (torch.Tensor) : Tensor of shape
76+ (num_tokens * top_k_num , hidden_size) after MLP.
77+ hidden_states (torch.Tensor): Tensor of shape
78+ (num_tokens, hidden_size) to be updated with the final output .
6779 """
6880 pass
6981
@@ -78,6 +90,7 @@ def _pre_process(
7890 expert_map : torch .Tensor ,
7991 num_experts : int ,
8092 ) -> tuple [torch .Tensor , torch .Tensor , int ]:
93+ """Dummy implementation, see moe_comm_pre_process_fake for details."""
8194 return moe_comm_pre_process_fake (hidden_states , topk_ids , topk_weights ,
8295 expert_map , num_experts )
8396
@@ -166,11 +179,22 @@ def _post_process(self, mlp_output: torch.Tensor,
166179
167180
168181class AllGatherCommImpl (MoECommMethod ):
169- """This implementation is for the scenarios listed below:
170- 1. `enable_expert_parallel=False`.
171- 2. If `npu_moe_init_routing_v2` is available, we will support `enable_expert_parallel=True`,
172- and this implementation will become the default one, changing the name to `AllGather` at
173- the same time.
182+ """This implementation is the same as NativeAllGatherCommImpl,
183+ but uses NPU-specific ops for better performance.
184+
185+ This implementation should be compatible with all scenarios, and
186+ thus it is the default implementation for MoE communication methods.
187+ It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
188+ and `torch_npu.npu_moe_token_unpermute` for post-processing
189+ to handle the token-to-expert mapping and communication efficiently.
190+
191+ NOTE(Yizhou): TBH, it is really weird that we were supposed to use
192+ `torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
193+ or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
194+ for pre-processing and post-processing, respectively.
195+ But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
196+ use `torch_npu.npu_moe_token_unpermute` instead.
197+ This is a workaround and should be removed after the issue is fixed.
174198 """
175199
176200 def _pre_process (
@@ -392,6 +416,10 @@ def moe_comm_pre_process(
392416 expert_map : torch .Tensor ,
393417 num_experts : int ,
394418) -> tuple [torch .Tensor , torch .Tensor , int ]:
419+ """This function is a wrapper for the pre_process method of the
420+ MoECommMethod instance stored in the ForwardContext. So it can be
421+ used as a custom op in the vllm framework.
422+ """
395423 forward_context : ForwardContext = get_forward_context ()
396424 self = forward_context .moe_comm_method
397425 return self ._pre_process (hidden_states , topk_ids , topk_weights , expert_map ,
@@ -405,6 +433,9 @@ def moe_comm_pre_process_fake(
405433 expert_map : torch .Tensor ,
406434 num_experts : int ,
407435) -> tuple [torch .Tensor , torch .Tensor , int ]:
436+ """This is a fake implementation of the pre_process method.
437+ torch.compile will use this implementation to generate FX graph.
438+ """
408439 top_k_num = topk_ids .shape [1 ]
409440 permuted_hidden_states = hidden_states .repeat_interleave (top_k_num , dim = 0 )
410441 expert_tokens = torch .zeros ((num_experts , ),
@@ -416,6 +447,10 @@ def moe_comm_pre_process_fake(
416447
417448def moe_comm_post_process (mlp_output : torch .Tensor ,
418449 hidden_states : torch .Tensor ) -> None :
450+ """This function is a wrapper for the post_process method of the
451+ MoECommMethod instance stored in the ForwardContext. So it can be
452+ used as a custom op in the vllm framework.
453+ """
419454 forward_context : ForwardContext = get_forward_context ()
420455 self = forward_context .moe_comm_method
421456 self ._post_process (mlp_output , hidden_states )
0 commit comments