|
1 | | -# import pytest |
2 | | -# import torch |
3 | | -# from vllm_ascend.ops.fused_moe import fused_moe |
4 | | - |
5 | | - |
6 | | -# def test_fused_moe(): |
7 | | -# # Since we are using native PyTorch operations in the function, the most reliable ground truth |
8 | | -# # for comparison is the manually computed output. By using hardcoded data, we can ensure |
9 | | -# # that the function produces the expected results and validate its correctness against a known reference. |
10 | | - |
11 | | -# # Step 1: Constructing inputs |
12 | | -# hidden_states = torch.tensor([[1.0, 2.0, 3.0], [3.0, 2.0, 1.0]]) |
13 | | - |
14 | | -# # w1: [3, 4, 3] (num_experts=3, intermediate_size*2=4, hidden_size=3) |
15 | | -# w1 = torch.tensor( |
16 | | -# [ |
17 | | -# [[1.0, 0.0, -1.0], [2.0, 1.0, 0.0], [1.0, 1.0, -1.0], [1.0, -1.0, 1.0]], |
18 | | -# [[-1.0, 1.0, 1.0], [1.0, -1.0, 1.0], [2.0, -2.0, 2.0], [1.0, 0.0, -1.0]], |
19 | | -# [[-2.0, -1.0, 1.0], [2.0, -1.0, 1.0], [-1.0, 2.0, 1.0], [1.0, 1.0, -1.0]], |
20 | | -# ] |
21 | | -# ) |
22 | | - |
23 | | -# # w2: [3, 3, 2] (num_experts=3, hidden_size=3, intermediate_size=2) |
24 | | -# w2 = torch.tensor( |
25 | | -# [ |
26 | | -# [[1.0, 0.5], [2.0, -1.0], [0.0, 1.0]], |
27 | | -# [[1.0, 1.0], [-1.0, 1.0], [1.0, -0.0]], |
28 | | -# [[-2.0, 1.0], [1.0, -1.0], [2.0, 1.0]], |
29 | | -# ] |
30 | | -# ) |
31 | | - |
32 | | -# # gating_output: [2, 3] (num_tokens=2, num_experts=3) |
33 | | -# gating_output = torch.tensor([[0.0, 0.5, 0.5], [0.5, 0.5, 0.0]]) |
34 | | - |
35 | | -# topk = 2 |
36 | | - |
37 | | -# global_num_experts = 3 |
38 | | - |
39 | | -# # Only has the first two experts |
40 | | -# expert_map = torch.tensor([0, 1, -1]) |
41 | | - |
42 | | -# renormalize = False |
43 | | - |
44 | | -# use_grouped_topk = False |
45 | | - |
46 | | -# # Step 2: Expected output calculation |
47 | | - |
48 | | -# # We use topk=2, which means we select the top 2 experts based on gating_output. |
49 | | -# # For sample 1, gating_output = [0.1, 0.7, 0.2], topk_weights = [0.7, 0.2], selected experts = 1, 2 |
50 | | -# # For sample 2, gating_output = [0.5, 0.4, 0.1], topk_weights = [0.5, 0.4], selected experts = 0, 1 |
51 | | - |
52 | | -# # 1. Calculate linear transformation of hidden_states with w1[0] -> F.linear(hidden_states, w1[0]) |
53 | | -# # 2. Apply gating function to get gate values -> F.silu(x[:, :intermediate_size]) |
54 | | -# # 3. Apply second linear transformation with w2[0] -> F.linear(x, w2[0]) |
55 | | -# # 4. Use the topk_weights for each sample and add the weighted outputs of experts 1 and 2 |
56 | | - |
57 | | -# expected_hidden_states = torch.tensor([[4.6763, -7.3797, 6.0280], [7.1232, 0.6220, 6.1364]]) |
58 | | - |
59 | | -# # Step 3: Running the fused_moe function |
60 | | -# final_output = fused_moe( |
61 | | -# hidden_states, w1, w2, gating_output, topk, global_num_experts, expert_map, renormalize, use_grouped_topk |
62 | | -# ) |
63 | | - |
64 | | -# # Step 4: Check the shape and values (this should match the expected result you computed manually) |
65 | | -# assert ( |
66 | | -# final_output.shape == hidden_states.shape |
67 | | -# ), f"Expected shape {hidden_states.shape}, but got {final_output.shape}" |
68 | | - |
69 | | -# assert torch.allclose(final_output, expected_hidden_states, atol=1e-4), "Output does not match expected result" |
70 | | - |
71 | | -# |
72 | | - |
73 | | - |
74 | 1 | # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. |
75 | 2 | # This file is a part of the vllm-ascend project. |
76 | 3 | # Adapted from vllm/tests/kernels/test_moe.py |
|
87 | 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
88 | 15 | # See the License for the specific language governing permissions and |
89 | 16 | # limitations under the License. |
90 | | -# |
91 | 17 | # SPDX-License-Identifier: Apache-2.0 |
92 | | - |
93 | 18 | """Tests for the MOE layers. |
94 | 19 |
|
95 | 20 | Run `pytest tests/ops/test_moe.py`. |
96 | 21 | """ |
| 22 | +from types import SimpleNamespace |
| 23 | + |
97 | 24 | import pytest |
98 | 25 | import torch |
99 | | - |
100 | 26 | from vllm.model_executor.layers.activation import SiluAndMul |
101 | | -from vllm.model_executor.layers.fused_moe.moe_torch_iterative import fused_moe as iterative_moe |
102 | | -from vllm_ascend.ops.fused_moe import fused_experts, fused_experts_with_ep |
| 27 | + |
| 28 | +from vllm_ascend.ops.fused_moe import forward_oot |
103 | 29 |
|
104 | 30 | NUM_EXPERTS = [8, 64] |
105 | | -EP_SIZE = [1, 4] |
| 31 | +EP_SIZE = [1] |
106 | 32 | TOP_KS = [2, 6] |
107 | | -DEVICE = ["npu"] |
| 33 | +DEVICE = ["npu:0"] |
108 | 34 |
|
109 | 35 |
|
110 | | -def torch_moe(a, w1, w2, score, topk, expert_map): |
| 36 | +def torch_moe(a, w1, w2, score, topk, renormalize, num_expert_group, |
| 37 | + topk_group, scoring_func, e_score_correction_bias): |
111 | 38 | B, D = a.shape |
112 | 39 | a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) |
113 | 40 | out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) |
114 | | - score = torch.softmax(score, dim=-1, dtype=torch.float32) |
115 | | - topk_weight, topk_ids = torch.topk(score, topk) |
| 41 | + |
| 42 | + if scoring_func == "softmax": |
| 43 | + score = torch.softmax(score, dim=-1) |
| 44 | + elif scoring_func == "sigmoid": |
| 45 | + score = score.sigmoid() |
| 46 | + |
| 47 | + if e_score_correction_bias is not None: |
| 48 | + original_scores = score |
| 49 | + score = score + e_score_correction_bias.unsqueeze(0) |
| 50 | + |
| 51 | + # group_topk |
| 52 | + num_token = score.shape[0] |
| 53 | + group_score = score.view(num_token, num_expert_group, |
| 54 | + -1).max(dim=-1).values |
| 55 | + group_idx = torch.topk(group_score, k=topk_group, dim=-1, |
| 56 | + sorted=False)[1] # [n, top_k_group] |
| 57 | + group_mask = torch.zeros_like(group_score) # [n, n_group] |
| 58 | + group_mask.scatter_(1, group_idx, 1) # [n, n_group] |
| 59 | + score_mask = group_mask.unsqueeze(-1).expand( |
| 60 | + num_token, num_expert_group, |
| 61 | + score.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] |
| 62 | + score = score.masked_fill(~score_mask.bool(), 0.0) # [n, e] |
| 63 | + |
| 64 | + if e_score_correction_bias is not None: |
| 65 | + topk_ids = torch.topk(score, k=topk, dim=-1, sorted=False)[1] |
| 66 | + # Use original unbiased scores for the routing weights |
| 67 | + topk_weight = original_scores.gather(1, topk_ids) |
| 68 | + else: |
| 69 | + topk_weight, topk_ids = torch.topk(score, k=topk, dim=-1, sorted=False) |
| 70 | + |
| 71 | + if renormalize: |
| 72 | + topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True) |
| 73 | + |
116 | 74 | topk_weight = topk_weight.view(-1) |
117 | 75 | topk_ids = topk_ids.view(-1) |
118 | | - if expert_map is not None: |
119 | | - topk_ids = expert_map[topk_ids] |
| 76 | + |
120 | 77 | for i in range(w1.shape[0]): |
121 | 78 | mask = topk_ids == i |
122 | 79 | if mask.sum(): |
123 | | - out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) |
124 | | - return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) |
| 80 | + out[mask] = SiluAndMul()( |
| 81 | + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) |
| 82 | + return (out.view(B, -1, w2.shape[1]) * |
| 83 | + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) |
125 | 84 |
|
126 | 85 |
|
127 | | -@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) |
128 | | -@pytest.mark.parametrize("n", [128, 1024, 2048]) |
129 | | -@pytest.mark.parametrize("k", [128, 511, 1024]) |
| 86 | +@pytest.mark.parametrize("m", [1]) |
| 87 | +@pytest.mark.parametrize("n", [128]) |
| 88 | +@pytest.mark.parametrize("k", [128]) |
130 | 89 | @pytest.mark.parametrize("e", NUM_EXPERTS) |
131 | 90 | @pytest.mark.parametrize("topk", TOP_KS) |
132 | 91 | @pytest.mark.parametrize("ep_size", EP_SIZE) |
133 | | -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) |
| 92 | +@pytest.mark.parametrize("dtype", [torch.float16]) #, torch.bfloat16]) |
134 | 93 | @pytest.mark.parametrize("device", DEVICE) |
135 | | -def test_fused_moe(m: int, n: int, k: int, e: int, topk: int, ep_size: int, dtype: torch.dtype, device: str): |
| 94 | +def test_fused_moe(m: int, n: int, k: int, e: int, topk: int, ep_size: int, |
| 95 | + dtype: torch.dtype, device: str): |
136 | 96 | a = torch.randn((m, k), device=device, dtype=dtype) / 10 |
137 | 97 | w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10 |
138 | 98 | w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10 |
139 | 99 |
|
140 | 100 | score = torch.randn((m, e), device=device, dtype=dtype) |
| 101 | + topk_weights, topk_ids = score.topk(topk, dim=-1) |
| 102 | + topk_weights = topk_weights.to(dtype) |
| 103 | + topk_ids = topk_ids.to(torch.int32) |
| 104 | + |
| 105 | + layer = SimpleNamespace(w13_weight=w1, w2_weight=w2) |
141 | 106 |
|
142 | 107 | if ep_size > 1: |
143 | 108 | local_e = e // ep_size |
144 | | - e_ids = torch.randint(0, e, (local_e,), device=device, dtype=torch.int32) |
145 | | - e_map = torch.full((e,), -1, device=device, dtype=torch.int32) |
| 109 | + e_ids = torch.randint(0, |
| 110 | + e, (local_e, ), |
| 111 | + device=device, |
| 112 | + dtype=torch.int32) |
| 113 | + e_map = torch.full((e, ), -1, device=device, dtype=torch.int32) |
146 | 114 | e_map[e_ids] = torch.arange(local_e, device=device, dtype=torch.int32) |
147 | 115 | w1 = w1[e_ids] |
148 | 116 | w2 = w2[e_ids] |
149 | | - |
150 | | - output = fused_experts_with_ep( |
151 | | - a, w1, w2, score, topk, global_num_experts=e, expert_map=e_map, renormalize=False |
152 | | - ) |
153 | 117 | else: |
154 | | - fused_moe = fused_experts |
| 118 | + e_map = None |
| 119 | + |
| 120 | + output = forward_oot(None, layer, a, True, topk, topk_weights, True, 1, 1, |
| 121 | + -1, e_map) |
| 122 | + |
| 123 | + torch_output = torch_moe(a, |
| 124 | + w1, |
| 125 | + w2, |
| 126 | + score, |
| 127 | + topk, |
| 128 | + renormalize=True, |
| 129 | + num_expert_group=1, |
| 130 | + topk_group=1, |
| 131 | + scoring_func='sigmoid', |
| 132 | + e_score_correction_bias=None) |
155 | 133 |
|
156 | | - output = fused_moe(a, w1, w2, score, topk, global_num_experts=e, expert_map=e_map, renormalize=False) |
157 | | - torch_output = torch_moe(a, w1, w2, score, topk, e_map) |
158 | 134 | torch.testing.assert_close(output, torch_output, atol=2e-2, rtol=0) |
0 commit comments