Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 0 additions & 46 deletions tests/ut/patch/worker/patch_common/test_patch_sampler.py

This file was deleted.

32 changes: 32 additions & 0 deletions tests/ut/sample/test_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from unittest import mock

import torch

from tests.ut.base import TestBase
from vllm_ascend.sample.sampler import AscendSampler, AscendTopKTopPSampler


class TestAscendSampler(TestBase):

def test_init_with_raw_logprobs(self):
sampler = AscendSampler(logprobs_mode="raw_logprobs")
self.assertEqual(sampler.logprobs_mode, "raw_logprobs")
self.assertTrue(hasattr(sampler, 'topk_topp_sampler'))
self.assertIsInstance(sampler.topk_topp_sampler, AscendTopKTopPSampler)


class TestAscendTopKTopPSampler(TestBase):

@mock.patch("torch_npu.npu_top_k_top_p")
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op):
mock_npu_op.return_value = (torch.randn(1, 3))
sampler = AscendTopKTopPSampler()

logits = torch.tensor([[1.0, 2.0, 3.0]])
k = torch.tensor([2])
p = torch.tensor([0.9])
generators = {0: torch.Generator()}
generators[0].manual_seed(42)

sampler.forward_native(logits, generators, k, p)
mock_npu_op.assert_called_once_with(logits, p, k)
6 changes: 3 additions & 3 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,11 @@
"VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE":
lambda: int(
os.getenv("VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE", 64)),
# Whether to enable the topk optimization. It's disabled by default for experimental support
# We'll make it enabled by default in the future.
# Whether to enable the topk optimization. It's enabled by default. Please set to False if you hit any issue.
# We'll remove this flag in the future once it's stable enough.
"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION":
lambda: bool(
int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '0'))),
int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '1'))),

# `LLMDataDistCMgrConnector` required variable. `DISAGGREGATED_PREFILL_RANK_TABLE_PATH` is
# used for llmdatadist to build the communication topology for kv cache transfer, it is
Expand Down
16 changes: 1 addition & 15 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,21 +88,7 @@
# Future Plan:
# Remove this patch once pytorch 2.7.0 is supported for vllm ascend.
#
# ** File: worker/patch_common/patch_sampler.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.sample.sampler.Sampler.apply_top_k_top_p`
# Why:
# We need to use the patched `apply_top_k_top_p` in `sample`.
# The mainly reason to overwrite `apply_top_k_top_p` is
# to improve performance.
# How:
# Re-implementation the `apply_top_k_top_p` function by pytorch
# Related PR (if no, explain why):
# - https://github.com/vllm-project/vllm-ascend/pull/1732
# Future Plan:
# Revert it when the ascend scatter performance improves.
#
# ** File: worker/patch_common/patch_sampler.py **
# ** File: worker/patch_0_10_0/patch_sampler_gather_logprobs.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.sample.sampler.Sampler.gather_logprobs`
# Why:
Expand Down
1 change: 0 additions & 1 deletion vllm_ascend/patch/worker/patch_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,3 @@
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_common.patch_linear # noqa
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
83 changes: 0 additions & 83 deletions vllm_ascend/patch/worker/patch_common/patch_sampler.py

This file was deleted.

62 changes: 62 additions & 0 deletions vllm_ascend/sample/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
import torch_npu
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
from vllm.v1.sample.sampler import Sampler


class AscendSampler(Sampler):

def __init__(self, logprobs_mode="raw_logprobs"):
# TODO: support logprobs_mode in vllm-ascend
super().__init__(logprobs_mode=logprobs_mode)
self.topk_topp_sampler = AscendTopKTopPSampler()


class AscendTopKTopPSampler(TopKTopPSampler):

def _apply_top_k_top_p(
self,
logits: torch.Tensor,
k: torch.Tensor,
p: torch.Tensor,
) -> torch.Tensor:
if p is not None and k is not None:
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
return torch_npu.npu_top_k_top_p(logits, p, k)

if p is None and k is None:
return logits

probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)
Comment on lines +30 to +31
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated: it seems probs and probs_sort compution are ununsed if p is None and k is None, the logits can return directly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


if k is not None:
top_k_count = probs_sort.size(1) - k.to(
torch.long) # shape: (batch, )
top_k_count = top_k_count.unsqueeze(dim=1)
top_k_cutoff = probs_sort.gather(-1, top_k_count)

# Make sure the no top-k rows are no-op.
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))

elements_to_discard = probs < top_k_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))

if p is not None:
cumprob = torch.cumsum(probs_sort, dim=-1)
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = False # at least one

top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
top_p_cutoff = probs_sort.gather(-1, top_p_count)
elements_to_discard = probs < top_p_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))

return logits

def forward_native(self, logits, generators, k, p):
"""Override pytorch native implementation to torch_npu"""
logits = self._apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
12 changes: 10 additions & 2 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@
ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)

from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
Expand Down Expand Up @@ -165,7 +165,15 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.device = device
self.dtype = self.model_config.dtype
self.sampler = Sampler()
if envs.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
# TODO: drop the env config to use ascend sampler by default
from vllm_ascend.sample.sampler import AscendSampler

self.sampler = AscendSampler()
else:
from vllm.v1.sample.sampler import Sampler

self.sampler = Sampler()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it take effect directly by removing the flag?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I hope so. Let's have a fully test to ensure ascend sampler is stable enough. Then we can remove the flag

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set the value to True to enable ascend sampler by default now.


# Lazy initialization, these will be set after __init__
self.kv_caches: List[torch.Tensor] = []
Expand Down
Loading