Skip to content

Commit c9bc8f5

Browse files
committed
Support EP in All Gather mode
This commit refactors and cleans up the Mixture-of-Experts (MoE) communication implementations for Ascend NPUs. Key changes include: - Renames `AllReduceCommImpl` to `AllGatherCommImpl` and updates its implementation to use `npu_moe_init_routing_v2` and `npu_moe_token_unpermute` for improved performance and correctness. - Renames the original `AllGatherCommImpl` to `NativeAllGatherCommImpl` to clarify that it uses native PyTorch operations. - Removes the `MC2CommImpl` and sets `AllGatherCommImpl` as the default MoE communication method. - Adds workarounds in both `AllGatherCommImpl` and `NativeAllGatherCommImpl` to handle incorrect outputs from `npu_grouped_matmul` by zeroing out weights for invalid tokens. - Improves documentation by adding detailed docstrings to abstract methods. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent 1cd1293 commit c9bc8f5

File tree

3 files changed

+100
-101
lines changed

3 files changed

+100
-101
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import vllm_ascend.envs as envs
1313
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
14-
from vllm_ascend.platform import NPUPlatform
1514

1615

1716
class FusedMoEState(Enum):

vllm_ascend/distributed/moe_comm_method.py

Lines changed: 89 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
import torch_npu
5-
from vllm.distributed.parallel_state import get_tp_group
5+
from vllm.distributed.parallel_state import get_ep_group, get_tp_group
66
from vllm.forward_context import ForwardContext, get_forward_context
77
from vllm.utils import direct_register_custom_op
88

@@ -34,13 +34,30 @@ def _pre_process(
3434
expert_map: torch.Tensor,
3535
num_experts: int,
3636
) -> tuple[torch.Tensor, torch.Tensor, int]:
37-
"""Pre-process before MLP."""
37+
"""Pre-process before MLP.
38+
Args:
39+
hidden_states: Tensor of shape (num_tokens, hidden_size)
40+
topk_ids: Tensor of shape (num_tokens, top_k_num)
41+
topk_weights: Tensor of shape (num_tokens, top_k_num)
42+
expert_map: Tensor mapping global expert IDs to local IDs
43+
num_experts: Number of local experts
44+
Returns:
45+
permuted_hidden_states: Tensor of shape (num_tokens * top_k_num, hidden_size)
46+
expert_tokens: Tensor of shape (num_experts,)
47+
group_list_type: Argument for grouped matmul
48+
"""
3849
pass
3950

4051
@abstractmethod
4152
def _post_process(self, mlp_output: torch.Tensor,
4253
hidden_states: torch.Tensor) -> None:
43-
"""Post-process after MLP."""
54+
"""Post-process after MLP.
55+
Args:
56+
mlp_output: Tensor of shape (num_tokens * top_k_num, hidden_size)
57+
hidden_states: Tensor of shape (num_tokens, hidden_size)
58+
Returns:
59+
None: This method mutates hidden_states in-place.
60+
"""
4461
pass
4562

4663

