diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 6c89f6fc1d..46192f6168 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -112,7 +112,7 @@ def mock_moe_env(mocker: MockerFixture): torch.randn(16, 2) )), \ patch("torch_npu.npu_grouped_matmul", return_value=( - (torch.randn(8, 2), torch.randn(8, 2)) + [torch.randn(16, 2)] )), \ patch("torch_npu.npu_swiglu", return_value=( torch.randn(16, 2) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 04d288b063..550d97c0ff 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -204,11 +204,9 @@ def fused_experts_with_mc2( group_list_type=1, group_type=0, group_list=group_list, - ) + )[0] - # TODO: Remove this in the future. - gate_up_out = torch.cat(gate_up_out_list, dim=0) - gate_up_out = torch_npu.npu_swiglu(gate_up_out) + gate_up_out = torch_npu.npu_swiglu(gate_up_out_list) w2 = w2.transpose(1, 2) down_out_list = torch_npu.npu_grouped_matmul( @@ -218,9 +216,7 @@ def fused_experts_with_mc2( group_list_type=1, group_type=0, group_list=group_list, - ) - - down_out_list = torch.cat(down_out_list, dim=0) + )[0] # moeCombine kwargs_mc2 = { @@ -311,9 +307,8 @@ def apply_mlp( group_list_type=group_list_type, group_type=0, group_list=group_list, - ) + )[0] - hidden_states = torch.cat(hidden_states, dim=0) hidden_states = torch_npu.npu_swiglu(hidden_states) w2 = w2.transpose(1, 2) @@ -324,9 +319,8 @@ def apply_mlp( group_list_type=group_list_type, group_type=0, group_list=group_list, - ) + )[0] - hidden_states = torch.cat(hidden_states, dim=0) return hidden_states @@ -416,23 +410,19 @@ def fused_experts_with_all2all( group_list_type=0, group_type=0, group_list=expert_tokens, - ) + )[0] - # TODO: Remove this in the future. - hidden_states = torch.cat(gate_up_out_list, dim=0) - hidden_states = torch_npu.npu_swiglu(hidden_states) + hidden_states = torch_npu.npu_swiglu(gate_up_out_list) w2 = w2.transpose(1, 2) - down_out_list = torch_npu.npu_grouped_matmul( + hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w2], split_item=2, group_list_type=0, group_type=0, group_list=expert_tokens, - ) - - hidden_states = torch.cat(down_out_list, dim=0) + )[0] if expert_map is not None: resorted_idx = torch.argsort(sorted_idx) @@ -822,11 +812,9 @@ def fused_experts( group_list_type=0, group_type=0, group_list=expert_tokens, - ) + )[0] - # TODO: Remove this in the future. - gate_up_out = torch.cat(gate_up_out_list, dim=0) - gate_up_out = torch_npu.npu_swiglu(gate_up_out) + gate_up_out = torch_npu.npu_swiglu(gate_up_out_list) w2 = w2.transpose(1, 2) down_out_list = torch_npu.npu_grouped_matmul( @@ -836,9 +824,7 @@ def fused_experts( group_list_type=0, group_type=0, group_list=expert_tokens, - ) - - down_out_list = torch.cat(down_out_list, dim=0) + )[0] if expert_map is not None: weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)