|
| 1 | +import pytest |
| 2 | +import torch |
| 3 | +from vllm_ascend.ops.fused_moe import fused_moe |
| 4 | + |
| 5 | + |
| 6 | +def test_fused_moe(): |
| 7 | + # Since we are using native PyTorch operations in the function, the most reliable ground truth |
| 8 | + # for comparison is the manually computed output. By using hardcoded data, we can ensure |
| 9 | + # that the function produces the expected results and validate its correctness against a known reference. |
| 10 | + |
| 11 | + # Step 1: Constructing inputs |
| 12 | + hidden_states = torch.tensor([[1.0, 2.0, 3.0], [3.0, 2.0, 1.0]]) |
| 13 | + |
| 14 | + # w1: [3, 4, 3] (num_experts=3, intermediate_size*2=4, hidden_size=3) |
| 15 | + w1 = torch.tensor( |
| 16 | + [ |
| 17 | + [[1.0, 0.0, -1.0], [2.0, 1.0, 0.0], [1.0, 1.0, -1.0], [1.0, -1.0, 1.0]], |
| 18 | + [[-1.0, 1.0, 1.0], [1.0, -1.0, 1.0], [2.0, -2.0, 2.0], [1.0, 0.0, -1.0]], |
| 19 | + [[-2.0, -1.0, 1.0], [2.0, -1.0, 1.0], [-1.0, 2.0, 1.0], [1.0, 1.0, -1.0]], |
| 20 | + ] |
| 21 | + ) |
| 22 | + |
| 23 | + # w2: [3, 3, 2] (num_experts=3, hidden_size=3, intermediate_size=2) |
| 24 | + w2 = torch.tensor( |
| 25 | + [ |
| 26 | + [[1.0, 0.5], [2.0, -1.0], [0.0, 1.0]], |
| 27 | + [[1.0, 1.0], [-1.0, 1.0], [1.0, -0.0]], |
| 28 | + [[-2.0, 1.0], [1.0, -1.0], [2.0, 1.0]], |
| 29 | + ] |
| 30 | + ) |
| 31 | + |
| 32 | + # gating_output: [2, 3] (num_tokens=2, num_experts=3) |
| 33 | + gating_output = torch.tensor([[0.0, 0.5, 0.5], [0.5, 0.5, 0.0]]) |
| 34 | + |
| 35 | + topk = 2 |
| 36 | + |
| 37 | + global_num_experts = 3 |
| 38 | + |
| 39 | + # Only has the first two experts |
| 40 | + expert_map = torch.tensor([0, 1, -1]) |
| 41 | + |
| 42 | + renormalize = False |
| 43 | + |
| 44 | + use_grouped_topk = False |
| 45 | + |
| 46 | + # Step 2: Expected output calculation |
| 47 | + |
| 48 | + # We use topk=2, which means we select the top 2 experts based on gating_output. |
| 49 | + # For sample 1, gating_output = [0.1, 0.7, 0.2], topk_weights = [0.7, 0.2], selected experts = 1, 2 |
| 50 | + # For sample 2, gating_output = [0.5, 0.4, 0.1], topk_weights = [0.5, 0.4], selected experts = 0, 1 |
| 51 | + |
| 52 | + # 1. Calculate linear transformation of hidden_states with w1[0] -> F.linear(hidden_states, w1[0]) |
| 53 | + # 2. Apply gating function to get gate values -> F.silu(x[:, :intermediate_size]) |
| 54 | + # 3. Apply second linear transformation with w2[0] -> F.linear(x, w2[0]) |
| 55 | + # 4. Use the topk_weights for each sample and add the weighted outputs of experts 1 and 2 |
| 56 | + |
| 57 | + expected_hidden_states = torch.tensor([[4.6763, -7.3797, 6.0280], [7.1232, 0.6220, 6.1364]]) |
| 58 | + |
| 59 | + # Step 3: Running the fused_moe function |
| 60 | + final_output = fused_moe( |
| 61 | + hidden_states, w1, w2, gating_output, topk, global_num_experts, expert_map, renormalize, use_grouped_topk |
| 62 | + ) |
| 63 | + |
| 64 | + # Step 4: Check the shape and values (this should match the expected result you computed manually) |
| 65 | + assert ( |
| 66 | + final_output.shape == hidden_states.shape |
| 67 | + ), f"Expected shape {hidden_states.shape}, but got {final_output.shape}" |
| 68 | + |
| 69 | + assert torch.allclose(final_output, expected_hidden_states, atol=1e-4), "Output does not match expected result" |
0 commit comments