Skip to content

Commit 3a459de

Browse files
committed
feat: Add MC2 communication method for MoE
Introduces and enables the MC2 communication implementation for Mixture-of-Experts (MoE) on Ascend devices when expert parallelism is active. This new method leverages platform-specific `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` operators to optimize communication and computation parallelism, improving performance. The implementation also adapts to different Ascend SoC versions and available features. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent c9c27ad commit 3a459de

File tree

3 files changed

+156
-5
lines changed

3 files changed

+156
-5
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
import torch
77
from vllm.config import VllmConfig
8-
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
8+
from vllm.distributed import (get_dp_group, get_ep_group,
9+
get_tensor_model_parallel_world_size)
910
from vllm.forward_context import get_forward_context, set_forward_context
1011

1112
import vllm_ascend.envs as envs
@@ -108,7 +109,7 @@ def set_ascend_forward_context(
108109
forward_context.max_tokens_across_dp = max_tokens_across_dp
109110

110111
if num_tokens is not None:
111-
tp_world_size = get_tp_group().world_size
112+
tp_world_size = get_tensor_model_parallel_world_size()
112113
# NOTE: token num which need to pad to when mc2
113114
forward_context.padded_num_tokens = math.ceil(
114115
max_tokens_across_dp / tp_world_size) * tp_world_size

vllm_ascend/distributed/moe_comm_method.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22

33
import torch
44
import torch_npu
5+
from vllm.distributed.parallel_state import get_tp_group
56
from vllm.forward_context import ForwardContext, get_forward_context
67
from vllm.utils import direct_register_custom_op
78

9+
from vllm_ascend.distributed.parallel_state import get_mc2_group
10+
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
11+
812

913
class MoECommMethod(ABC):
1014
"""Base class for MoE communication methods."""
@@ -76,6 +80,7 @@ def _pre_process(
7680
expert_map: torch.Tensor,
7781
num_experts: int,
7882
) -> tuple[torch.Tensor, torch.Tensor, int]:
83+
print("Using AllGatherCommImpl for MoE communication.")
7984
num_tokens = hidden_states.shape[0]
8085

8186
# Generate token indices and flatten
@@ -164,6 +169,7 @@ def _pre_process(
164169
expert_map: torch.Tensor, # noqa: F841
165170
num_experts: int,
166171
) -> tuple[torch.Tensor, torch.Tensor, int]:
172+
print("Using AllReduceCommImpl for MoE communication.")
167173
num_tokens = hidden_states.shape[0]
168174

169175
self.topk_weights = topk_weights
@@ -229,6 +235,145 @@ def _post_process(self, mlp_output: torch.Tensor,
229235
)
230236

231237

238+
class MC2CommImpl(MoECommMethod):
239+
"""This implementation is for the scenarios listed below:
240+
1. `enable_expert_parallel=True`.
241+
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
242+
3. `enable_expert_parallel=False` is not supported.
243+
244+
This implementation uses the MC2 communication method, which is optimized for
245+
Communication and Computation parallelism on Ascend devices.
246+
"""
247+
248+
def __init__(
249+
self,
250+
device: torch.device,
251+
dtype: torch.dtype,
252+
top_k_num: int,
253+
global_num_experts: int,
254+
):
255+
super().__init__(device, dtype, top_k_num, global_num_experts)
256+
257+
# Shared communication configurations
258+
ep_group = get_mc2_group()
259+
self.ep_rank_id = ep_group.rank_in_group
260+
self.ep_world_size = ep_group.world_size
261+
self.tp_world_size = get_tp_group().world_size
262+
263+
device_group = ep_group.device_group
264+
local_rank = torch.distributed.get_rank(group=device_group)
265+
backend = device_group._get_backend(torch.device("npu"))
266+
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
267+
268+
# Feature flags
269+
self.enable_dispatch_v2 = hasattr(torch_npu,
270+
"npu_moe_distribute_dispatch_v2")
271+
self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3
272+
self.need_extra_args = self.is_ascend_a3 # or is_torchair
273+
274+
# Intermediate tensors to be passed from pre_process to post_process
275+
self.topk_ids = None
276+
self.topk_weights = None
277+
self.mc2_mask = None
278+
self.assist_info_for_combine = None
279+
self.ep_recv_counts = None
280+
self.tp_recv_counts = None
281+
282+
def _pre_process(
283+
self,
284+
hidden_states: torch.Tensor,
285+
topk_ids: torch.Tensor,
286+
topk_weights: torch.Tensor,
287+
expert_map: torch.Tensor,
288+
num_experts: int,
289+
) -> tuple[torch.Tensor, torch.Tensor, int]:
290+
# Store tensors needed for post_process
291+
self.topk_ids = topk_ids.clone()
292+
self.topk_weights = topk_weights
293+
self.mc2_mask = get_forward_context().mc2_mask
294+
295+
dispatch_kwargs = {
296+
"x": hidden_states,
297+
"expert_ids": self.topk_ids,
298+
"expert_shard_type": 0,
299+
"shared_expert_rank_num": 0,
300+
"moe_expert_num": self.global_num_experts,
301+
"global_bs": 0,
302+
"scales": None,
303+
"quant_mode": 0,
304+
"group_ep": self.moe_all_to_all_group_name,
305+
"ep_world_size": self.ep_world_size,
306+
"ep_rank_id": self.ep_rank_id,
307+
}
308+
309+
if self.need_extra_args:
310+
dispatch_kwargs.update({
311+
"group_tp": self.moe_all_to_all_group_name,
312+
"tp_world_size": 1,
313+
"tp_rank_id": 0,
314+
})
315+
if self.is_ascend_a3 and self.enable_dispatch_v2:
316+
dispatch_kwargs.update({
317+
"x_active_mask": self.mc2_mask,
318+
})
319+
320+
dispatch = torch_npu.npu_moe_distribute_dispatch_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch
321+
322+
(
323+
permuted_hidden_states,
324+
_, # dynamic_scale is not used
325+
self.assist_info_for_combine,
326+
expert_tokens,
327+
self.ep_recv_counts,
328+
self.tp_recv_counts,
329+
) = torch_npu.npu_moe_distribute_dispatch_v2(**dispatch_kwargs)[:6]
330+
331+
group_list_type = 1
332+
333+
return permuted_hidden_states, expert_tokens, group_list_type
334+
335+
def _post_process(self, mlp_output: torch.Tensor,
336+
hidden_states: torch.Tensor) -> None:
337+
combine_kwargs = {
338+
"expand_x": mlp_output,
339+
"expert_ids": self.topk_ids,
340+
"expert_scales": self.topk_weights.to(torch.float32),
341+
"expert_shard_type": 0,
342+
"shared_expert_rank_num": 0,
343+
"moe_expert_num": self.global_num_experts,
344+
"global_bs": 0,
345+
"ep_send_counts": self.ep_recv_counts,
346+
"group_ep": self.moe_all_to_all_group_name,
347+
"ep_world_size": self.ep_world_size,
348+
"ep_rank_id": self.ep_rank_id,
349+
}
350+
351+
if self.enable_dispatch_v2:
352+
combine_kwargs[
353+
"assist_info_for_combine"] = self.assist_info_for_combine
354+
else:
355+
combine_kwargs["expand_idx"] = self.assist_info_for_combine
356+
357+
if self.need_extra_args:
358+
combine_kwargs.update({
359+
"tp_send_counts": self.tp_recv_counts,
360+
"group_tp": self.moe_all_to_all_group_name,
361+
"tp_world_size": 1,
362+
"tp_rank_id": 0,
363+
})
364+
if self.is_ascend_a3 and self.enable_dispatch_v2:
365+
combine_kwargs.update({
366+
"x_active_mask": self.mc2_mask,
367+
})
368+
369+
if self.enable_dispatch_v2:
370+
hidden_states[:] = torch_npu.npu_moe_distribute_combine_v2(
371+
**combine_kwargs)
372+
else:
373+
hidden_states[:] = torch_npu.npu_moe_distribute_combine(
374+
**combine_kwargs)
375+
376+
232377
def moe_comm_pre_process(
233378
hidden_states: torch.Tensor,
234379
topk_ids: torch.Tensor,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
8383
AllReduceCommImpl,
8484
DummyCommImpl,
85+
MC2CommImpl,
8586
MoECommMethod)
8687
from vllm_ascend.multistream.ms_split import compute_split_seq_index
8788
from vllm_ascend.platform import NPUPlatform
@@ -365,7 +366,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
365366
)
366367

367368
if self.parallel_config.enable_expert_parallel:
368-
self.moe_comm_method = AllGatherCommImpl
369+
# self.moe_comm_method = AllGatherCommImpl
370+
self.moe_comm_method = MC2CommImpl
369371
else:
370372
self.moe_comm_method = AllReduceCommImpl
371373

@@ -1218,12 +1220,15 @@ def _process_reqs(
12181220

12191221
moe_comm_method = self.moe_comm_method
12201222

1223+
# NOTE: Currently this padding logic is really messy,
1224+
# MC2 may not be available in eager mode
1225+
if not self.use_aclgraph or self.torchair_graph_enabled:
1226+
num_input_tokens = padded_num_tokens_across_dp
1227+
12211228
# Run forward pass
12221229
with set_ascend_forward_context(
12231230
attn_metadata,
12241231
self.vllm_config,
1225-
# NOTE: This will break some function
1226-
# num_tokens=padded_num_tokens_across_dp,
12271232
num_tokens=num_input_tokens,
12281233
num_tokens_across_dp=num_tokens_across_dp,
12291234
with_prefill=with_prefill,

0 commit comments

Comments
 (0)