-
Notifications
You must be signed in to change notification settings - Fork 531
[Refactor]Refactor sampler #2050
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| from unittest import mock | ||
|
|
||
| import torch | ||
|
|
||
| from tests.ut.base import TestBase | ||
| from vllm_ascend.sample.sampler import AscendSampler, AscendTopKTopPSampler | ||
|
|
||
|
|
||
| class TestAscendSampler(TestBase): | ||
|
|
||
| def test_init_with_raw_logprobs(self): | ||
| sampler = AscendSampler(logprobs_mode="raw_logprobs") | ||
| self.assertEqual(sampler.logprobs_mode, "raw_logprobs") | ||
| self.assertTrue(hasattr(sampler, 'topk_topp_sampler')) | ||
| self.assertIsInstance(sampler.topk_topp_sampler, AscendTopKTopPSampler) | ||
|
|
||
|
|
||
| class TestAscendTopKTopPSampler(TestBase): | ||
|
|
||
| @mock.patch("torch_npu.npu_top_k_top_p") | ||
| def test_npu_topk_topp_called_when_optimized(self, mock_npu_op): | ||
| mock_npu_op.return_value = (torch.randn(1, 3)) | ||
| sampler = AscendTopKTopPSampler() | ||
|
|
||
| logits = torch.tensor([[1.0, 2.0, 3.0]]) | ||
| k = torch.tensor([2]) | ||
| p = torch.tensor([0.9]) | ||
| generators = {0: torch.Generator()} | ||
| generators[0].manual_seed(42) | ||
|
|
||
| sampler.forward_native(logits, generators, k, p) | ||
| mock_npu_op.assert_called_once_with(logits, p, k) |
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| import torch | ||
| import torch_npu | ||
| from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample | ||
| from vllm.v1.sample.sampler import Sampler | ||
|
|
||
|
|
||
| class AscendSampler(Sampler): | ||
|
|
||
| def __init__(self, logprobs_mode="raw_logprobs"): | ||
| # TODO: support logprobs_mode in vllm-ascend | ||
| super().__init__(logprobs_mode=logprobs_mode) | ||
| self.topk_topp_sampler = AscendTopKTopPSampler() | ||
|
|
||
|
|
||
| class AscendTopKTopPSampler(TopKTopPSampler): | ||
|
|
||
| def _apply_top_k_top_p( | ||
| self, | ||
| logits: torch.Tensor, | ||
| k: torch.Tensor, | ||
| p: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| if p is not None and k is not None: | ||
| # npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p) | ||
| return torch_npu.npu_top_k_top_p(logits, p, k) | ||
|
|
||
| if p is None and k is None: | ||
| return logits | ||
|
|
||
| probs = logits.softmax(dim=-1) | ||
| probs_sort, _ = probs.sort(dim=-1, descending=False) | ||
|
|
||
| if k is not None: | ||
| top_k_count = probs_sort.size(1) - k.to( | ||
| torch.long) # shape: (batch, ) | ||
| top_k_count = top_k_count.unsqueeze(dim=1) | ||
| top_k_cutoff = probs_sort.gather(-1, top_k_count) | ||
|
|
||
| # Make sure the no top-k rows are no-op. | ||
| no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) | ||
| top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) | ||
|
|
||
| elements_to_discard = probs < top_k_cutoff | ||
| logits.masked_fill_(elements_to_discard, -float("inf")) | ||
|
|
||
| if p is not None: | ||
| cumprob = torch.cumsum(probs_sort, dim=-1) | ||
| top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) | ||
| top_p_mask[:, -1] = False # at least one | ||
|
|
||
| top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) | ||
| top_p_cutoff = probs_sort.gather(-1, top_p_count) | ||
| elements_to_discard = probs < top_p_cutoff | ||
| logits.masked_fill_(elements_to_discard, -float("inf")) | ||
|
|
||
| return logits | ||
|
|
||
| def forward_native(self, logits, generators, k, p): | ||
| """Override pytorch native implementation to torch_npu""" | ||
| logits = self._apply_top_k_top_p(logits, k, p) | ||
| probs = logits.softmax(dim=-1, dtype=torch.float32) | ||
| return random_sample(probs, generators) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -64,14 +64,14 @@ | |
| ModelRunnerOutput) | ||
| from vllm.v1.pool.metadata import PoolingMetadata | ||
| from vllm.v1.sample.metadata import SamplingMetadata | ||
| from vllm.v1.sample.sampler import Sampler | ||
| from vllm.v1.spec_decode.metadata import SpecDecodeMetadata | ||
| from vllm.v1.spec_decode.ngram_proposer import NgramProposer | ||
| from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin | ||
| from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders, | ||
| sanity_check_mm_encoder_outputs, | ||
| scatter_mm_placeholders) | ||
|
|
||
| from vllm_ascend import envs | ||
| from vllm_ascend.ascend_config import get_ascend_config | ||
| from vllm_ascend.ascend_forward_context import set_ascend_forward_context | ||
| from vllm_ascend.attention.attention_mask import AttentionMaskBuilder | ||
|
|
@@ -165,7 +165,15 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): | |
| self.dp_rank = vllm_config.parallel_config.data_parallel_rank | ||
| self.device = device | ||
| self.dtype = self.model_config.dtype | ||
| self.sampler = Sampler() | ||
| if envs.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION: | ||
| # TODO: drop the env config to use ascend sampler by default | ||
| from vllm_ascend.sample.sampler import AscendSampler | ||
|
|
||
| self.sampler = AscendSampler() | ||
| else: | ||
| from vllm.v1.sample.sampler import Sampler | ||
|
|
||
| self.sampler = Sampler() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can it take effect directly by removing the flag? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, I hope so. Let's have a fully test to ensure ascend sampler is stable enough. Then we can remove the flag There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I set the value to True to enable ascend sampler by default now. |
||
|
|
||
| # Lazy initialization, these will be set after __init__ | ||
| self.kv_caches: List[torch.Tensor] = [] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated: it seems
probsandprobs_sortcompution are ununsedif p is None and k is None, the logits can return directly.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @Pr0Wh1teGivee
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done