Skip to content

Commit 41e47d8

Browse files
author
angazenn
committed
support MERRouter
Signed-off-by: angazenn <zengyanjia@huawei.com>
1 parent 0060886 commit 41e47d8

File tree

2 files changed

+61
-26
lines changed

2 files changed

+61
-26
lines changed

vllm_ascend/ascend_config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,27 @@ def __init__(self, torchair_graph_config):
7272
)
7373

7474

75+
class AscendModelConfig:
76+
"""
77+
Configuration Object for ascend_model_config from additional_config
78+
"""
79+
80+
def __init__(self, ascend_model_config: dict):
81+
self.num_voted_experts = ascend_model_config.get(
82+
"num_voted_experts", None)
83+
84+
if self.num_voted_experts is None:
85+
self.num_voted_experts = 8
86+
else:
87+
logger.info(
88+
"Currently, MERRouter voted experts are only implemented for PanguProMoE. "
89+
"For other models, setting this value will not take any effects.")
90+
91+
if not isinstance(self.num_voted_experts, int) or \
92+
self.num_voted_experts <= 0 or self.num_voted_experts > 8:
93+
raise ValueError("num_voted_experts should be an integer within the range of (0, 8].")
94+
95+
7596
class AscendSchedulerConfig:
7697
"""
7798
Configuration Object for ascend_scheduler_config from additional_config

vllm_ascend/models/pangu_moe.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from vllm.model_executor.sampling_metadata import SamplingMetadata
5050
from vllm.sequence import IntermediateTensors
5151

52+
from vllm_ascend.ascend_config import get_ascend_config
5253
from vllm_ascend.distributed.parallel_state import get_ep_group
5354

5455
logger = init_logger(__name__)
@@ -102,38 +103,48 @@ def pangu_group8_topk(
102103
hidden_states: torch.Tensor,
103104
gating_output: torch.Tensor,
104105
topk: int,
105-
renormalize: bool,
106+
renormalize: bool = False,
106107
num_expert_group: int = 0,
107108
topk_group: int = 0,
108109
global_num_experts: int = 0,
109110
):
110-
ep_size = get_ep_group().world_size
111-
local_num_experts = global_num_experts // ep_size
112-
local_num_group = topk // ep_size
113-
router_scale = _ROUTER_SCALE.squeeze() # type: ignore
114-
scores = F.softmax(gating_output, dim=1)
115-
scores = scores[...,
116-
get_ep_group().rank_in_group *
117-
local_num_experts:(get_ep_group().rank_in_group + 1) *
118-
local_num_experts]
119-
120-
router_weights = router_scale[get_ep_group().rank_in_group *
121-
local_num_experts:
122-
(get_ep_group().rank_in_group + 1) *
123-
local_num_experts]
124-
topk_weights, topk_ids = torch.max(scores.view(scores.shape[0],
125-
local_num_group, -1),
126-
dim=-1)
127-
bias = torch.arange(0,
128-
local_num_experts,
129-
topk,
130-
device=scores.device,
131-
dtype=torch.int32).unsqueeze(0)
132-
topk_ids = topk_ids.to(torch.int32) + bias
111+
local_num_experts = global_num_experts
112+
local_num_group = topk
113+
scores = F.softmax(gating_output, dim=1, dtype=torch.float16)
114+
num_tokens = scores.shape[0]
115+
router_weights = _ROUTER_SCALE.squeeze().to(torch.float16)
116+
117+
if self.num_voted_experts == 8:
118+
# use original topk
119+
topk_weights, topk_ids = torch.max(scores.view(scores.shape[0], local_num_group, -1), dim = -1)
120+
bias = torch.arange(0, local_num_experts, topk, device=scores.device, dtype=torch.int32).unsqueeze(0)
121+
topk_ids = topk_ids.to(torch.int32) + bias
133122

123+
else:
124+
k = self.num_voted_experts
125+
experts_per_group = local_num_experts // local_num_group
126+
group_expert_indices = torch.arange(experts_per_group, dtype=torch.int32, device=scores.device).view(1, 1, -1)
127+
group_expert_offset = (torch.arange(local_num_group, dtype=torch.int32, device=scores.device) * experts_per_group).unsqueeze(0)
128+
expert_index_range = torch.arange(experts_per_group, dtype=torch.int32, device=scores.device)
129+
130+
scores_grouped = scores.view(num_tokens, local_num_group, experts_per_group)
131+
best_expert_idx = torch.argmax(scores_grouped, dim=2) # (num_tokens, num_groups)
132+
vote_mask = (best_expert_idx.unsqueeze(-1).to(torch.int32) == group_expert_indices).to(torch.float16)
133+
134+
expert_vote_freq = vote_mask.sum(dim=0)
135+
136+
sorted_indices = torch.argsort(expert_vote_freq, dim=1, descending=True).to(torch.int32)
137+
topk_experts = sorted_indices[:, :k]
138+
keep_mask = ((topk_experts.unsqueeze(-1) == expert_index_range).any(dim=1)).unsqueeze(0)
139+
140+
masked_scores = torch.where(keep_mask, scores_grouped, 0)
141+
142+
topk_weights, best_pos_in_group = masked_scores.max(dim=2)
143+
best_pos_in_group = best_pos_in_group.to(torch.int32)
144+
topk_ids = (best_pos_in_group + group_expert_offset).to(torch.int32)
145+
134146
flatten_topk_ids = topk_ids.view(-1)
135-
router_weights = router_weights.index_select(0, flatten_topk_ids).view(
136-
topk_ids.shape)
147+
router_weights = router_weights.index_select(0, flatten_topk_ids).view(topk_ids.shape)
137148
topk_weights *= router_weights
138149

139150
return topk_weights, topk_ids
@@ -192,6 +203,9 @@ def __init__(
192203
)
193204
else:
194205
self.shared_expert = None # type: ignore
206+
207+
ascend_config = get_ascend_config()
208+
self.num_voted_experts = ascend_config.ascend_model_config.num_voted_experts
195209

196210
def forward(
197211
self,

0 commit comments

Comments
 (0)