@@ -63,9 +80,8 @@ def _post_process(self, mlp_output: torch.Tensor,
6380
pass
6481

6582

66-
class AllGatherCommImpl(MoECommMethod):
67-
"""This implementation is for the scenarios listed below:
68-
1. `enable_expert_parallel=True`.
83+
class NativeAllGatherCommImpl(MoECommMethod):
84+
"""This implementation should be compatible with all scenarios.
6985
7086
Note that this implementation purely consists of native PyTorch ops
7187
and does not use any NPU-specific ops. So the performance may not be optimal.
@@ -80,7 +96,6 @@ def _pre_process(
8096
expert_map: torch.Tensor,
8197
num_experts: int,
8298
) -> tuple[torch.Tensor, torch.Tensor, int]:
83-
print("Using AllGatherCommImpl for MoE communication.")
8499
num_tokens = hidden_states.shape[0]
85100

86101
# Generate token indices and flatten
@@ -98,6 +113,9 @@ def _pre_process(
98113

99114
# Filter valid token-expert pairs
100115
mask = local_experts_flat != -1
116+
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
117+
# So we need to filter out invalid tokens by zeroing their weights.
118+
# This is a workaround and should be removed after the issue is fixed
101119
filtered_weights = torch.where(mask, weights_flat,
102120
torch.zeros_like(weights_flat)).to(
103121
self.dtype)
@@ -106,7 +124,6 @@ def _pre_process(
106124
local_experts_flat,
107125
torch.full_like(local_experts_flat, num_experts),
108126
).to(topk_ids.dtype)
109-
self.mask = mask
110127

111128
# Sort by local expert IDs
112129
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
@@ -121,39 +138,27 @@ def _pre_process(
121138
dtype=torch.int64)
122139
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
123140
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
124-
token_counts = token_counts[:num_experts]
125-
expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64)
141+
expert_tokens = token_counts[:num_experts]
126142

127143
# Rearrange hidden_states
128144
permuted_hidden_states = hidden_states[self.sorted_token_indices]
129145

130-
group_list_type = 0
146+
group_list_type = 1 # `count` mode
131147

132148
return permuted_hidden_states, expert_tokens, group_list_type
133149

134150
def _post_process(self, mlp_output: torch.Tensor,
135151
hidden_states: torch.Tensor) -> None:
136-
weighted_down_out = mlp_output * self.sorted_weights.unsqueeze(1)
152+
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
137153

138154
final_hidden_states = torch.zeros_like(hidden_states)
139-
140-
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
141-
# This created multiple NaN and index_add_ will mix them up which harms accuracy
142-
# remove this mask and filter after it being fixed
143-
num_valid_tokens = self.mask.sum()
144-
valid_token_mask = (torch.arange(
145-
0, self.sorted_token_indices.shape[0],
146-
device=self.device).unsqueeze(1) < num_valid_tokens)
147-
valid_output = torch.where(valid_token_mask, weighted_down_out,
148-
torch.zeros_like(weighted_down_out)).to(
149-
self.dtype)
150155
final_hidden_states.index_add_(0, self.sorted_token_indices,
151-
valid_output)
156+
mlp_output)
152157

153158
hidden_states[:] = final_hidden_states
154159

155160

