Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions tests/ut/patch/worker/patch_common/test_patch_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.

import importlib
import os
import unittest
from unittest import mock

import torch
from vllm.v1.sample.ops import topk_topp_sampler


class TestTopKTopPSamplerOptimize(unittest.TestCase):

@mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE": "1"})
@mock.patch("torch_npu.npu_top_k_top_p")
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op):
# We have to patch and reload because the patch will take effect
# only after VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE is set.
import vllm_ascend.patch.worker.patch_common.patch_sampler
importlib.reload(vllm_ascend.patch.worker.patch_common.patch_sampler)

mock_npu_op.return_value = (torch.randn(1, 3))
sampler = topk_topp_sampler.TopKTopPSampler()

logits = torch.tensor([[1.0, 2.0, 3.0]])
k = torch.tensor([2])
p = torch.tensor([0.9])
generators = {0: torch.Generator()}
generators[0].manual_seed(42)

sampler.forward_native(logits, generators, k, p)
mock_npu_op.assert_called_once_with(logits, p, k)
11 changes: 8 additions & 3 deletions vllm_ascend/patch/worker/patch_common/patch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Optional

import torch
import torch_npu
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
from vllm.v1.sample.sampler import Sampler

Expand Down Expand Up @@ -46,11 +47,15 @@ def apply_min_p(
return logits


def _apply_top_k_top_p(
def apply_top_k_top_p(
logits: torch.Tensor,
p: torch.Tensor,
k: torch.Tensor,
p: torch.Tensor,
) -> torch.Tensor:
if p is not None and k is not None:
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
return torch_npu.npu_top_k_top_p(logits, p, k)

probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)

Expand Down Expand Up @@ -91,7 +96,7 @@ def topk_topp_forward_native(

The logits tensor may be updated in-place.
"""
logits = _apply_top_k_top_p(logits, k, p)
logits = apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)

Expand Down