Skip to content

Commit 9abbf4b

Browse files
committed
docs: Improve documentation for MoE communication methods
Enhances and adds docstrings across the MoE communication methods to improve clarity and provide more detailed explanations. The docstring for `AllGatherCommImpl` is updated to reflect that it is now the default implementation and to explain a workaround for an accuracy issue. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent 64d5ceb commit 9abbf4b

File tree

1 file changed

+52
-17
lines changed

1 file changed

+52
-17
lines changed

vllm_ascend/distributed/moe_comm_method.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,28 +42,40 @@ def _pre_process(
4242
num_experts: int,
4343
) -> tuple[torch.Tensor, torch.Tensor, int]:
4444
"""Pre-process before MLP.
45+
4546
Args:
46-
hidden_states: Tensor of shape (num_tokens, hidden_size)
47-
topk_ids: Tensor of shape (num_tokens, top_k_num)
48-
topk_weights: Tensor of shape (num_tokens, top_k_num)
49-
expert_map: Tensor mapping global expert IDs to local IDs
50-
num_experts: Number of local experts
47+
hidden_states (torch.Tensor): Tensor of shape (num_tokens, hidden_size)
48+
topk_ids (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
49+
topk_weights (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
50+
expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
51+
Mapping from global expert IDs to local expert IDs.
52+
num_experts (int): Number of local experts (experts on this device).
53+
5154
Returns:
52-
permuted_hidden_states: Tensor of shape (num_tokens * top_k_num, hidden_size)
53-
expert_tokens: Tensor of shape (num_experts,)
54-
group_list_type: Argument for grouped matmul
55+
tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
56+
- permuted_hidden_states (torch.Tensor): Tensor of shape
57+
(num_tokens * top_k_num, hidden_size) after permuting
58+
hidden_states based on topk_ids.
59+
- expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
60+
Number of tokens assigned to each expert.
61+
- group_list_type (int): Type of group list, 0 for `cumsum`
62+
and 1 for `count`. This is mainly for `npu_grouped_matmul`
63+
to determine how to handle the output.
64+
Raises:
65+
NotImplementedError: If the method is not implemented in the subclass.
5566
"""
5667
pass
5768

5869
@abstractmethod
5970
def _post_process(self, mlp_output: torch.Tensor,
6071
hidden_states: torch.Tensor) -> None:
6172
"""Post-process after MLP.
73+
6274
Args:
63-
mlp_output: Tensor of shape (num_tokens * top_k_num, hidden_size)
64-
hidden_states: Tensor of shape (num_tokens, hidden_size)
65-
Returns:
66-
None: This method mutates hidden_states in-place.
75+
mlp_output (torch.Tensor): Tensor of shape
76+
(num_tokens * top_k_num, hidden_size) after MLP.
77+
hidden_states (torch.Tensor): Tensor of shape
78+
(num_tokens, hidden_size) to be updated with the final output.
6779
"""
6880
pass
6981

@@ -78,6 +90,7 @@ def _pre_process(
7890
expert_map: torch.Tensor,
7991
num_experts: int,
8092
) -> tuple[torch.Tensor, torch.Tensor, int]:
93+
"""Dummy implementation, see moe_comm_pre_process_fake for details."""
8194
return moe_comm_pre_process_fake(hidden_states, topk_ids, topk_weights,
8295
expert_map, num_experts)
8396

@@ -166,11 +179,22 @@ def _post_process(self, mlp_output: torch.Tensor,
166179

167180

168181
class AllGatherCommImpl(MoECommMethod):
169-
"""This implementation is for the scenarios listed below:
170-
1. `enable_expert_parallel=False`.
171-
2. If `npu_moe_init_routing_v2` is available, we will support `enable_expert_parallel=True`,
172-
and this implementation will become the default one, changing the name to `AllGather` at
173-
the same time.
182+
"""This implementation is the same as NativeAllGatherCommImpl,
183+
but uses NPU-specific ops for better performance.
184+
185+
This implementation should be compatible with all scenarios, and
186+
thus it is the default implementation for MoE communication methods.
187+
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
188+
and `torch_npu.npu_moe_token_unpermute` for post-processing
189+
to handle the token-to-expert mapping and communication efficiently.
190+
191+
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
192+
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
193+
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
194+
for pre-processing and post-processing, respectively.
195+
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
196+
use `torch_npu.npu_moe_token_unpermute` instead.
197+
This is a workaround and should be removed after the issue is fixed.
174198
"""
175199

176200
def _pre_process(
@@ -392,6 +416,10 @@ def moe_comm_pre_process(
392416
expert_map: torch.Tensor,
393417
num_experts: int,
394418
) -> tuple[torch.Tensor, torch.Tensor, int]:
419+
"""This function is a wrapper for the pre_process method of the
420+
MoECommMethod instance stored in the ForwardContext. So it can be
421+
used as a custom op in the vllm framework.
422+
"""
395423
forward_context: ForwardContext = get_forward_context()
396424
self = forward_context.moe_comm_method
397425
return self._pre_process(hidden_states, topk_ids, topk_weights, expert_map,
@@ -405,6 +433,9 @@ def moe_comm_pre_process_fake(
405433
expert_map: torch.Tensor,
406434
num_experts: int,
407435
) -> tuple[torch.Tensor, torch.Tensor, int]:
436+
"""This is a fake implementation of the pre_process method.
437+
torch.compile will use this implementation to generate FX graph.
438+
"""
408439
top_k_num = topk_ids.shape[1]
409440
permuted_hidden_states = hidden_states.repeat_interleave(top_k_num, dim=0)
410441
expert_tokens = torch.zeros((num_experts, ),
@@ -416,6 +447,10 @@ def moe_comm_pre_process_fake(
416447

417448
def moe_comm_post_process(mlp_output: torch.Tensor,
418449
hidden_states: torch.Tensor) -> None:
450+
"""This function is a wrapper for the post_process method of the
451+
MoECommMethod instance stored in the ForwardContext. So it can be
452+
used as a custom op in the vllm framework.
453+
"""
419454
forward_context: ForwardContext = get_forward_context()
420455
self = forward_context.moe_comm_method
421456
self._post_process(mlp_output, hidden_states)

0 commit comments

Comments
 (0)