156-
class AllReduceCommImpl(MoECommMethod):
161+
class AllGatherCommImpl(MoECommMethod):
157162
"""This implementation is for the scenarios listed below:
158163
1. `enable_expert_parallel=False`.
159164
2. If `npu_moe_init_routing_v2` is available, we will support `enable_expert_parallel=True`,
@@ -169,70 +174,48 @@ def _pre_process(
169174
expert_map: torch.Tensor, # noqa: F841
170175
num_experts: int,
171176
) -> tuple[torch.Tensor, torch.Tensor, int]:
172-
print("Using AllReduceCommImpl for MoE communication.")
173177
num_tokens = hidden_states.shape[0]
174178

175179
self.topk_weights = topk_weights
176180
self.topk_ids = topk_ids
177181

178-
# 1. Prepare row indices for routing
179-
row_idx_len = num_tokens * self.top_k_num
180-
row_idx = torch.arange(row_idx_len,
181-
dtype=torch.int32,
182-
device=self.device)
183-
row_idx = row_idx.view(self.top_k_num, -1).permute(1, 0).contiguous()
184-
185-
# 2. Initial routing to expand tokens and experts
186-
permuted_hidden_states, expanded_row_idx, expanded_expert_idx = (
187-
torch_npu.npu_moe_init_routing(
182+
first_expert_idx = 0
183+
if expert_map is not None:
184+
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
185+
# So we need to filter out invalid tokens by zeroing their weights.
186+
# This is a workaround and should be removed after the issue is fixed
187+
mask = expert_map[topk_ids] != -1
188+
# NOTE: This is equivalent to self.topk_weights[~mask] = 0.0,
189+
# but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph
190+
self.topk_weights = torch.where(mask, topk_weights, 0.0)
191+
192+
first_expert_idx = get_ep_group().rank_in_group * num_experts
193+
last_expert_idx = first_expert_idx + num_experts
194+
195+
permuted_hidden_states, expanded_row_idx, expert_tokens, _ = (
196+
torch_npu.npu_moe_init_routing_v2(
188197
hidden_states,
189-
row_idx=row_idx,
190-
expert_idx=topk_ids,
191-
active_num=num_tokens,
198+
topk_ids,
199+
active_num=num_tokens * self.top_k_num,
200+
expert_num=self.global_num_experts,
201+
expert_tokens_num_type=1, # Only support `count` mode now
202+
expert_tokens_num_flag=True, # Output `expert_tokens`
203+
active_expert_range=[first_expert_idx, last_expert_idx],
204+
quant_mode=-1,
192205
))
193-
# NOTE: Currently, V2 produces incorrect accuracy and weaker performance than V1
194-
# first_expert_idx = 0
195-
# if expert_map is not None:
196-
# first_expert_idx = torch.nonzero(expert_map != -1, as_tuple=False)[0].item()
197-
# last_expert_idx = first_expert_idx + num_experts
198-
# permuted_hidden_states, expanded_row_idx, expert_tokens, _ = (
199-
# torch_npu.npu_moe_init_routing_v2(
200-
# hidden_states,
201-
# topk_ids,
202-
# active_num=num_tokens * self.top_k_num,
203-
# expert_num=self.global_num_experts,
204-
# expert_tokens_num_type=1, # Only support `count` mode now
205-
# expert_tokens_num_flag=True, # Output `expert_tokens`
206-
# active_expert_range=[first_expert_idx, last_expert_idx],
207-
# quant_mode=-1,
208-
# )
209-
# )
210206
self.expanded_row_idx = expanded_row_idx
211207
permuted_hidden_states = permuted_hidden_states
212208

213-
# 3. Compute expert tokens
214-
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
215-
expanded_expert_idx, num_experts).to(torch.int64)
216-
# NOTE: This is also for npu_moe_init_routing_v2
217-
# expert_tokens = torch.cumsum(expert_tokens, 0)
218-
219-
group_list_type = 0
209+
group_list_type = 1 # `count` mode
220210

221211
return permuted_hidden_states, expert_tokens, group_list_type
222212

223213
def _post_process(self, mlp_output: torch.Tensor,
224214
hidden_states: torch.Tensor) -> None:
225-
hidden_states[:] = torch_npu.npu_moe_finalize_routing(
226-
mlp_output,
227-
skip1=None,
228-
skip2=None,
229-
bias=None,
230-
scales=self.topk_weights,
231-
expanded_src_to_dst_row=self.expanded_row_idx,
232-
export_for_source_row=self.topk_ids,
233-
# NOTE: For npu_moe_init_routing_v2
234-
# drop_pad_mode=2,
235-
)
215+
hidden_states[:] = torch_npu.npu_moe_token_unpermute(
216+
permuted_tokens=mlp_output,
217+
sorted_indices=self.expanded_row_idx,
218+
probs=self.topk_weights)
236219

237220

238221
class MC2CommImpl(MoECommMethod):
@@ -288,10 +271,27 @@ def _pre_process(
288271
num_experts: int,
289272
) -> tuple[torch.Tensor, torch.Tensor, int]:
290273
# Store tensors needed for post_process
291-
self.topk_ids = topk_ids.clone()
292-
self.topk_weights = topk_weights
274+
self.topk_ids = topk_ids
275+
self.topk_weights = topk_weights.to(torch.float32)
293276
self.mc2_mask = get_forward_context().mc2_mask
294277

278+
# tp_size = get_tensor_model_parallel_world_size()
279+
# self.chunked_hidden_states = torch.tensor_split(hidden_states,
280+
# tp_size,
281+
# dim=0)
282+
# chunked_topk_ids = torch.tensor_split(self.topk_ids,
283+
# tp_size,
284+
# dim=0)
285+
# chunked_topk_weights = torch.tensor_split(self.topk_weights,
286+
# tp_size,
287+
# dim=0)
288+
# chunked_mc2_mask = torch.tensor_split(self.mc2_mask, tp_size, dim=0)
289+
# tp_rank = get_tensor_model_parallel_rank()
290+
# hidden_states = self.chunked_hidden_states[tp_rank]
291+
# self.topk_ids = chunked_topk_ids[tp_rank]
292+
# self.topk_weights = chunked_topk_weights[tp_rank]
293+
# self.mc2_mask = chunked_mc2_mask[tp_rank]
294+
295295
dispatch_kwargs = {
296296
"x": hidden_states,
297297
"expert_ids": self.topk_ids,
@@ -326,7 +326,7 @@ def _pre_process(
326326
expert_tokens,
327327
self.ep_recv_counts,
328328
self.tp_recv_counts,
329-
) = torch_npu.npu_moe_distribute_dispatch_v2(**dispatch_kwargs)[:6]
329+
) = dispatch(**dispatch_kwargs)[:6]
330330

331331
group_list_type = 1
332332

@@ -337,7 +337,7 @@ def _post_process(self, mlp_output: torch.Tensor,
337337
combine_kwargs = {
338338
"expand_x": mlp_output,
339339
"expert_ids": self.topk_ids,
340-
"expert_scales": self.topk_weights.to(torch.float32),
340+
"expert_scales": self.topk_weights,
341341
"expert_shard_type": 0,
342342
"shared_expert_rank_num": 0,
343343
"moe_expert_num": self.global_num_experts,
@@ -366,12 +366,17 @@ def _post_process(self, mlp_output: torch.Tensor,
366366
"x_active_mask": self.mc2_mask,
367367
})
368368

369-
if self.enable_dispatch_v2:
370-
hidden_states[:] = torch_npu.npu_moe_distribute_combine_v2(
371-
**combine_kwargs)
372-
else:
373-
hidden_states[:] = torch_npu.npu_moe_distribute_combine(
374-
**combine_kwargs)
369+
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
370+
371+
hidden_states[:] = combine(**combine_kwargs)
372+
373+
# final_hidden_states = combine(**combine_kwargs)
374+
375+
# dist.all_gather(list(self.chunked_hidden_states), final_hidden_states, get_tp_group().device_group)
376+
377+
# final_hidden_states = torch.cat(self.chunked_hidden_states, dim=0)
378+
379+
# hidden_states[:] = final_hidden_states
375380

376381

377382
def moe_comm_pre_process(

vllm_ascend/worker/model_runner_v1.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import weakref
2727
from contextlib import contextmanager, nullcontext
2828
from dataclasses import dataclass
29-
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast
29+
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast
3030

3131
import numpy as np
3232
import numpy.typing as npt
@@ -80,9 +80,8 @@
8080
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
8181
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
8282
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
83-
AllReduceCommImpl,
83+
NativeAllGatherCommImpl,
8484
DummyCommImpl,
85-
MC2CommImpl,
8685
MoECommMethod)
8786
from vllm_ascend.multistream.ms_split import compute_split_seq_index
8887
from vllm_ascend.platform import NPUPlatform
@@ -379,6 +378,14 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
379378
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
380379
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
381380

381+
self.reserved_mc2_mask = torch.zeros(
382+
512,
383+
dtype=torch.bool,
384+
device=self.device,
385+
)
386+
387+
self.moe_comm_method = AllGatherCommImpl
388+
382389
def check_batch_sizes_consistency(self) -> None:
383390
if not dist.is_initialized():
384391
return
@@ -401,18 +408,6 @@ def check_batch_sizes_consistency(self) -> None:
401408
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
402409
)
403410

404-
self.reserved_mc2_mask = torch.zeros(
405-
512,
406-
dtype=torch.bool,
407-
device=self.device,
408-
)
409-
410-
if self.parallel_config.enable_expert_parallel:
411-
# self.moe_comm_method = AllGatherCommImpl
412-
self.moe_comm_method = MC2CommImpl
413-
else:
414-
self.moe_comm_method = AllReduceCommImpl
415-
416411
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
417412
"""Update the cached states and the persistent batch with the scheduler
418413
output.
@@ -1880,7 +1875,7 @@ def _dummy_run(
18801875
skip_attn: bool = True,
18811876
with_prefill: bool = False,
18821877
is_torchair_compile: bool = False,
1883-
moe_comm_method: MoECommMethod = DummyCommImpl,
1878+
moe_comm_method: Type[MoECommMethod] = DummyCommImpl,
18841879
) -> torch.Tensor:
18851880
# Padding for DP
18861881
(num_tokens, num_tokens_across_dp, with_prefill,

0 commit comments

Comments
 (0)