Skip to content

Commit b8d514d

Browse files
author
Yizhou Liu
committed
[Test] Fix test case and format
Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com>
1 parent 9bd073c commit b8d514d

File tree

2 files changed

+130
-134
lines changed

2 files changed

+130
-134
lines changed

tests/ops/test_fused_moe.py

Lines changed: 77 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,3 @@
1-
# import pytest
2-
# import torch
3-
# from vllm_ascend.ops.fused_moe import fused_moe
4-
5-
6-
# def test_fused_moe():
7-
# # Since we are using native PyTorch operations in the function, the most reliable ground truth
8-
# # for comparison is the manually computed output. By using hardcoded data, we can ensure
9-
# # that the function produces the expected results and validate its correctness against a known reference.
10-
11-
# # Step 1: Constructing inputs
12-
# hidden_states = torch.tensor([[1.0, 2.0, 3.0], [3.0, 2.0, 1.0]])
13-
14-
# # w1: [3, 4, 3] (num_experts=3, intermediate_size*2=4, hidden_size=3)
15-
# w1 = torch.tensor(
16-
# [
17-
# [[1.0, 0.0, -1.0], [2.0, 1.0, 0.0], [1.0, 1.0, -1.0], [1.0, -1.0, 1.0]],
18-
# [[-1.0, 1.0, 1.0], [1.0, -1.0, 1.0], [2.0, -2.0, 2.0], [1.0, 0.0, -1.0]],
19-
# [[-2.0, -1.0, 1.0], [2.0, -1.0, 1.0], [-1.0, 2.0, 1.0], [1.0, 1.0, -1.0]],
20-
# ]
21-
# )
22-
23-
# # w2: [3, 3, 2] (num_experts=3, hidden_size=3, intermediate_size=2)
24-
# w2 = torch.tensor(
25-
# [
26-
# [[1.0, 0.5], [2.0, -1.0], [0.0, 1.0]],
27-
# [[1.0, 1.0], [-1.0, 1.0], [1.0, -0.0]],
28-
# [[-2.0, 1.0], [1.0, -1.0], [2.0, 1.0]],
29-
# ]
30-
# )
31-
32-
# # gating_output: [2, 3] (num_tokens=2, num_experts=3)
33-
# gating_output = torch.tensor([[0.0, 0.5, 0.5], [0.5, 0.5, 0.0]])
34-
35-
# topk = 2
36-
37-
# global_num_experts = 3
38-
39-
# # Only has the first two experts
40-
# expert_map = torch.tensor([0, 1, -1])
41-
42-
# renormalize = False
43-
44-
# use_grouped_topk = False
45-
46-
# # Step 2: Expected output calculation
47-
48-
# # We use topk=2, which means we select the top 2 experts based on gating_output.
49-
# # For sample 1, gating_output = [0.1, 0.7, 0.2], topk_weights = [0.7, 0.2], selected experts = 1, 2
50-
# # For sample 2, gating_output = [0.5, 0.4, 0.1], topk_weights = [0.5, 0.4], selected experts = 0, 1
51-
52-
# # 1. Calculate linear transformation of hidden_states with w1[0] -> F.linear(hidden_states, w1[0])
53-
# # 2. Apply gating function to get gate values -> F.silu(x[:, :intermediate_size])
54-
# # 3. Apply second linear transformation with w2[0] -> F.linear(x, w2[0])
55-
# # 4. Use the topk_weights for each sample and add the weighted outputs of experts 1 and 2
56-
57-
# expected_hidden_states = torch.tensor([[4.6763, -7.3797, 6.0280], [7.1232, 0.6220, 6.1364]])
58-
59-
# # Step 3: Running the fused_moe function
60-
# final_output = fused_moe(
61-
# hidden_states, w1, w2, gating_output, topk, global_num_experts, expert_map, renormalize, use_grouped_topk
62-
# )
63-
64-
# # Step 4: Check the shape and values (this should match the expected result you computed manually)
65-
# assert (
66-
# final_output.shape == hidden_states.shape
67-
# ), f"Expected shape {hidden_states.shape}, but got {final_output.shape}"
68-
69-
# assert torch.allclose(final_output, expected_hidden_states, atol=1e-4), "Output does not match expected result"
70-
71-
#
72-
73-
741
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
752
# This file is a part of the vllm-ascend project.
763
# Adapted from vllm/tests/kernels/test_moe.py
@@ -87,72 +14,121 @@
8714
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8815
# See the License for the specific language governing permissions and
8916
# limitations under the License.
90-
#
9117
# SPDX-License-Identifier: Apache-2.0
92-
9318
"""Tests for the MOE layers.
9419
9520
Run `pytest tests/ops/test_moe.py`.
9621
"""
22+
from types import SimpleNamespace
23+
9724
import pytest
9825
import torch
99-
10026
from vllm.model_executor.layers.activation import SiluAndMul
101-
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import fused_moe as iterative_moe
102-
from vllm_ascend.ops.fused_moe import fused_experts, fused_experts_with_ep
27+
28+
from vllm_ascend.ops.fused_moe import forward_oot
10329

10430
NUM_EXPERTS = [8, 64]
105-
EP_SIZE = [1, 4]
31+
EP_SIZE = [1]
10632
TOP_KS = [2, 6]
107-
DEVICE = ["npu"]
33+
DEVICE = ["npu:0"]
10834

