Skip to content

Commit c9c27ad

Browse files
committed
Refactor MoE communication logic with a strategy pattern
Introduces a `MoECommMethod` abstract base class to encapsulate different communication strategies for Mixture of Experts layers. This change decouples the MoE implementation from the specific communication method. Two initial strategies are provided: - `AllGatherCommImpl`: A pure PyTorch implementation for expert parallel scenarios. - `AllReduceCommImpl`: Utilizes NPU-specific ops for non-expert parallel cases. The selection of the communication method is now determined at runtime based on the parallel configuration. This improves code organization and makes it easier to add or swap communication strategies in the future. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent f939381 commit c9c27ad

File tree

5 files changed

+408
-18
lines changed

5 files changed

+408
-18
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from vllm.config import VllmConfig
88
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
99
from vllm.forward_context import get_forward_context, set_forward_context
10-
from vllm.platforms import current_platform
1110

1211
import vllm_ascend.envs as envs
12+
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
1313

1414

1515
class FusedMoEState(Enum):
@@ -54,6 +54,8 @@ def set_ascend_forward_context(
5454
num_tokens_across_dp: Optional[torch.Tensor] = None,
5555
with_prefill: bool = True,
5656
in_profile_run: bool = False,
57+
reserved_mc2_mask: Optional[torch.Tensor] = None,
58+
moe_comm_method: Optional[MoECommMethod] = None,
5759
num_actual_tokens: Optional[int] = None,
5860
):
5961
"""A context manager that stores the current forward context,
@@ -66,6 +68,7 @@ def set_ascend_forward_context(
6668
num_tokens=num_tokens,
6769
num_tokens_across_dp=num_tokens_across_dp):
6870
forward_context = get_forward_context()
71+
forward_context.moe_comm_method = moe_comm_method
6972
forward_context.with_prefill = with_prefill
7073
ep_size = (get_ep_group().world_size if
7174
vllm_config.parallel_config.enable_expert_parallel else 1)
@@ -110,11 +113,12 @@ def set_ascend_forward_context(
110113
forward_context.padded_num_tokens = math.ceil(
111114
max_tokens_across_dp / tp_world_size) * tp_world_size
112115

113-
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
114-
dtype=torch.bool,
115-
device=current_platform.device_type)
116-
mc2_mask[:num_actual_tokens] = True
117-
forward_context.mc2_mask = mc2_mask
116+
if reserved_mc2_mask is not None:
117+
mc2_mask = reserved_mc2_mask[:forward_context.
118+
padded_num_tokens]
119+
mc2_mask[:num_actual_tokens] = True
120+
mc2_mask[num_actual_tokens:] = False
121+
forward_context.mc2_mask = mc2_mask
118122

119123
try:
120124
yield
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
from abc import ABC, abstractmethod
2+
3+
import torch
4+
import torch_npu
5+
from vllm.forward_context import ForwardContext, get_forward_context
6+
from vllm.utils import direct_register_custom_op
7+
8+
9+
class MoECommMethod(ABC):
10+
"""Base class for MoE communication methods."""
11+
12+
def __init__(
13+
self,
14+
device: torch.device,
15+
dtype: torch.dtype,
16+
top_k_num: int,
17+
global_num_experts: int,
18+
):
19+
self.device = device
20+
self.dtype = dtype
21+
self.top_k_num = top_k_num
22+
self.global_num_experts = global_num_experts
23+
24+
@abstractmethod
25+
def _pre_process(
26+
self,
27+
hidden_states: torch.Tensor,
28+
topk_ids: torch.Tensor,
29+
topk_weights: torch.Tensor,
30+
expert_map: torch.Tensor,
31+
num_experts: int,
32+
) -> tuple[torch.Tensor, torch.Tensor, int]:
33+
"""Pre-process before MLP."""
34+
pass
35+
36+
@abstractmethod
37+
def _post_process(self, mlp_output: torch.Tensor,
38+
hidden_states: torch.Tensor) -> None:
39+
"""Post-process after MLP."""
40+
pass
41+
42+
43+
class DummyCommImpl(MoECommMethod):
44+
45+
def _pre_process(
46+
self,
47+
hidden_states: torch.Tensor,
48+
topk_ids: torch.Tensor,
49+
topk_weights: torch.Tensor,
50+
expert_map: torch.Tensor,
51+
num_experts: int,
52+
) -> tuple[torch.Tensor, torch.Tensor, int]:
53+
return moe_comm_pre_process_fake(hidden_states, topk_ids, topk_weights,
54+
expert_map, num_experts)
55+
56+
def _post_process(self, mlp_output: torch.Tensor,
57+
hidden_states: torch.Tensor) -> None:
58+
"""Dummy implementation that does nothing."""
59+
pass
60+
61+
62+
class AllGatherCommImpl(MoECommMethod):
63+
"""This implementation is for the scenarios listed below:
64+
1. `enable_expert_parallel=True`.
65+
66+
Note that this implementation purely consists of native PyTorch ops
67+
and does not use any NPU-specific ops. So the performance may not be optimal.
68+
But it is a good fallback for scenarios where NPU-specific ops are not available.
69+
"""
70+
71+
def _pre_process(
72+
self,
73+
hidden_states: torch.Tensor,
74+
topk_ids: torch.Tensor,
75+
topk_weights: torch.Tensor,
76+
expert_map: torch.Tensor,
77+
num_experts: int,
78+
) -> tuple[torch.Tensor, torch.Tensor, int]:
79+
num_tokens = hidden_states.shape[0]
80+
81+
# Generate token indices and flatten
82+
token_indices = torch.arange(num_tokens,
83+
device=self.device,
84+
dtype=torch.int64)
85+
token_indices = (token_indices.unsqueeze(1).expand(
86+
-1, self.top_k_num).reshape(-1))
87+
88+
# Flatten token-to-expert mappings and map to local experts
89+
weights_flat = topk_weights.view(-1)
90+
experts_flat = topk_ids.view(-1)
91+
local_experts_flat = (expert_map[experts_flat]
92+
if expert_map is not None else experts_flat)
93+
94+
# Filter valid token-expert pairs
95+
mask = local_experts_flat != -1
96+
filtered_weights = torch.where(mask, weights_flat,
97+
torch.zeros_like(weights_flat)).to(
98+
self.dtype)
99+
filtered_experts = torch.where(
100+
mask,
101+
local_experts_flat,
102+
torch.full_like(local_experts_flat, num_experts),
103+
).to(topk_ids.dtype)
104+
self.mask = mask
105+
106+
# Sort by local expert IDs
107+
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
108+
self.sorted_token_indices = token_indices[sort_indices]
109+
self.sorted_weights = filtered_weights[sort_indices]
110+
111+
# Compute token counts with minlength of num_experts
112+
# This is equivalent to but faster than:
113+
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
114+
token_counts = torch.zeros(num_experts + 1,
115+
device=self.device,
116+
dtype=torch.int64)
117+
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
118+
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
119+
token_counts = token_counts[:num_experts]
120+
expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64)
121+
122+
# Rearrange hidden_states
123+
permuted_hidden_states = hidden_states[self.sorted_token_indices]
124+
125+
group_list_type = 0
126+
127+
return permuted_hidden_states, expert_tokens, group_list_type
128+
129+
def _post_process(self, mlp_output: torch.Tensor,
130+
hidden_states: torch.Tensor) -> None:
131+
weighted_down_out = mlp_output * self.sorted_weights.unsqueeze(1)
132+
133+
final_hidden_states = torch.zeros_like(hidden_states)
134+
135+
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
136+
# This created multiple NaN and index_add_ will mix them up which harms accuracy
137+
# remove this mask and filter after it being fixed
138+
num_valid_tokens = self.mask.sum()
139+
valid_token_mask = (torch.arange(
140+
0, self.sorted_token_indices.shape[0],
141+
device=self.device).unsqueeze(1) < num_valid_tokens)
142+
valid_output = torch.where(valid_token_mask, weighted_down_out,
143+
torch.zeros_like(weighted_down_out)).to(
144+
self.dtype)
145+
final_hidden_states.index_add_(0, self.sorted_token_indices,
146+
valid_output)
147+
148+
hidden_states[:] = final_hidden_states
149+
150+
151+
class AllReduceCommImpl(MoECommMethod):
152+
"""This implementation is for the scenarios listed below:
153+
1. `enable_expert_parallel=False`.
154+
2. If `npu_moe_init_routing_v2` is available, we will support `enable_expert_parallel=True`,
155+
and this implementation will become the default one, changing the name to `AllGather` at
156+
the same time.
157+
"""
158+
159+
def _pre_process(
160+
self,
161+
hidden_states: torch.Tensor,
162+
topk_ids: torch.Tensor,
163+
topk_weights: torch.Tensor,
164+
expert_map: torch.Tensor, # noqa: F841
165+
num_experts: int,
166+
) -> tuple[torch.Tensor, torch.Tensor, int]:
167+
num_tokens = hidden_states.shape[0]
168+
169+
self.topk_weights = topk_weights
170+
self.topk_ids = topk_ids
171+
172+
# 1. Prepare row indices for routing
173+
row_idx_len = num_tokens * self.top_k_num
174+
row_idx = torch.arange(row_idx_len,
175+
dtype=torch.int32,
176+
device=self.device)
177+
row_idx = row_idx.view(self.top_k_num, -1).permute(1, 0).contiguous()
178+
179+
# 2. Initial routing to expand tokens and experts
180+
permuted_hidden_states, expanded_row_idx, expanded_expert_idx = (
181+
torch_npu.npu_moe_init_routing(
182+
hidden_states,
183+
row_idx=row_idx,
184+
expert_idx=topk_ids,
185+
active_num=num_tokens,
186+
))
187+
# NOTE: Currently, V2 produces incorrect accuracy and weaker performance than V1
188+
# first_expert_idx = 0
189+
# if expert_map is not None:
190+
# first_expert_idx = torch.nonzero(expert_map != -1, as_tuple=False)[0].item()
191+
# last_expert_idx = first_expert_idx + num_experts
192+
# permuted_hidden_states, expanded_row_idx, expert_tokens, _ = (
193+
# torch_npu.npu_moe_init_routing_v2(
194+
# hidden_states,
195+
# topk_ids,
196+
# active_num=num_tokens * self.top_k_num,
197+
# expert_num=self.global_num_experts,
198+
# expert_tokens_num_type=1, # Only support `count` mode now
199+
# expert_tokens_num_flag=True, # Output `expert_tokens`
200+
# active_expert_range=[first_expert_idx, last_expert_idx],
201+
# quant_mode=-1,
202+
# )
203+
# )
204+
self.expanded_row_idx = expanded_row_idx
205+
permuted_hidden_states = permuted_hidden_states
206+
207+
# 3. Compute expert tokens
208+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
209+
expanded_expert_idx, num_experts).to(torch.int64)
210+
# NOTE: This is also for npu_moe_init_routing_v2
211+
# expert_tokens = torch.cumsum(expert_tokens, 0)
212+
213+
group_list_type = 0
214+
215+
return permuted_hidden_states, expert_tokens, group_list_type
216+
217+
def _post_process(self, mlp_output: torch.Tensor,
218+
hidden_states: torch.Tensor) -> None:
219+
hidden_states[:] = torch_npu.npu_moe_finalize_routing(
220+
mlp_output,
221+
skip1=None,
222+
skip2=None,
223+
bias=None,
224+
scales=self.topk_weights,
225+
expanded_src_to_dst_row=self.expanded_row_idx,
226+
export_for_source_row=self.topk_ids,
227+
# NOTE: For npu_moe_init_routing_v2
228+
# drop_pad_mode=2,
229+
)
230+
231+
232+
def moe_comm_pre_process(
233+
hidden_states: torch.Tensor,
234+
topk_ids: torch.Tensor,
235+
topk_weights: torch.Tensor,
236+
expert_map: torch.Tensor,
237+
num_experts: int,
238+
) -> tuple[torch.Tensor, torch.Tensor, int]:
239+
forward_context: ForwardContext = get_forward_context()
240+
self = forward_context.moe_comm_method
241+
return self._pre_process(hidden_states, topk_ids, topk_weights, expert_map,
242+
num_experts)
243+
244+
245+
def moe_comm_pre_process_fake(
246+
hidden_states: torch.Tensor,
247+
topk_ids: torch.Tensor,
248+
topk_weights: torch.Tensor,
249+
expert_map: torch.Tensor,
250+
num_experts: int,
251+
) -> tuple[torch.Tensor, torch.Tensor, int]:
252+
top_k_num = topk_ids.shape[1]
253+
permuted_hidden_states = hidden_states.repeat_interleave(top_k_num, dim=0)
254+
expert_tokens = torch.zeros((num_experts, ),
255+
dtype=torch.int64,
256+
device=hidden_states.device)
257+
group_list_type = 0
258+
return permuted_hidden_states, expert_tokens, group_list_type
259+
260+
261+
def moe_comm_post_process(mlp_output: torch.Tensor,
262+
hidden_states: torch.Tensor) -> None:
263+
forward_context: ForwardContext = get_forward_context()
264+
self = forward_context.moe_comm_method
265+
self._post_process(mlp_output, hidden_states)
266+
return
267+
268+
269+
direct_register_custom_op(
270+
op_name="moe_comm_pre_process",
271+
op_func=moe_comm_pre_process,
272+
mutates_args=[],
273+
fake_impl=moe_comm_pre_process_fake,
274+
dispatch_key="PrivateUse1",
275+
)
276+
277+
direct_register_custom_op(
278+
op_name="moe_comm_post_process",
279+
op_func=moe_comm_post_process,
280+
mutates_args=["hidden_states"],
281+
fake_impl=lambda x, y: None, # No-op for fake implementation
282+
dispatch_key="PrivateUse1",
283+
)

vllm_ascend/ops/common_fused_moe.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919

2020
import torch
2121
from vllm.config import CompilationLevel, get_current_vllm_config
22+
from vllm.forward_context import get_forward_context
2223
from vllm.model_executor.layers.fused_moe.layer import \
2324
UnquantizedFusedMoEMethod
2425

25-
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
26-
select_experts)
26+
from vllm_ascend.ops.fused_moe import (fused_experts_moge, select_experts,
27+
unified_fused_experts)
2728
from vllm_ascend.utils import is_310p
2829

2930
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
@@ -86,20 +87,18 @@ def forward_oot(
8687
expert_map=expert_map,
8788
apply_router_weight_on_input=apply_router_weight_on_input)
8889

89-
# If use aclgraph, we need to set max_num_tokens to make
90-
# the input shape of `npu_moe_init_routing` fixed
91-
max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None
90+
moe_comm_method = get_forward_context().moe_comm_method
9291

93-
return fused_experts(
92+
return unified_fused_experts(
9493
hidden_states=x,
9594
w1=layer.w13_weight,
9695
w2=layer.w2_weight,
9796
topk_weights=topk_weights,
9897
topk_ids=topk_ids,
99-
top_k=top_k,
98+
global_num_experts=global_num_experts,
10099
expert_map=expert_map,
101-
apply_router_weight_on_input=apply_router_weight_on_input,
102-
max_num_tokens=max_num_tokens)
100+
moe_comm_method=moe_comm_method,
101+
)
103102

104103

105104
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func

0 commit comments

Comments
 (0)