|
22 | 22 |
|
23 | 23 | import pytest |
24 | 24 | import torch |
25 | | -from vllm.config import VllmConfig, set_current_vllm_config |
26 | 25 | from vllm.model_executor.layers.activation import SiluAndMul |
27 | 26 |
|
28 | 27 | from vllm_ascend.ops.fused_moe import fused_experts |
@@ -68,36 +67,31 @@ def test_fused_experts( |
68 | 67 | dtype: torch.dtype, |
69 | 68 | device: str, |
70 | 69 | ): |
71 | | - vllm_config = VllmConfig() |
72 | | - with set_current_vllm_config(vllm_config): |
73 | | - a = torch.randn((m, k), device=device, dtype=dtype) / 10 |
74 | | - w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10 |
75 | | - w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10 |
| 70 | + a = torch.randn((m, k), device=device, dtype=dtype) / 10 |
| 71 | + w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10 |
| 72 | + w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10 |
76 | 73 |
|
77 | | - score = torch.randn((m, e), device=device, dtype=dtype) |
| 74 | + score = torch.randn((m, e), device=device, dtype=dtype) |
78 | 75 |
|
79 | | - if ep_size > 1: |
80 | | - local_e = e // ep_size |
81 | | - e_ids = torch.randint(0, |
82 | | - e, (local_e, ), |
83 | | - device=device, |
84 | | - dtype=torch.int32) |
85 | | - e_map = torch.full((e, ), -1, device=device, dtype=torch.int32) |
86 | | - e_map[e_ids] = torch.arange(local_e, |
87 | | - device=device, |
88 | | - dtype=torch.int32) |
89 | | - w1 = w1[e_ids] |
90 | | - w2 = w2[e_ids] |
91 | | - else: |
92 | | - e_map = None |
| 76 | + if ep_size > 1: |
| 77 | + local_e = e // ep_size |
| 78 | + e_ids = torch.randint(0, |
| 79 | + e, (local_e, ), |
| 80 | + device=device, |
| 81 | + dtype=torch.int32) |
| 82 | + e_map = torch.full((e, ), -1, device=device, dtype=torch.int32) |
| 83 | + e_map[e_ids] = torch.arange(local_e, device=device, dtype=torch.int32) |
| 84 | + w1 = w1[e_ids] |
| 85 | + w2 = w2[e_ids] |
| 86 | + else: |
| 87 | + e_map = None |
93 | 88 |
|
94 | | - score = torch.softmax(score, dim=-1, dtype=dtype) |
95 | | - topk_weights, topk_ids = torch.topk(score, topk) |
96 | | - topk_ids = topk_ids.to(torch.int32) |
| 89 | + score = torch.softmax(score, dim=-1, dtype=dtype) |
| 90 | + topk_weights, topk_ids = torch.topk(score, topk) |
| 91 | + topk_ids = topk_ids.to(torch.int32) |
97 | 92 |
|
98 | | - output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map) |
99 | | - torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, |
100 | | - e_map) |
101 | | - # TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem |
102 | | - torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1) |
| 93 | + output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map) |
| 94 | + torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map) |
| 95 | + # TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem |
| 96 | + torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1) |
103 | 97 | torch.npu.empty_cache() |
0 commit comments