10935

110-
def torch_moe(a, w1, w2, score, topk, expert_map):
36+
def torch_moe(a, w1, w2, score, topk, renormalize, num_expert_group,
37+
topk_group, scoring_func, e_score_correction_bias):
11138
B, D = a.shape
11239
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
11340
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
114-
score = torch.softmax(score, dim=-1, dtype=torch.float32)
115-
topk_weight, topk_ids = torch.topk(score, topk)
41+
42+
if scoring_func == "softmax":
43+
score = torch.softmax(score, dim=-1)
44+
elif scoring_func == "sigmoid":
45+
score = score.sigmoid()
46+
47+
if e_score_correction_bias is not None:
48+
original_scores = score
49+
score = score + e_score_correction_bias.unsqueeze(0)
50+
51+
# group_topk
52+
num_token = score.shape[0]
53+
group_score = score.view(num_token, num_expert_group,
54+
-1).max(dim=-1).values
55+
group_idx = torch.topk(group_score, k=topk_group, dim=-1,
56+
sorted=False)[1] # [n, top_k_group]
57+
group_mask = torch.zeros_like(group_score) # [n, n_group]
58+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
59+
score_mask = group_mask.unsqueeze(-1).expand(
60+
num_token, num_expert_group,
61+
score.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
62+
score = score.masked_fill(~score_mask.bool(), 0.0) # [n, e]
63+
64+
if e_score_correction_bias is not None:
65+
topk_ids = torch.topk(score, k=topk, dim=-1, sorted=False)[1]
66+
# Use original unbiased scores for the routing weights
67+
topk_weight = original_scores.gather(1, topk_ids)
68+
else:
69+
topk_weight, topk_ids = torch.topk(score, k=topk, dim=-1, sorted=False)
70+
71+
if renormalize:
72+
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
73+
11674
topk_weight = topk_weight.view(-1)
11775
topk_ids = topk_ids.view(-1)
118-
if expert_map is not None:
119-
topk_ids = expert_map[topk_ids]
76+
12077
for i in range(w1.shape[0]):
12178
mask = topk_ids == i
12279
if mask.sum():
123-
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
124-
return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
80+
out[mask] = SiluAndMul()(
81+
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
82+
return (out.view(B, -1, w2.shape[1]) *
83+
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
12584

12685

127-
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
128-
@pytest.mark.parametrize("n", [128, 1024, 2048])
129-
@pytest.mark.parametrize("k", [128, 511, 1024])
86+
@pytest.mark.parametrize("m", [1])
87+
@pytest.mark.parametrize("n", [128])
88+
@pytest.mark.parametrize("k", [128])
13089
@pytest.mark.parametrize("e", NUM_EXPERTS)
13190
@pytest.mark.parametrize("topk", TOP_KS)
13291
@pytest.mark.parametrize("ep_size", EP_SIZE)
133-
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
92+
@pytest.mark.parametrize("dtype", [torch.float16]) #, torch.bfloat16])
13493
@pytest.mark.parametrize("device", DEVICE)
135-
def test_fused_moe(m: int, n: int, k: int, e: int, topk: int, ep_size: int, dtype: torch.dtype, device: str):
94+
def test_fused_moe(m: int, n: int, k: int, e: int, topk: int, ep_size: int,
95+
dtype: torch.dtype, device: str):
13696
a = torch.randn((m, k), device=device, dtype=dtype) / 10
13797
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
13898
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
13999

140100
score = torch.randn((m, e), device=device, dtype=dtype)
101+
topk_weights, topk_ids = score.topk(topk, dim=-1)
102+
topk_weights = topk_weights.to(dtype)
103+
topk_ids = topk_ids.to(torch.int32)
104+
105+
layer = SimpleNamespace(w13_weight=w1, w2_weight=w2)
141106

142107
if ep_size > 1:
143108
local_e = e // ep_size
144-
e_ids = torch.randint(0, e, (local_e,), device=device, dtype=torch.int32)
145-
e_map = torch.full((e,), -1, device=device, dtype=torch.int32)
109+
e_ids = torch.randint(0,
110+
e, (local_e, ),
111+
device=device,
112+
dtype=torch.int32)
113+
e_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
146114
e_map[e_ids] = torch.arange(local_e, device=device, dtype=torch.int32)
147115
w1 = w1[e_ids]
148116
w2 = w2[e_ids]
149-
150-
output = fused_experts_with_ep(
151-
a, w1, w2, score, topk, global_num_experts=e, expert_map=e_map, renormalize=False
152-
)
153117
else:
154-
fused_moe = fused_experts
118+
e_map = None
119+
120+
output = forward_oot(None, layer, a, True, topk, topk_weights, True, 1, 1,
121+
-1, e_map)
122+
123+
torch_output = torch_moe(a,
124+
w1,
125+
w2,
126+
score,
127+
topk,
128+
renormalize=True,
129+
num_expert_group=1,
130+
topk_group=1,
131+
scoring_func='sigmoid',
132+
e_score_correction_bias=None)
155133

156-
output = fused_moe(a, w1, w2, score, topk, global_num_experts=e, expert_map=e_map, renormalize=False)
157-
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
158134
torch.testing.assert_close(output, torch_output, atol=2e-2, rtol=0)

0 commit comments

Comments
 (0)