|
| 1 | +# |
| 2 | +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. |
| 3 | +# This file is a part of the vllm-ascend project. |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | + |
| 18 | +from typing import Callable, Optional |
| 19 | + |
| 20 | +import torch |
| 21 | +import torch_npu |
| 22 | +from vllm.model_executor.layers.fused_moe.layer import \ |
| 23 | + UnquantizedFusedMoEMethod |
| 24 | + |
| 25 | + |
| 26 | +def group_topk(hidden_states: torch.Tensor, |
| 27 | + gating_output: torch.Tensor, |
| 28 | + topk: int, |
| 29 | + renormalize: bool, |
| 30 | + num_expert_group: Optional[int] = 0, |
| 31 | + topk_group: Optional[int] = 0, |
| 32 | + scoring_func: str = "softmax", |
| 33 | + e_score_correction_bias: Optional[torch.Tensor] = None): |
| 34 | + |
| 35 | + assert hidden_states.shape[0] == gating_output.shape[0], ( |
| 36 | + "Number of tokens mismatch") |
| 37 | + |
| 38 | + if scoring_func == "softmax": |
| 39 | + scores = torch.softmax(gating_output, dim=-1) |
| 40 | + elif scoring_func == "sigmoid": |
| 41 | + scores = gating_output.sigmoid() |
| 42 | + else: |
| 43 | + raise ValueError(f"Unsupported scoring function: {scoring_func}") |
| 44 | + |
| 45 | + if e_score_correction_bias is not None: |
| 46 | + # Store original scores before applying correction bias. We use biased |
| 47 | + # scores for expert selection but original scores for routing weights |
| 48 | + original_scores = scores |
| 49 | + scores = scores + e_score_correction_bias.unsqueeze(0) |
| 50 | + |
| 51 | + torch_npu.npu_group_topk(input=scores, |
| 52 | + out=scores, |
| 53 | + group_num=num_expert_group, |
| 54 | + k=topk_group) |
| 55 | + if e_score_correction_bias is not None: |
| 56 | + topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False)[1] |
| 57 | + # Use original unbiased scores for the routing weights |
| 58 | + topk_weights = original_scores.gather(1, topk_ids) |
| 59 | + else: |
| 60 | + topk_weights, topk_ids = torch.topk(scores, |
| 61 | + k=topk, |
| 62 | + dim=-1, |
| 63 | + sorted=False) |
| 64 | + |
| 65 | + if renormalize: |
| 66 | + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) |
| 67 | + |
| 68 | + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) |
| 69 | + |
| 70 | + |
| 71 | +def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, |
| 72 | + w2: torch.Tensor, topk_weights: torch.Tensor, |
| 73 | + topk_ids: torch.Tensor, top_k: int): |
| 74 | + # Check constraints. |
| 75 | + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" |
| 76 | + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" |
| 77 | + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" |
| 78 | + assert w1.is_contiguous(), "Expert weights1 must be contiguous" |
| 79 | + assert w2.is_contiguous(), "Expert weights2 must be contiguous" |
| 80 | + assert hidden_states.dtype in [ |
| 81 | + torch.float32, torch.float16, torch.bfloat16 |
| 82 | + ] |
| 83 | + ori_shape = hidden_states.shape |
| 84 | + if len(ori_shape) == 3: |
| 85 | + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) |
| 86 | + |
| 87 | + num_tokens, _ = hidden_states.shape |
| 88 | + E, N, _ = w1.shape |
| 89 | + |
| 90 | + row_idx_len = num_tokens * top_k |
| 91 | + row_idx = torch.arange(0, |
| 92 | + row_idx_len, |
| 93 | + dtype=torch.int32, |
| 94 | + device=topk_weights.device).view(top_k, -1).permute( |
| 95 | + 1, 0).contiguous() |
| 96 | + expanded_x, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( |
| 97 | + hidden_states, |
| 98 | + row_idx=row_idx, |
| 99 | + expert_idx=topk_ids, |
| 100 | + active_num=num_tokens) |
| 101 | + |
| 102 | + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( |
| 103 | + expanded_expert_idx, E) |
| 104 | + expert_tokens = expert_tokens.to(torch.int64) |
| 105 | + |
| 106 | + w1 = w1.transpose(1, 2) |
| 107 | + gate_up_out_list = torch_npu.npu_grouped_matmul(x=[expanded_x], |
| 108 | + weight=[w1], |
| 109 | + split_item=2, |
| 110 | + group_list_type=0, |
| 111 | + group_type=0, |
| 112 | + group_list=expert_tokens) |
| 113 | + |
| 114 | + # TODO: Remove this in the future. |
| 115 | + gate_up_out = torch.cat(gate_up_out_list, dim=0) |
| 116 | + gate_up_out = torch_npu.npu_swiglu(gate_up_out) |
| 117 | + |
| 118 | + w2 = w2.transpose(1, 2) |
| 119 | + down_out_list = torch_npu.npu_grouped_matmul(x=[gate_up_out], |
| 120 | + weight=[w2], |
| 121 | + split_item=2, |
| 122 | + group_list_type=0, |
| 123 | + group_type=0, |
| 124 | + group_list=expert_tokens) |
| 125 | + |
| 126 | + down_out_list = torch.cat(down_out_list, dim=0) |
| 127 | + # TODO: Reorder device memory 2 times here, replace the current |
| 128 | + # implementation here when suitable operators become available. |
| 129 | + routing_weights = topk_weights.to(down_out_list.dtype) |
| 130 | + hidden_states = torch_npu.npu_moe_finalize_routing( |
| 131 | + down_out_list, |
| 132 | + skip1=None, |
| 133 | + skip2=None, |
| 134 | + bias=None, |
| 135 | + scales=routing_weights, |
| 136 | + expanded_src_to_dst_row=expanded_row_idx, |
| 137 | + export_for_source_row=topk_ids) |
| 138 | + if len(ori_shape) == 3: |
| 139 | + hidden_states = hidden_states.view(ori_shape) |
| 140 | + return hidden_states |
| 141 | + |
| 142 | + |
| 143 | +def forward_oot( |
| 144 | + self, |
| 145 | + layer: torch.nn.Module, |
| 146 | + x: torch.Tensor, |
| 147 | + use_grouped_topk: bool, |
| 148 | + top_k: int, |
| 149 | + router_logits: torch.Tensor, |
| 150 | + renormalize: bool, |
| 151 | + topk_group: Optional[int] = None, |
| 152 | + num_expert_group: Optional[int] = None, |
| 153 | + custom_routing_function: Optional[Callable] = None, |
| 154 | + scoring_func: str = "softmax", |
| 155 | + e_score_correction_bias: Optional[torch.Tensor] = None |
| 156 | +) -> torch.Tensor: |
| 157 | + |
| 158 | + topk_weights, topk_ids = group_topk( |
| 159 | + hidden_states=x, |
| 160 | + gating_output=router_logits, |
| 161 | + topk=top_k, |
| 162 | + renormalize=renormalize, |
| 163 | + num_expert_group=num_expert_group, |
| 164 | + topk_group=topk_group, |
| 165 | + scoring_func=scoring_func, |
| 166 | + e_score_correction_bias=e_score_correction_bias) |
| 167 | + |
| 168 | + return fused_experts(hidden_states=x, |
| 169 | + w1=layer.w13_weight, |
| 170 | + w2=layer.w2_weight, |
| 171 | + topk_weights=topk_weights, |
| 172 | + topk_ids=topk_ids, |
| 173 | + top_k=top_k) |
| 174 | + |
| 175 | + |
| 176 | +UnquantizedFusedMoEMethod.forward_oot = forward_oot |
0 commit comments