|
49 | 49 | from vllm.model_executor.sampling_metadata import SamplingMetadata |
50 | 50 | from vllm.sequence import IntermediateTensors |
51 | 51 |
|
| 52 | +from vllm_ascend.ascend_config import get_ascend_config |
52 | 53 | from vllm_ascend.distributed.parallel_state import get_ep_group |
53 | 54 |
|
54 | 55 | logger = init_logger(__name__) |
@@ -102,38 +103,48 @@ def pangu_group8_topk( |
102 | 103 | hidden_states: torch.Tensor, |
103 | 104 | gating_output: torch.Tensor, |
104 | 105 | topk: int, |
105 | | - renormalize: bool, |
| 106 | + renormalize: bool = False, |
106 | 107 | num_expert_group: int = 0, |
107 | 108 | topk_group: int = 0, |
108 | 109 | global_num_experts: int = 0, |
109 | 110 | ): |
110 | | - ep_size = get_ep_group().world_size |
111 | | - local_num_experts = global_num_experts // ep_size |
112 | | - local_num_group = topk // ep_size |
113 | | - router_scale = _ROUTER_SCALE.squeeze() # type: ignore |
114 | | - scores = F.softmax(gating_output, dim=1) |
115 | | - scores = scores[..., |
116 | | - get_ep_group().rank_in_group * |
117 | | - local_num_experts:(get_ep_group().rank_in_group + 1) * |
118 | | - local_num_experts] |
119 | | - |
120 | | - router_weights = router_scale[get_ep_group().rank_in_group * |
121 | | - local_num_experts: |
122 | | - (get_ep_group().rank_in_group + 1) * |
123 | | - local_num_experts] |
124 | | - topk_weights, topk_ids = torch.max(scores.view(scores.shape[0], |
125 | | - local_num_group, -1), |
126 | | - dim=-1) |
127 | | - bias = torch.arange(0, |
128 | | - local_num_experts, |
129 | | - topk, |
130 | | - device=scores.device, |
131 | | - dtype=torch.int32).unsqueeze(0) |
132 | | - topk_ids = topk_ids.to(torch.int32) + bias |
| 111 | + local_num_experts = global_num_experts |
| 112 | + local_num_group = topk |
| 113 | + scores = F.softmax(gating_output, dim=1, dtype=torch.float16) |
| 114 | + num_tokens = scores.shape[0] |
| 115 | + router_weights = _ROUTER_SCALE.squeeze().to(torch.float16) |
| 116 | + |
| 117 | + if self.num_voted_experts == 8: |
| 118 | + # use original topk |
| 119 | + topk_weights, topk_ids = torch.max(scores.view(scores.shape[0], local_num_group, -1), dim = -1) |
| 120 | + bias = torch.arange(0, local_num_experts, topk, device=scores.device, dtype=torch.int32).unsqueeze(0) |
| 121 | + topk_ids = topk_ids.to(torch.int32) + bias |
133 | 122 |
|
| 123 | + else: |
| 124 | + k = self.num_voted_experts |
| 125 | + experts_per_group = local_num_experts // local_num_group |
| 126 | + group_expert_indices = torch.arange(experts_per_group, dtype=torch.int32, device=scores.device).view(1, 1, -1) |
| 127 | + group_expert_offset = (torch.arange(local_num_group, dtype=torch.int32, device=scores.device) * experts_per_group).unsqueeze(0) |
| 128 | + expert_index_range = torch.arange(experts_per_group, dtype=torch.int32, device=scores.device) |
| 129 | + |
| 130 | + scores_grouped = scores.view(num_tokens, local_num_group, experts_per_group) |
| 131 | + best_expert_idx = torch.argmax(scores_grouped, dim=2) # (num_tokens, num_groups) |
| 132 | + vote_mask = (best_expert_idx.unsqueeze(-1).to(torch.int32) == group_expert_indices).to(torch.float16) |
| 133 | + |
| 134 | + expert_vote_freq = vote_mask.sum(dim=0) |
| 135 | + |
| 136 | + sorted_indices = torch.argsort(expert_vote_freq, dim=1, descending=True).to(torch.int32) |
| 137 | + topk_experts = sorted_indices[:, :k] |
| 138 | + keep_mask = ((topk_experts.unsqueeze(-1) == expert_index_range).any(dim=1)).unsqueeze(0) |
| 139 | + |
| 140 | + masked_scores = torch.where(keep_mask, scores_grouped, 0) |
| 141 | + |
| 142 | + topk_weights, best_pos_in_group = masked_scores.max(dim=2) |
| 143 | + best_pos_in_group = best_pos_in_group.to(torch.int32) |
| 144 | + topk_ids = (best_pos_in_group + group_expert_offset).to(torch.int32) |
| 145 | + |
134 | 146 | flatten_topk_ids = topk_ids.view(-1) |
135 | | - router_weights = router_weights.index_select(0, flatten_topk_ids).view( |
136 | | - topk_ids.shape) |
| 147 | + router_weights = router_weights.index_select(0, flatten_topk_ids).view(topk_ids.shape) |
137 | 148 | topk_weights *= router_weights |
138 | 149 |
|
139 | 150 | return topk_weights, topk_ids |
@@ -192,6 +203,9 @@ def __init__( |
192 | 203 | ) |
193 | 204 | else: |
194 | 205 | self.shared_expert = None # type: ignore |
| 206 | + |
| 207 | + ascend_config = get_ascend_config() |
| 208 | + self.num_voted_experts = ascend_config.ascend_model_config.num_voted_experts |
195 | 209 |
|
196 | 210 | def forward( |
197 | 211 | self, |
|
0 commit comments