|
1 | | -import importlib |
| 1 | +# |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# |
| 6 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +# |
| 8 | +# Unless required by applicable law or agreed to in writing, software |
| 9 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 10 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 11 | +# See the License for the specific language governing permissions and |
| 12 | +# limitations under the License. |
| 13 | +# This file is a part of the vllm-ascend project. |
| 14 | +# |
| 15 | + |
2 | 16 | import os |
3 | | -import unittest |
4 | 17 | from unittest import mock |
5 | 18 |
|
6 | 19 | import torch |
7 | 20 | from vllm.v1.sample.ops import topk_topp_sampler |
8 | 21 |
|
| 22 | +from tests.ut.base import TestBase |
| 23 | + |
9 | 24 |
|
10 | | -class TestTopKTopPSamplerOptimize(unittest.TestCase): |
| 25 | +class TestTopKTopPSamplerOptimize(TestBase): |
11 | 26 |
|
12 | 27 | @mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE": "1"}) |
13 | 28 | @mock.patch("torch_npu.npu_top_k_top_p") |
14 | 29 | def test_npu_topk_topp_called_when_optimized(self, mock_npu_op): |
15 | | - import vllm_ascend.patch.worker.patch_common.patch_sampler |
16 | | - importlib.reload(vllm_ascend.patch.worker.patch_common.patch_sampler) |
17 | | - |
18 | 30 | mock_npu_op.return_value = (torch.randn(1, 3)) |
19 | 31 | sampler = topk_topp_sampler.TopKTopPSampler() |
20 | 32 |
|
|
0 commit comments