Skip to content

Commit 74f8b45

Browse files
author
angazenn
committed
support MERRouter
Signed-off-by: angazenn <zengyanjia@huawei.com>
1 parent 5571fb7 commit 74f8b45

File tree

2 files changed

+73
-33
lines changed

2 files changed

+73
-33
lines changed

vllm_ascend/models/pangu_moe.py

Lines changed: 66 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from vllm.model_executor.sampling_metadata import SamplingMetadata
5050
from vllm.sequence import IntermediateTensors
5151

52-
from vllm_ascend.distributed.parallel_state import get_ep_group
52+
from vllm_ascend.utils import is_310p
5353

5454
logger = init_logger(__name__)
5555

@@ -95,41 +95,69 @@ def forward(self, x):
9595
return x
9696

9797

98-
class PanguProMoESparseMoeBlock(nn.Module):
98+
def topk_wrapper(num_voted_experts):
9999

100-
@staticmethod
101100
def pangu_group8_topk(
102101
hidden_states: torch.Tensor,
103102
gating_output: torch.Tensor,
104103
topk: int,
105-
renormalize: bool,
104+
renormalize: bool = False,
106105
num_expert_group: int = 0,
107106
topk_group: int = 0,
108107
global_num_experts: int = 0,
109108
):
110-
ep_size = get_ep_group().world_size
111-
local_num_experts = global_num_experts // ep_size
112-
local_num_group = topk // ep_size
113-
router_scale = _ROUTER_SCALE.squeeze() # type: ignore
114-
scores = F.softmax(gating_output, dim=1)
115-
scores = scores[...,
116-
get_ep_group().rank_in_group *
117-
local_num_experts:(get_ep_group().rank_in_group + 1) *
118-
local_num_experts]
119-
120-
router_weights = router_scale[get_ep_group().rank_in_group *
121-
local_num_experts:
122-
(get_ep_group().rank_in_group + 1) *
123-
local_num_experts]
124-
topk_weights, topk_ids = torch.max(scores.view(scores.shape[0],
125-
local_num_group, -1),
126-
dim=-1)
127-
bias = torch.arange(0,
128-
local_num_experts,
129-
topk,
130-
device=scores.device,
131-
dtype=torch.int32).unsqueeze(0)
132-
topk_ids = topk_ids.to(torch.int32) + bias
109+
scores = F.softmax(gating_output, dim=1, dtype=torch.float16)
110+
num_tokens = scores.shape[0]
111+
router_weights = _ROUTER_SCALE.squeeze( # type: ignore
112+
).to(torch.float16)
113+
114+
if num_voted_experts == 8:
115+
# use original topk
116+
topk_weights, topk_ids = torch.max(scores.view(
117+
scores.shape[0], topk, -1),
118+
dim=-1)
119+
bias = torch.arange(0,
120+
global_num_experts,
121+
topk,
122+
device=scores.device,
123+
dtype=torch.int32).unsqueeze(0)
124+
topk_ids = topk_ids.to(torch.int32) + bias
125+
126+
else:
127+
experts_per_group = global_num_experts // topk
128+
group_expert_indices = torch.arange(experts_per_group,
129+
dtype=torch.int32,
130+
device=scores.device).view(
131+
1, 1, -1)
132+
group_expert_offset = (
133+
torch.arange(topk, dtype=torch.int32, device=scores.device) *
134+
experts_per_group).unsqueeze(0)
135+
expert_index_range = torch.arange(experts_per_group,
136+
dtype=torch.int32,
137+
device=scores.device)
138+
139+
scores_grouped = scores.view(num_tokens, topk, experts_per_group)
140+
best_expert_idx = torch.argmax(scores_grouped,
141+
dim=2) # (num_tokens, num_groups)
142+
vote_mask = (best_expert_idx.unsqueeze(-1).to(
143+
torch.int32) == group_expert_indices).to(torch.float16)
144+
145+
expert_vote_freq = vote_mask.sum(dim=0)
146+
147+
sorted_indices = torch.argsort(expert_vote_freq,
148+
dim=1,
149+
descending=True).to(torch.int32)
150+
topk_experts = sorted_indices[:, :num_voted_experts]
151+
keep_mask = ((
152+
topk_experts.unsqueeze(-1) == expert_index_range).any(
153+
dim=1)).unsqueeze(0)
154+
155+
masked_scores = torch.where(keep_mask, scores_grouped, 0)
156+
157+
topk_weights, best_pos_in_group = masked_scores.max(dim=2)
158+
best_pos_in_group = best_pos_in_group.to(torch.int32)
159+
topk_ids = (best_pos_in_group + group_expert_offset).to(
160+
torch.int32)
133161

134162
flatten_topk_ids = topk_ids.view(-1)
135163
router_weights = router_weights.index_select(0, flatten_topk_ids).view(
@@ -138,6 +166,11 @@ def pangu_group8_topk(
138166

139167
return topk_weights, topk_ids
140168

169+
return pangu_group8_topk
170+
171+
172+
class PanguProMoESparseMoeBlock(nn.Module):
173+
141174
def __init__(
142175
self,
143176
config: PretrainedConfig,
@@ -153,23 +186,23 @@ def __init__(
153186
f"Tensor parallel size {self.tp_size} is greater than "
154187
f"the number of experts {config.num_experts}.")
155188

156-
self.local_num_group = config.num_experts_per_tok // get_ep_group(
157-
).world_size
158189
self.num_experts_per_tok = config.num_experts_per_tok
159-
self.local_num_experts = config.num_experts // get_ep_group(
160-
).world_size
161190
self.router_scale = torch.nn.Parameter(
162191
torch.ones((1, self.num_experts)))
163192

193+
# on 300I Duo platform, we find that num_voted_experts set to 5 achieves
194+
# good performance without sacrifice too much accuracy. for other platform,
195+
# this is set to 8 to use original pangu grouped topk.
196+
num_voted_experts = 5 if is_310p() else 8
197+
164198
self.experts = FusedMoE(
165199
num_experts=config.num_experts,
166200
top_k=config.num_experts_per_tok,
167201
hidden_size=config.hidden_size,
168202
intermediate_size=config.moe_intermediate_size,
169203
reduce_results=False,
170204
quant_config=quant_config,
171-
custom_routing_function=PanguProMoESparseMoeBlock.
172-
pangu_group8_topk,
205+
custom_routing_function=topk_wrapper(num_voted_experts),
173206
prefix=f"{prefix}.experts",
174207
)
175208

vllm_ascend/ops/fused_moe.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,13 @@ def fused_experts_310p(
578578
local_num_experts = global_num_experts // ep_size
579579
local_num_group = top_k // ep_size
580580

581+
if ep_size > 1:
582+
ep_rank = get_ep_group().rank_in_group
583+
local_group_start = ep_rank * local_num_experts
584+
local_group_end = (ep_rank + 1) * local_num_experts
585+
topk_ids = topk_ids[:, local_group_start:local_group_end]
586+
topk_weights = topk_weights[:, local_group_start:local_group_end]
587+
581588
if apply_router_weight_on_input:
582589
assert (topk_weights.dim() == 2
583590
), "`topk_weights` should be in shape (num_tokens, topk)"

0 commit comments

Comments
 (0)