diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 86974770c5..0374d61e07 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -172,8 +172,6 @@ def fused_experts_with_mc2( npu_wait_tensor(shared_gate_up, expand_x) shared_act = shared_experts.act_fn(shared_gate_up) - w1 = w1.transpose(1, 2) - group_list = expert_token_nums.to(torch.int64) gate_up_out_list = torch_npu.npu_grouped_matmul( x=[expand_x], @@ -189,7 +187,6 @@ def fused_experts_with_mc2( gate_up_out = torch.cat(gate_up_out_list, dim=0) gate_up_out = torch_npu.npu_swiglu(gate_up_out) - w2 = w2.transpose(1, 2) down_out_list = torch_npu.npu_grouped_matmul( x=[gate_up_out], weight=[w2], @@ -266,7 +263,6 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor], assert len(hidden_states_wrapper) == 1 hidden_states = hidden_states_wrapper.pop() - w1 = w1.transpose(1, 2) hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w1], @@ -369,7 +365,6 @@ def fused_experts_with_all2all( expanded_expert_idx, num_experts) expert_tokens = expert_tokens.to(torch.int64) - w1 = w1.transpose(1, 2) gate_up_out_list = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w1], @@ -611,7 +606,6 @@ def fused_experts_moge( 0, sorted_topk_ids).unsqueeze(-1) group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) - w1 = w1.transpose(1, 2) gate_up_out = torch_npu.npu_grouped_matmul( x=[sorted_hidden_states], weight=[w1], @@ -628,7 +622,6 @@ def fused_experts_moge( gate_up_out = torch_npu.npu_swiglu(gate_up_out) gate_up_out *= topk_scales - w2 = w2.transpose(1, 2) down_out_list = torch_npu.npu_grouped_matmul( x=[gate_up_out], weight=[w2], @@ -760,7 +753,6 @@ def fused_experts( expanded_expert_idx, num_experts) expert_tokens = expert_tokens.to(torch.int64) - w1 = w1.transpose(1, 2) gate_up_out_list = torch_npu.npu_grouped_matmul( x=[sorted_hidden_states], weight=[w1], @@ -774,7 +766,6 @@ def fused_experts( gate_up_out = torch.cat(gate_up_out_list, dim=0) gate_up_out = torch_npu.npu_swiglu(gate_up_out) - w2 = w2.transpose(1, 2) down_out_list = torch_npu.npu_grouped_matmul( x=[gate_up_out], weight=[w2], @@ -1003,12 +994,13 @@ def __init__(self, moe: FusedMoEConfig = None): def process_weights_after_loading(self, layer): super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer) - layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight( - layer.w13_weight.data), - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight( - layer.w2_weight.data), - requires_grad=False) + w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose( + 1, 2).contiguous() + layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False) + + w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose( + 1, 2).contiguous() + layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) def apply( self,