From b55ffcaea32ae997d04841527abb6c52464749da Mon Sep 17 00:00:00 2001 From: linfeng-yuan <1102311262@qq.com> Date: Tue, 15 Apr 2025 06:40:47 +0800 Subject: [PATCH 1/2] perf(npu): greatly accelerate post-processing on Ascend platform Signed-off-by: linfeng-yuan <1102311262@qq.com> --- vllm_ascend/sample/__init__.py | 0 vllm_ascend/sample/ops/__init__.py | 0 .../sample/ops/ascend_topk_topp_sampler.py | 64 ++++++++ vllm_ascend/sample/ops/penalties.py | 67 ++++++++ vllm_ascend/sample/sampler.py | 143 ++++++++++++++++++ vllm_ascend/sample/sampler_v1.py | 38 +++++ vllm_ascend/worker/model_runner.py | 8 +- vllm_ascend/worker/model_runner_v1.py | 31 ++++ 8 files changed, 350 insertions(+), 1 deletion(-) create mode 100644 vllm_ascend/sample/__init__.py create mode 100644 vllm_ascend/sample/ops/__init__.py create mode 100644 vllm_ascend/sample/ops/ascend_topk_topp_sampler.py create mode 100644 vllm_ascend/sample/ops/penalties.py create mode 100644 vllm_ascend/sample/sampler.py create mode 100644 vllm_ascend/sample/sampler_v1.py diff --git a/vllm_ascend/sample/__init__.py b/vllm_ascend/sample/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_ascend/sample/ops/__init__.py b/vllm_ascend/sample/ops/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_ascend/sample/ops/ascend_topk_topp_sampler.py b/vllm_ascend/sample/ops/ascend_topk_topp_sampler.py new file mode 100644 index 0000000000..78130412b7 --- /dev/null +++ b/vllm_ascend/sample/ops/ascend_topk_topp_sampler.py @@ -0,0 +1,64 @@ +from typing import Dict, Optional + +import torch +import torch.nn as nn + +from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample +from vllm.logger import init_logger + + +logger = init_logger(__name__) + + +class AscendTopKTopPSampler(TopKTopPSampler): + + def __init__(self): + super().__init__() + # TODO(linfeng): eliminate warning for FlashInfer here + self.forward = self.forward_npu + + def forward_npu( + self, + logits: torch.Tensor, + generators: Dict[int, torch.Generator], + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], + ) -> torch.Tensor: + """Optimized implementation of top-k and top-p sampling on NPU.""" + logits = apply_top_k_top_p_npu(logits, k, p) + probs = logits.softmax(dim=-1, dtype=torch.float32) + return random_sample(probs, generators) + + +def apply_top_k_top_p_npu( + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + """Apply top-k and top-p optimized for NPU. + + This algorithm avoids using torch.scatter which is time-consuming on NPU. + """ + # TODO(linfeng): consider the case taht either p or k is applied + if k is None and p is None: + return logits + batch_size, vocab_size = logits.shape + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1)) + top_k_mask = logits_sort < boundary + logits_sort.masked_fill_(top_k_mask, -float("inf")) + cutoff = top_k_mask.sum(dim=-1).min() + probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:] + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = True + strides = torch.arange(0, batch_size*vocab_size, vocab_size, device=logits.device) + flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1) + valid_idx = torch.masked_select(flatten_idx, top_p_mask) + + logits_flatten = logits.flatten() + valid_logits = torch.index_select(logits_flatten, 0, valid_idx) + logits = torch.empty_like(logits_flatten).fill_(-float("inf")) + logits[valid_idx] = valid_logits + return logits.reshape(batch_size, vocab_size) \ No newline at end of file diff --git a/vllm_ascend/sample/ops/penalties.py b/vllm_ascend/sample/ops/penalties.py new file mode 100644 index 0000000000..c84aee7dfb --- /dev/null +++ b/vllm_ascend/sample/ops/penalties.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from vllm.v1.sample.ops.penalties import _convert_to_tensors +from vllm.model_executor.layers.utils import get_token_bin_counts_and_mask + + +def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, + output_tokens_tensor: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor) -> torch.Tensor: + """Optimized implementation of repetition penalties on NPU. + + Applies penalties in place to the logits tensor + logits : The input logits tensor of shape [num_seqs, vocab_size] + prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts + are padded to the maximum prompt length within the batch using + `vocab_size` as the padding value. The value `vocab_size` is used + for padding because it does not correspond to any valid token ID + in the vocabulary. + output_tokens_tensor: The output tokens tensor. + presence_penalties: The presence penalties of shape (num_seqs, ) + frequency_penalties: The frequency penalties of shape (num_seqs, ) + repetition_penalties: The repetition penalties of shape (num_seqs, ) + """ + num_seqs, vocab_size = logits.shape + _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor, + vocab_size, num_seqs) + output_bin_counts, output_mask = get_token_bin_counts_and_mask( + output_tokens_tensor, vocab_size, num_seqs) + + + repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( + 1, vocab_size) + + # Avoid IndexPut operations in original apply_penalties function which are extremely time-consuming on NPU. + sequence_mask = prompt_mask | output_mask + logits = torch.where(sequence_mask & torch.lt(logits, 0), logits * repetition_penalties, + logits).to(logits.dtype) + logits = torch.where(sequence_mask & torch.ge(logits, 0), logits / repetition_penalties, + logits).to(logits.dtype) + + # We follow the definition in OpenAI API. + # Refer to https://platform.openai.com/docs/api-reference/parameter-details + logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts + logits -= presence_penalties.unsqueeze(dim=1) * output_mask + return logits + +def apply_all_penalties( + logits: torch.Tensor, + prompt_token_ids: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor, + output_token_ids: list[list[int]], +) -> torch.Tensor: + """ + Applies presence, frequency and repetition penalties to the logits. + """ + _, vocab_size = logits.shape + output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, + logits.device) + return apply_penalties(logits, prompt_token_ids, output_tokens_t, + presence_penalties, frequency_penalties, + repetition_penalties) \ No newline at end of file diff --git a/vllm_ascend/sample/sampler.py b/vllm_ascend/sample/sampler.py new file mode 100644 index 0000000000..18248afa0f --- /dev/null +++ b/vllm_ascend/sample/sampler.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A layer that samples the next tokens from the model's outputs.""" +from typing import Optional + +import torch +from vllm.model_executor.layers.sampler import (Sampler, + SamplerOutput, + _apply_min_tokens_penalty, + _apply_min_p, + _sample, + SampleResultArgsType, + get_logprobs, + _build_sampler_output) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm_ascend.sample.ops.penalties import apply_penalties + + +class AscendSampler(Sampler): + + def __init__(self): + super().__init__() + + def forward( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + assert logits is not None + _, vocab_size = logits.shape + + # Prepare sampling tensors with pinned memory to avoid blocking. + if not sampling_metadata.reuse_sampling_tensors: + self._init_sampling_tensors(logits, sampling_metadata) + elif self._do_penalties: + # In this case, the sampling tensors logic depends on + # "output_tokens" of a sequence. As a result, we cannot + # reuse sampling tensors, since "output_tokens" changes + # between decode runs. + self._init_sampling_tensors(logits, sampling_metadata) + + assert self._sampling_tensors is not None + sampling_tensors = self._sampling_tensors + do_penalties = self._do_penalties + do_top_p_top_k = self._do_top_p_top_k + do_min_p = self._do_min_p + + logits = _apply_min_tokens_penalty(logits, sampling_metadata) + + # Apply presence and frequency penalties. + if do_penalties: + logits = apply_penalties(logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties) + + # Use float32 to apply temperature scaling. + # Use in-place division to avoid creating a new tensor. + logits = logits.to(torch.float) + logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) + + if do_top_p_top_k: + logits = _apply_top_k_top_p_npu(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) + + if do_min_p: + logits = _apply_min_p(logits, sampling_tensors.min_ps) + + # We use float32 for probabilities and log probabilities. + # Compute the probabilities. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # Compute the log probabilities. + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + + # Sample the next tokens. + maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample( + probs, + logprobs, + sampling_metadata, + sampling_tensors, + include_gpu_probs_tensor=self.include_gpu_probs_tensor, + modify_greedy_probs=self._should_modify_greedy_probs_inplace, + ) + + if self.include_gpu_probs_tensor: + # Since we will defer sampler result Pythonization, + # preserve GPU-side tensors in support of later + # deferred pythonization of logprobs + assert maybe_sampled_tokens_tensor is not None + on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) + else: + # Since Pythonization has already happened, don't preserve + # GPU-side tensors. + on_device_tensors = None + + # Get the logprobs query results. + prompt_logprobs = None + sample_logprobs = None + if not sampling_metadata.skip_sampler_cpu_output: + # Pythonize logprobs now (GPU -> CPU); do not defer. + assert not isinstance(maybe_deferred_sample_results, + SampleResultArgsType) + prompt_logprobs, sample_logprobs = get_logprobs( + logprobs, sampling_metadata, maybe_deferred_sample_results) + + return _build_sampler_output( + maybe_deferred_sample_results, + sampling_metadata, + prompt_logprobs, + sample_logprobs, + on_device_tensors=on_device_tensors, + skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) + + +def _apply_top_k_top_p_npu( + logits: torch.Tensor, + p: torch.Tensor, + k: torch.Tensor, +) -> torch.Tensor: + """Apply top-k and top-p optimized for NPU. + + This algorithm avoids using torch.scatter which is time-consuming on NPU. + """ + # TODO(linfeng): consider the case taht either p or k is applied + batch_size, vocab_size = logits.shape + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1)) + top_k_mask = logits_sort < boundary + logits_sort.masked_fill_(top_k_mask, -float("inf")) + cutoff = top_k_mask.sum(dim=-1).min() + probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:] + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = True + strides = torch.arange(0, batch_size*vocab_size, vocab_size, device=logits.device) + flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1) + valid_idx = torch.masked_select(flatten_idx, top_p_mask) + logits_flatten = logits.flatten() + valid_logits = torch.index_select(logits_flatten, 0, valid_idx) + logits = torch.empty_like(logits_flatten).fill_(-float("inf")) + logits[valid_idx] = valid_logits + return logits.reshape(batch_size, vocab_size) diff --git a/vllm_ascend/sample/sampler_v1.py b/vllm_ascend/sample/sampler_v1.py new file mode 100644 index 0000000000..0cb8006dcd --- /dev/null +++ b/vllm_ascend/sample/sampler_v1.py @@ -0,0 +1,38 @@ +import torch +from vllm.v1.sample.sampler import Sampler +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.penalties import apply_min_token_penalties +from vllm.logger import init_logger +from vllm_ascend.sample.ops.ascend_topk_topp_sampler import AscendTopKTopPSampler +from vllm_ascend.sample.ops.penalties import apply_all_penalties + + +logger = init_logger(__name__) + + +class AscendSampler(Sampler): + + def __init__(self): + super().__init__() + self.topk_topp_sampler = AscendTopKTopPSampler() + + def apply_penalties( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + if sampling_metadata.min_tokens: + apply_min_token_penalties(logits, + sampling_metadata.output_token_ids, + sampling_metadata.min_tokens) + if not sampling_metadata.no_penalties: + assert sampling_metadata.prompt_token_ids is not None + logits = apply_all_penalties( + logits, + sampling_metadata.prompt_token_ids, + sampling_metadata.presence_penalties, + sampling_metadata.frequency_penalties, + sampling_metadata.repetition_penalties, + sampling_metadata.output_token_ids, + ) + return logits \ No newline at end of file diff --git a/vllm_ascend/worker/model_runner.py b/vllm_ascend/worker/model_runner.py index c149ddf9ff..139fb7e3a9 100644 --- a/vllm_ascend/worker/model_runner.py +++ b/vllm_ascend/worker/model_runner.py @@ -60,6 +60,7 @@ _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) +from vllm_ascend.sample.sampler import AscendSampler if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -820,7 +821,12 @@ def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: self.model = get_model(vllm_config=self.vllm_config) - + # Same options with those in model_runner_v1.py + # option 1 + if hasattr(self.model, "sampler"): + self.model.sampler = AscendSampler() + # option 2 + # self.model = NPUModelWrapperV1(model) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 31227359f5..6a97c80c8c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -33,7 +33,9 @@ from vllm.inputs import INPUT_REGISTRY from vllm.logger import logger from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.sampler import sampler_output from vllm.model_executor.model_loader import get_model +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.platforms import current_platform from vllm.sampling_params import SamplingType @@ -52,6 +54,7 @@ from vllm_ascend.attention.attention import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import (AscendAttentionState, AscendMetadata) +from vllm_ascend.sample.sampler_v1 import AscendSampler if TYPE_CHECKING: from vllm.v1.core.scheduler_output import SchedulerOutput @@ -810,6 +813,12 @@ def load_model(self) -> None: with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) + # option 1 + if hasattr(self.model, "sampler"): + self.model.sampler = AscendSampler() + # option 2 + # self.model = NPUModelWrapperV1(model) + if self.lora_config: raise ValueError("LoRA model is not supported on NPU now.") @@ -889,3 +898,25 @@ def get_kv_cache_spec(self) -> KVCacheSpec: f"Unknown attention type: {attn_module.attn_type}") return kv_cache_spec + +# class NPUModelWrapperV1(nn.Module): + +# def __init__(self, model: nn.Module): +# super().__init__() +# self._model = model +# self.sampler = AscendSampler() + +# def __getattr__(self, name): +# return getattr(self._model, name) + +# def sample( +# self, +# logits: Optional[torch.Tensor], +# sampling_metadata: SamplingMetadata, +# ) -> Optional[SamplerOutput]: +# next_tokens = self.sampler(logits, sampling_metadata) +# return next_tokens + +# def forward(): +# # necessary if using wrapper class +# pass From d377ba3acf89e66d9a446e87968886358de5d80d Mon Sep 17 00:00:00 2001 From: linfeng-yuan <1102311262@qq.com> Date: Tue, 15 Apr 2025 16:11:09 +0800 Subject: [PATCH 2/2] refactor: support scenarios where top_p or top_k is None Signed-off-by: linfeng-yuan <1102311262@qq.com> --- .../sample/ops/ascend_topk_topp_sampler.py | 66 +++++++++---------- vllm_ascend/sample/ops/penalties.py | 19 +++--- vllm_ascend/sample/sampler.py | 30 ++++----- vllm_ascend/sample/sampler_v1.py | 12 ++-- vllm_ascend/worker/model_runner.py | 8 +-- vllm_ascend/worker/model_runner_v1.py | 30 +-------- 6 files changed, 63 insertions(+), 102 deletions(-) diff --git a/vllm_ascend/sample/ops/ascend_topk_topp_sampler.py b/vllm_ascend/sample/ops/ascend_topk_topp_sampler.py index 78130412b7..4073ccc4d1 100644 --- a/vllm_ascend/sample/ops/ascend_topk_topp_sampler.py +++ b/vllm_ascend/sample/ops/ascend_topk_topp_sampler.py @@ -1,23 +1,12 @@ from typing import Dict, Optional import torch -import torch.nn as nn - from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample -from vllm.logger import init_logger - - -logger = init_logger(__name__) class AscendTopKTopPSampler(TopKTopPSampler): - def __init__(self): - super().__init__() - # TODO(linfeng): eliminate warning for FlashInfer here - self.forward = self.forward_npu - - def forward_npu( + def forward_native( self, logits: torch.Tensor, generators: Dict[int, torch.Generator], @@ -28,37 +17,48 @@ def forward_npu( logits = apply_top_k_top_p_npu(logits, k, p) probs = logits.softmax(dim=-1, dtype=torch.float32) return random_sample(probs, generators) - + def apply_top_k_top_p_npu( logits: torch.Tensor, k: Optional[torch.Tensor], p: Optional[torch.Tensor], ) -> torch.Tensor: - """Apply top-k and top-p optimized for NPU. - - This algorithm avoids using torch.scatter which is time-consuming on NPU. - """ - # TODO(linfeng): consider the case taht either p or k is applied + """Apply top-k and/or top-p optimized for NPU.""" if k is None and p is None: return logits + batch_size, vocab_size = logits.shape + device = logits.device logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + if k is not None: + safe_k = torch.clamp(k, min=1, max=vocab_size) + boundary_idx = (vocab_size - safe_k).unsqueeze(1) + boundary = logits_sort.gather(1, boundary_idx) + top_k_mask = logits_sort < boundary + logits_sort = logits_sort.masked_fill(top_k_mask, -float("inf")) + else: + top_k_mask = torch.zeros_like(logits_sort, dtype=torch.bool) - boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1)) - top_k_mask = logits_sort < boundary - logits_sort.masked_fill_(top_k_mask, -float("inf")) - cutoff = top_k_mask.sum(dim=-1).min() - probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:] - probs_sum = probs_sort.cumsum(dim=-1) - top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1) - top_p_mask[:, -1] = True - strides = torch.arange(0, batch_size*vocab_size, vocab_size, device=logits.device) - flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1) - valid_idx = torch.masked_select(flatten_idx, top_p_mask) + cutoffs = top_k_mask.sum(dim=-1) + strides = torch.arange(0, + batch_size * vocab_size, + vocab_size, + device=device).unsqueeze(1) + if p is not None: + global_cutoff = cutoffs.min() + active_part = logits_idx[:, global_cutoff:] + probs_sort = logits_sort[:, global_cutoff:].softmax(dim=-1) + cumprob = probs_sort.cumsum(dim=-1) + top_p_mask = (cumprob <= (1 - p.unsqueeze(1))) | (torch.arange( + probs_sort.size(1), device=device) == probs_sort.size(1) - 1) + else: + active_part = logits_idx + top_p_mask = torch.arange(vocab_size, device=device).expand( + batch_size, -1) >= cutoffs.unsqueeze(1) + valid_idx = (active_part + strides).masked_select(top_p_mask) logits_flatten = logits.flatten() - valid_logits = torch.index_select(logits_flatten, 0, valid_idx) - logits = torch.empty_like(logits_flatten).fill_(-float("inf")) - logits[valid_idx] = valid_logits - return logits.reshape(batch_size, vocab_size) \ No newline at end of file + output = torch.full_like(logits_flatten, -float('inf')) + output[valid_idx] = logits_flatten[valid_idx] + return output.reshape(batch_size, vocab_size) diff --git a/vllm_ascend/sample/ops/penalties.py b/vllm_ascend/sample/ops/penalties.py index c84aee7dfb..9fedc8fdfc 100644 --- a/vllm_ascend/sample/ops/penalties.py +++ b/vllm_ascend/sample/ops/penalties.py @@ -1,9 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import torch - -from vllm.v1.sample.ops.penalties import _convert_to_tensors from vllm.model_executor.layers.utils import get_token_bin_counts_and_mask +from vllm.v1.sample.ops.penalties import _convert_to_tensors def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, @@ -31,16 +30,17 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, output_bin_counts, output_mask = get_token_bin_counts_and_mask( output_tokens_tensor, vocab_size, num_seqs) - repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( 1, vocab_size) - + # Avoid IndexPut operations in original apply_penalties function which are extremely time-consuming on NPU. sequence_mask = prompt_mask | output_mask - logits = torch.where(sequence_mask & torch.lt(logits, 0), logits * repetition_penalties, - logits).to(logits.dtype) - logits = torch.where(sequence_mask & torch.ge(logits, 0), logits / repetition_penalties, - logits).to(logits.dtype) + logits = torch.where(sequence_mask & torch.lt(logits, 0), + logits * repetition_penalties, + logits).to(logits.dtype) + logits = torch.where(sequence_mask & torch.ge(logits, 0), + logits / repetition_penalties, + logits).to(logits.dtype) # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details @@ -48,6 +48,7 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, logits -= presence_penalties.unsqueeze(dim=1) * output_mask return logits + def apply_all_penalties( logits: torch.Tensor, prompt_token_ids: torch.Tensor, @@ -64,4 +65,4 @@ def apply_all_penalties( logits.device) return apply_penalties(logits, prompt_token_ids, output_tokens_t, presence_penalties, frequency_penalties, - repetition_penalties) \ No newline at end of file + repetition_penalties) diff --git a/vllm_ascend/sample/sampler.py b/vllm_ascend/sample/sampler.py index 18248afa0f..791839410a 100644 --- a/vllm_ascend/sample/sampler.py +++ b/vllm_ascend/sample/sampler.py @@ -3,15 +3,13 @@ from typing import Optional import torch -from vllm.model_executor.layers.sampler import (Sampler, - SamplerOutput, - _apply_min_tokens_penalty, - _apply_min_p, - _sample, - SampleResultArgsType, - get_logprobs, - _build_sampler_output) +from vllm.model_executor.layers.sampler import (Sampler, SampleResultArgsType, + SamplerOutput, _apply_min_p, + _apply_min_tokens_penalty, + _build_sampler_output, _sample, + get_logprobs) from vllm.model_executor.sampling_metadata import SamplingMetadata + from vllm_ascend.sample.ops.penalties import apply_penalties @@ -61,7 +59,7 @@ def forward( if do_top_p_top_k: logits = _apply_top_k_top_p_npu(logits, sampling_tensors.top_ps, - sampling_tensors.top_ks) + sampling_tensors.top_ks) if do_min_p: logits = _apply_min_p(logits, sampling_tensors.min_ps) @@ -83,21 +81,15 @@ def forward( ) if self.include_gpu_probs_tensor: - # Since we will defer sampler result Pythonization, - # preserve GPU-side tensors in support of later - # deferred pythonization of logprobs assert maybe_sampled_tokens_tensor is not None on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) else: - # Since Pythonization has already happened, don't preserve - # GPU-side tensors. on_device_tensors = None # Get the logprobs query results. prompt_logprobs = None sample_logprobs = None if not sampling_metadata.skip_sampler_cpu_output: - # Pythonize logprobs now (GPU -> CPU); do not defer. assert not isinstance(maybe_deferred_sample_results, SampleResultArgsType) prompt_logprobs, sample_logprobs = get_logprobs( @@ -121,10 +113,9 @@ def _apply_top_k_top_p_npu( This algorithm avoids using torch.scatter which is time-consuming on NPU. """ - # TODO(linfeng): consider the case taht either p or k is applied batch_size, vocab_size = logits.shape logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - + boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1)) top_k_mask = logits_sort < boundary logits_sort.masked_fill_(top_k_mask, -float("inf")) @@ -133,7 +124,10 @@ def _apply_top_k_top_p_npu( probs_sum = probs_sort.cumsum(dim=-1) top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1) top_p_mask[:, -1] = True - strides = torch.arange(0, batch_size*vocab_size, vocab_size, device=logits.device) + strides = torch.arange(0, + batch_size * vocab_size, + vocab_size, + device=logits.device) flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1) valid_idx = torch.masked_select(flatten_idx, top_p_mask) logits_flatten = logits.flatten() diff --git a/vllm_ascend/sample/sampler_v1.py b/vllm_ascend/sample/sampler_v1.py index 0cb8006dcd..90195b5dc0 100644 --- a/vllm_ascend/sample/sampler_v1.py +++ b/vllm_ascend/sample/sampler_v1.py @@ -1,13 +1,11 @@ import torch -from vllm.v1.sample.sampler import Sampler from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.penalties import apply_min_token_penalties -from vllm.logger import init_logger -from vllm_ascend.sample.ops.ascend_topk_topp_sampler import AscendTopKTopPSampler -from vllm_ascend.sample.ops.penalties import apply_all_penalties - +from vllm.v1.sample.sampler import Sampler -logger = init_logger(__name__) +from vllm_ascend.sample.ops.ascend_topk_topp_sampler import \ + AscendTopKTopPSampler +from vllm_ascend.sample.ops.penalties import apply_all_penalties class AscendSampler(Sampler): @@ -35,4 +33,4 @@ def apply_penalties( sampling_metadata.repetition_penalties, sampling_metadata.output_token_ids, ) - return logits \ No newline at end of file + return logits diff --git a/vllm_ascend/worker/model_runner.py b/vllm_ascend/worker/model_runner.py index 139fb7e3a9..ecddb2155c 100644 --- a/vllm_ascend/worker/model_runner.py +++ b/vllm_ascend/worker/model_runner.py @@ -60,6 +60,7 @@ _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) + from vllm_ascend.sample.sampler import AscendSampler if TYPE_CHECKING: @@ -821,12 +822,7 @@ def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: self.model = get_model(vllm_config=self.vllm_config) - # Same options with those in model_runner_v1.py - # option 1 - if hasattr(self.model, "sampler"): - self.model.sampler = AscendSampler() - # option 2 - # self.model = NPUModelWrapperV1(model) + self.model.sampler = AscendSampler() self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 6a97c80c8c..f16afaa817 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -33,9 +33,7 @@ from vllm.inputs import INPUT_REGISTRY from vllm.logger import logger from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.sampler import sampler_output from vllm.model_executor.model_loader import get_model -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.platforms import current_platform from vllm.sampling_params import SamplingType @@ -813,11 +811,7 @@ def load_model(self) -> None: with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) - # option 1 - if hasattr(self.model, "sampler"): - self.model.sampler = AscendSampler() - # option 2 - # self.model = NPUModelWrapperV1(model) + self.model.sampler = AscendSampler() if self.lora_config: raise ValueError("LoRA model is not supported on NPU now.") @@ -898,25 +892,3 @@ def get_kv_cache_spec(self) -> KVCacheSpec: f"Unknown attention type: {attn_module.attn_type}") return kv_cache_spec - -# class NPUModelWrapperV1(nn.Module): - -# def __init__(self, model: nn.Module): -# super().__init__() -# self._model = model -# self.sampler = AscendSampler() - -# def __getattr__(self, name): -# return getattr(self._model, name) - -# def sample( -# self, -# logits: Optional[torch.Tensor], -# sampling_metadata: SamplingMetadata, -# ) -> Optional[SamplerOutput]: -# next_tokens = self.sampler(logits, sampling_metadata) -# return next_tokens - -# def forward(): -# # necessary if using wrapper class -# pass