1- # fused moe ops test will hit the infer_schema error, we need add the patch
2- # here to make the test pass.
3- import vllm_ascend .patch .worker .patch_common .patch_utils # type: ignore[import] # isort: skip # noqa
1+ #
2+ # Licensed under the Apache License, Version 2.0 (the "License");
3+ # you may not use this file except in compliance with the License.
4+ # You may obtain a copy of the License at
5+ #
6+ # http://www.apache.org/licenses/LICENSE-2.0
7+ #
8+ # Unless required by applicable law or agreed to in writing, software
9+ # distributed under the License is distributed on an "AS IS" BASIS,
10+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+ # See the License for the specific language governing permissions and
12+ # limitations under the License.
13+ # This file is a part of the vllm-ascend project.
14+ #
415
516import json
6- import unittest
17+ import os
718from typing import List , TypedDict
819from unittest import mock
920
1021import torch
1122
23+ from tests .ut .base import TestBase
1224from vllm_ascend .ops .expert_load_balancer import ExpertLoadBalancer
1325
1426
@@ -28,31 +40,13 @@ class MockData(TypedDict):
2840 layer_list : List [Layer ]
2941
3042
31- MOCK_DATA : MockData = {
32- "moe_layer_count" :
33- 1 ,
34- "layer_list" : [{
35- "layer_id" :
36- 0 ,
37- "device_count" :
38- 2 ,
39- "device_list" : [{
40- "device_id" : 0 ,
41- "device_expert" : [7 , 2 , 0 , 3 , 5 ]
42- }, {
43- "device_id" : 1 ,
44- "device_expert" : [6 , 1 , 4 , 7 , 2 ]
45- }]
46- }]
47- }
48-
49-
50- class TestExpertLoadBalancer (unittest .TestCase ):
43+ class TestExpertLoadBalancer (TestBase ):
5144
5245 def setUp (self ):
53- json_file = "expert_map.json"
54- with open (json_file , 'w' ) as f :
55- json .dump (MOCK_DATA , f )
46+ _TEST_DIR = os .path .dirname (__file__ )
47+ json_file = _TEST_DIR + "/expert_map.json"
48+ with open (json_file , 'r' ) as f :
49+ self .expert_map : MockData = json .load (f )
5650
5751 self .expert_load_balancer = ExpertLoadBalancer (json_file ,
5852 global_expert_num = 8 )
@@ -62,9 +56,9 @@ def test_init(self):
6256 self .assertIsInstance (self .expert_load_balancer .expert_map_tensor ,
6357 torch .Tensor )
6458 self .assertEqual (self .expert_load_balancer .layers_num ,
65- MOCK_DATA ["moe_layer_count" ])
59+ self . expert_map ["moe_layer_count" ])
6660 self .assertEqual (self .expert_load_balancer .ranks_num ,
67- MOCK_DATA ["layer_list" ][0 ]["device_count" ])
61+ self . expert_map ["layer_list" ][0 ]["device_count" ])
6862
6963 def test_generate_index_dicts (self ):
7064 tensor_2d = torch .tensor ([[7 , 2 , 0 , 3 , 5 ], [6 , 1 , 4 , 7 , 2 ]])
@@ -142,6 +136,6 @@ def test_get_rank_log2phy_map(self):
142136 def test_get_global_redundant_expert_num (self ):
143137 redundant_expert_num = self .expert_load_balancer .get_global_redundant_expert_num (
144138 )
145- expected_redundant_expert_num = len (MOCK_DATA ["layer_list" ][0 ]["device_list" ][0 ]["device_expert" ]) * \
146- MOCK_DATA ["layer_list" ][0 ]["device_count" ] - 8
139+ expected_redundant_expert_num = len (self . expert_map ["layer_list" ][0 ]["device_list" ][0 ]["device_expert" ]) * \
140+ self . expert_map ["layer_list" ][0 ]["device_count" ] - 8
147141 self .assertEqual (redundant_expert_num , expected_redundant_expert_num )
0 commit comments