1+ #
2+ # Licensed under the Apache License, Version 2.0 (the "License");
3+ # you may not use this file except in compliance with the License.
4+ # You may obtain a copy of the License at
5+ #
6+ # http://www.apache.org/licenses/LICENSE-2.0
7+ #
8+ # Unless required by applicable law or agreed to in writing, software
9+ # distributed under the License is distributed on an "AS IS" BASIS,
10+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+ # See the License for the specific language governing permissions and
12+ # limitations under the License.
13+ # This file is a part of the vllm-ascend project.
14+ #
15+
116# fused moe ops test will hit the infer_schema error, we need add the patch
217# here to make the test pass.
318import vllm_ascend .patch .worker .patch_common .patch_utils # type: ignore[import] # isort: skip # noqa
419
520import json
6- import unittest
21+ import os
722from typing import List , TypedDict
823from unittest import mock
924
1025import torch
1126
27+ from tests .ut .base import TestBase
1228from vllm_ascend .ops .expert_load_balancer import ExpertLoadBalancer
1329
1430
@@ -28,31 +44,13 @@ class MockData(TypedDict):
2844 layer_list : List [Layer ]
2945
3046
31- MOCK_DATA : MockData = {
32- "moe_layer_count" :
33- 1 ,
34- "layer_list" : [{
35- "layer_id" :
36- 0 ,
37- "device_count" :
38- 2 ,
39- "device_list" : [{
40- "device_id" : 0 ,
41- "device_expert" : [7 , 2 , 0 , 3 , 5 ]
42- }, {
43- "device_id" : 1 ,
44- "device_expert" : [6 , 1 , 4 , 7 , 2 ]
45- }]
46- }]
47- }
48-
49-
50- class TestExpertLoadBalancer (unittest .TestCase ):
47+ class TestExpertLoadBalancer (TestBase ):
5148
5249 def setUp (self ):
53- json_file = "expert_map.json"
54- with open (json_file , 'w' ) as f :
55- json .dump (MOCK_DATA , f )
50+ _TEST_DIR = os .path .dirname (__file__ )
51+ json_file = _TEST_DIR + "/expert_map.json"
52+ with open (json_file , 'r' ) as f :
53+ self .expert_map : MockData = json .load (f )
5654
5755 self .expert_load_balancer = ExpertLoadBalancer (json_file ,
5856 global_expert_num = 8 )
@@ -62,9 +60,9 @@ def test_init(self):
6260 self .assertIsInstance (self .expert_load_balancer .expert_map_tensor ,
6361 torch .Tensor )
6462 self .assertEqual (self .expert_load_balancer .layers_num ,
65- MOCK_DATA ["moe_layer_count" ])
63+ self . expert_map ["moe_layer_count" ])
6664 self .assertEqual (self .expert_load_balancer .ranks_num ,
67- MOCK_DATA ["layer_list" ][0 ]["device_count" ])
65+ self . expert_map ["layer_list" ][0 ]["device_count" ])
6866
6967 def test_generate_index_dicts (self ):
7068 tensor_2d = torch .tensor ([[7 , 2 , 0 , 3 , 5 ], [6 , 1 , 4 , 7 , 2 ]])
@@ -142,6 +140,6 @@ def test_get_rank_log2phy_map(self):
142140 def test_get_global_redundant_expert_num (self ):
143141 redundant_expert_num = self .expert_load_balancer .get_global_redundant_expert_num (
144142 )
145- expected_redundant_expert_num = len (MOCK_DATA ["layer_list" ][0 ]["device_list" ][0 ]["device_expert" ]) * \
146- MOCK_DATA ["layer_list" ][0 ]["device_count" ] - 8
143+ expected_redundant_expert_num = len (self . expert_map ["layer_list" ][0 ]["device_list" ][0 ]["device_expert" ]) * \
144+ self . expert_map ["layer_list" ][0 ]["device_count" ] - 8
147145 self .assertEqual (redundant_expert_num , expected_redundant_expert_num )
0 commit comments