Skip to content

Commit 2bb06c5

Browse files
committed
use fused ops npu_top_k_top_p which is introduced in https://mirrors.huaweicloud.com/ascend/repos/pypi/torch-npu/
1 parent b3d6e0c commit 2bb06c5

File tree

2 files changed

+52
-3
lines changed

2 files changed

+52
-3
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# This file is a part of the vllm-ascend project.
14+
15+
import importlib
16+
import os
17+
import unittest
18+
from unittest import mock
19+
20+
import torch
21+
from vllm.v1.sample.ops import topk_topp_sampler
22+
23+
24+
class TestTopKTopPSamplerOptimize(unittest.TestCase):
25+
26+
@mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE": "1"})
27+
@mock.patch("torch_npu.npu_top_k_top_p")
28+
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op):
29+
# We have to patch and reload because the patch will take effect
30+
# only after VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE is set.
31+
import vllm_ascend.patch.worker.patch_common.patch_sampler
32+
importlib.reload(vllm_ascend.patch.worker.patch_common.patch_sampler)
33+
34+
mock_npu_op.return_value = (torch.randn(1, 3))
35+
sampler = topk_topp_sampler.TopKTopPSampler()
36+
37+
logits = torch.tensor([[1.0, 2.0, 3.0]])
38+
k = torch.tensor([2])
39+
p = torch.tensor([0.9])
40+
generators = {0: torch.Generator()}
41+
generators[0].manual_seed(42)
42+
43+
sampler.forward_native(logits, generators, k, p)
44+
mock_npu_op.assert_called_once_with(logits, p, k)

vllm_ascend/patch/worker/patch_common/patch_sampler.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Optional
2020

2121
import torch
22+
import torch_npu
2223
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
2324
from vllm.v1.sample.sampler import Sampler
2425

@@ -46,11 +47,15 @@ def apply_min_p(
4647
return logits
4748

4849

49-
def _apply_top_k_top_p(
50+
def apply_top_k_top_p(
5051
logits: torch.Tensor,
51-
p: torch.Tensor,
5252
k: torch.Tensor,
53+
p: torch.Tensor,
5354
) -> torch.Tensor:
55+
if p is not None and k is not None:
56+
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
57+
return torch_npu.npu_top_k_top_p(logits, p, k)
58+
5459
probs = logits.softmax(dim=-1)
5560
probs_sort, _ = probs.sort(dim=-1, descending=False)
5661

@@ -91,7 +96,7 @@ def topk_topp_forward_native(
9196
9297
The logits tensor may be updated in-place.
9398
"""
94-
logits = _apply_top_k_top_p(logits, k, p)
99+
logits = apply_top_k_top_p(logits, k, p)
95100
probs = logits.softmax(dim=-1, dtype=torch.float32)
96101
return random_sample(probs, generators)
97102

0 commit comments

Comments
 (0)