From 7966750b84233c473e7e08ad53bb15d7f30925ba Mon Sep 17 00:00:00 2001 From: Pr0Wh1teGivee Date: Tue, 22 Jul 2025 10:10:34 +0800 Subject: [PATCH] use fused ops npu_top_k_top_p which is introduced in https://mirrors.huaweicloud.com/ascend/repos/pypi/torch-npu/ Signed-off-by: Pr0Wh1teGivee --- .../worker/patch_common/test_patch_sampler.py | 44 +++++++++++++++++++ .../worker/patch_common/patch_sampler.py | 11 +++-- 2 files changed, 52 insertions(+), 3 deletions(-) create mode 100644 tests/ut/patch/worker/patch_common/test_patch_sampler.py diff --git a/tests/ut/patch/worker/patch_common/test_patch_sampler.py b/tests/ut/patch/worker/patch_common/test_patch_sampler.py new file mode 100644 index 0000000000..bc3a14ae91 --- /dev/null +++ b/tests/ut/patch/worker/patch_common/test_patch_sampler.py @@ -0,0 +1,44 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +import importlib +import os +import unittest +from unittest import mock + +import torch +from vllm.v1.sample.ops import topk_topp_sampler + + +class TestTopKTopPSamplerOptimize(unittest.TestCase): + + @mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE": "1"}) + @mock.patch("torch_npu.npu_top_k_top_p") + def test_npu_topk_topp_called_when_optimized(self, mock_npu_op): + # We have to patch and reload because the patch will take effect + # only after VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE is set. + import vllm_ascend.patch.worker.patch_common.patch_sampler + importlib.reload(vllm_ascend.patch.worker.patch_common.patch_sampler) + + mock_npu_op.return_value = (torch.randn(1, 3)) + sampler = topk_topp_sampler.TopKTopPSampler() + + logits = torch.tensor([[1.0, 2.0, 3.0]]) + k = torch.tensor([2]) + p = torch.tensor([0.9]) + generators = {0: torch.Generator()} + generators[0].manual_seed(42) + + sampler.forward_native(logits, generators, k, p) + mock_npu_op.assert_called_once_with(logits, p, k) diff --git a/vllm_ascend/patch/worker/patch_common/patch_sampler.py b/vllm_ascend/patch/worker/patch_common/patch_sampler.py index a6fbfbc01f..a01d0de6f1 100644 --- a/vllm_ascend/patch/worker/patch_common/patch_sampler.py +++ b/vllm_ascend/patch/worker/patch_common/patch_sampler.py @@ -19,6 +19,7 @@ from typing import Optional import torch +import torch_npu from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample from vllm.v1.sample.sampler import Sampler @@ -46,11 +47,15 @@ def apply_min_p( return logits -def _apply_top_k_top_p( +def apply_top_k_top_p( logits: torch.Tensor, - p: torch.Tensor, k: torch.Tensor, + p: torch.Tensor, ) -> torch.Tensor: + if p is not None and k is not None: + # npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p) + return torch_npu.npu_top_k_top_p(logits, p, k) + probs = logits.softmax(dim=-1) probs_sort, _ = probs.sort(dim=-1, descending=False) @@ -91,7 +96,7 @@ def topk_topp_forward_native( The logits tensor may be updated in-place. """ - logits = _apply_top_k_top_p(logits, k, p) + logits = apply_top_k_top_p(logits, k, p) probs = logits.softmax(dim=-1, dtype=torch.float32) return random_sample(probs, generators)