1+ #
2+ # Licensed under the Apache License, Version 2.0 (the "License");
3+ # you may not use this file except in compliance with the License.
4+ # You may obtain a copy of the License at
5+ #
6+ # http://www.apache.org/licenses/LICENSE-2.0
7+ #
8+ # Unless required by applicable law or agreed to in writing, software
9+ # distributed under the License is distributed on an "AS IS" BASIS,
10+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+ # See the License for the specific language governing permissions and
12+ # limitations under the License.
13+ # This file is a part of the vllm-ascend project.
14+
15+ import importlib
16+ import os
17+ import unittest
18+ from unittest import mock
19+
20+ import torch
21+ from vllm .v1 .sample .ops import topk_topp_sampler
22+
23+
24+ class TestTopKTopPSamplerOptimize (unittest .TestCase ):
25+
26+ @mock .patch .dict (os .environ , {"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE" : "1" })
27+ @mock .patch ("torch_npu.npu_top_k_top_p" )
28+ def test_npu_topk_topp_called_when_optimized (self , mock_npu_op ):
29+ # We have to patch and reload because the patch will take effect
30+ # only after VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE is set.
31+ import vllm_ascend .patch .worker .patch_common .patch_sampler
32+ importlib .reload (vllm_ascend .patch .worker .patch_common .patch_sampler )
33+
34+ mock_npu_op .return_value = (torch .randn (1 , 3 ))
35+ sampler = topk_topp_sampler .TopKTopPSampler ()
36+
37+ logits = torch .tensor ([[1.0 , 2.0 , 3.0 ]])
38+ k = torch .tensor ([2 ])
39+ p = torch .tensor ([0.9 ])
40+ generators = {0 : torch .Generator ()}
41+ generators [0 ].manual_seed (42 )
42+
43+ sampler .forward_native (logits , generators , k , p )
44+ mock_npu_op .assert_called_once_with (logits , p , k )
0 commit comments