diff --git a/vllm_ascend/sample/ops/__init__.py b/vllm_ascend/sample/ops/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_ascend/sample/ops/ascend_topk_topp_sampler.py b/vllm_ascend/sample/ops/ascend_topk_topp_sampler.py new file mode 100644 index 0000000000..4073ccc4d1 --- /dev/null +++ b/vllm_ascend/sample/ops/ascend_topk_topp_sampler.py @@ -0,0 +1,64 @@ +from typing import Dict, Optional + +import torch +from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample + + +class AscendTopKTopPSampler(TopKTopPSampler): + + def forward_native( + self, + logits: torch.Tensor, + generators: Dict[int, torch.Generator], + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], + ) -> torch.Tensor: + """Optimized implementation of top-k and top-p sampling on NPU.""" + logits = apply_top_k_top_p_npu(logits, k, p) + probs = logits.softmax(dim=-1, dtype=torch.float32) + return random_sample(probs, generators) + + +def apply_top_k_top_p_npu( + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + """Apply top-k and/or top-p optimized for NPU.""" + if k is None and p is None: + return logits + + batch_size, vocab_size = logits.shape + device = logits.device + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + if k is not None: + safe_k = torch.clamp(k, min=1, max=vocab_size) + boundary_idx = (vocab_size - safe_k).unsqueeze(1) + boundary = logits_sort.gather(1, boundary_idx) + top_k_mask = logits_sort < boundary + logits_sort = logits_sort.masked_fill(top_k_mask, -float("inf")) + else: + top_k_mask = torch.zeros_like(logits_sort, dtype=torch.bool) + + cutoffs = top_k_mask.sum(dim=-1) + strides = torch.arange(0, + batch_size * vocab_size, + vocab_size, + device=device).unsqueeze(1) + if p is not None: + global_cutoff = cutoffs.min() + active_part = logits_idx[:, global_cutoff:] + probs_sort = logits_sort[:, global_cutoff:].softmax(dim=-1) + cumprob = probs_sort.cumsum(dim=-1) + top_p_mask = (cumprob <= (1 - p.unsqueeze(1))) | (torch.arange( + probs_sort.size(1), device=device) == probs_sort.size(1) - 1) + else: + active_part = logits_idx + top_p_mask = torch.arange(vocab_size, device=device).expand( + batch_size, -1) >= cutoffs.unsqueeze(1) + + valid_idx = (active_part + strides).masked_select(top_p_mask) + logits_flatten = logits.flatten() + output = torch.full_like(logits_flatten, -float('inf')) + output[valid_idx] = logits_flatten[valid_idx] + return output.reshape(batch_size, vocab_size) diff --git a/vllm_ascend/sample/ops/penalties.py b/vllm_ascend/sample/ops/penalties.py new file mode 100644 index 0000000000..9fedc8fdfc --- /dev/null +++ b/vllm_ascend/sample/ops/penalties.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch +from vllm.model_executor.layers.utils import get_token_bin_counts_and_mask +from vllm.v1.sample.ops.penalties import _convert_to_tensors + + +def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, + output_tokens_tensor: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor) -> torch.Tensor: + """Optimized implementation of repetition penalties on NPU. + + Applies penalties in place to the logits tensor + logits : The input logits tensor of shape [num_seqs, vocab_size] + prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts + are padded to the maximum prompt length within the batch using + `vocab_size` as the padding value. The value `vocab_size` is used + for padding because it does not correspond to any valid token ID + in the vocabulary. + output_tokens_tensor: The output tokens tensor. + presence_penalties: The presence penalties of shape (num_seqs, ) + frequency_penalties: The frequency penalties of shape (num_seqs, ) + repetition_penalties: The repetition penalties of shape (num_seqs, ) + """ + num_seqs, vocab_size = logits.shape + _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor, + vocab_size, num_seqs) + output_bin_counts, output_mask = get_token_bin_counts_and_mask( + output_tokens_tensor, vocab_size, num_seqs) + + repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( + 1, vocab_size) + + # Avoid IndexPut operations in original apply_penalties function which are extremely time-consuming on NPU. + sequence_mask = prompt_mask | output_mask + logits = torch.where(sequence_mask & torch.lt(logits, 0), + logits * repetition_penalties, + logits).to(logits.dtype) + logits = torch.where(sequence_mask & torch.ge(logits, 0), + logits / repetition_penalties, + logits).to(logits.dtype) + + # We follow the definition in OpenAI API. + # Refer to https://platform.openai.com/docs/api-reference/parameter-details + logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts + logits -= presence_penalties.unsqueeze(dim=1) * output_mask + return logits + + +def apply_all_penalties( + logits: torch.Tensor, + prompt_token_ids: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor, + output_token_ids: list[list[int]], +) -> torch.Tensor: + """ + Applies presence, frequency and repetition penalties to the logits. + """ + _, vocab_size = logits.shape + output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, + logits.device) + return apply_penalties(logits, prompt_token_ids, output_tokens_t, + presence_penalties, frequency_penalties, + repetition_penalties) diff --git a/vllm_ascend/sample/sampler.py b/vllm_ascend/sample/sampler.py new file mode 100644 index 0000000000..791839410a --- /dev/null +++ b/vllm_ascend/sample/sampler.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A layer that samples the next tokens from the model's outputs.""" +from typing import Optional + +import torch +from vllm.model_executor.layers.sampler import (Sampler, SampleResultArgsType, + SamplerOutput, _apply_min_p, + _apply_min_tokens_penalty, + _build_sampler_output, _sample, + get_logprobs) +from vllm.model_executor.sampling_metadata import SamplingMetadata + +from vllm_ascend.sample.ops.penalties import apply_penalties + + +class AscendSampler(Sampler): + + def __init__(self): + super().__init__() + + def forward( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + assert logits is not None + _, vocab_size = logits.shape + + # Prepare sampling tensors with pinned memory to avoid blocking. + if not sampling_metadata.reuse_sampling_tensors: + self._init_sampling_tensors(logits, sampling_metadata) + elif self._do_penalties: + # In this case, the sampling tensors logic depends on + # "output_tokens" of a sequence. As a result, we cannot + # reuse sampling tensors, since "output_tokens" changes + # between decode runs. + self._init_sampling_tensors(logits, sampling_metadata) + + assert self._sampling_tensors is not None + sampling_tensors = self._sampling_tensors + do_penalties = self._do_penalties + do_top_p_top_k = self._do_top_p_top_k + do_min_p = self._do_min_p + + logits = _apply_min_tokens_penalty(logits, sampling_metadata) + + # Apply presence and frequency penalties. + if do_penalties: + logits = apply_penalties(logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties) + + # Use float32 to apply temperature scaling. + # Use in-place division to avoid creating a new tensor. + logits = logits.to(torch.float) + logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) + + if do_top_p_top_k: + logits = _apply_top_k_top_p_npu(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) + + if do_min_p: + logits = _apply_min_p(logits, sampling_tensors.min_ps) + + # We use float32 for probabilities and log probabilities. + # Compute the probabilities. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # Compute the log probabilities. + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + + # Sample the next tokens. + maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample( + probs, + logprobs, + sampling_metadata, + sampling_tensors, + include_gpu_probs_tensor=self.include_gpu_probs_tensor, + modify_greedy_probs=self._should_modify_greedy_probs_inplace, + ) + + if self.include_gpu_probs_tensor: + assert maybe_sampled_tokens_tensor is not None + on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) + else: + on_device_tensors = None + + # Get the logprobs query results. + prompt_logprobs = None + sample_logprobs = None + if not sampling_metadata.skip_sampler_cpu_output: + assert not isinstance(maybe_deferred_sample_results, + SampleResultArgsType) + prompt_logprobs, sample_logprobs = get_logprobs( + logprobs, sampling_metadata, maybe_deferred_sample_results) + + return _build_sampler_output( + maybe_deferred_sample_results, + sampling_metadata, + prompt_logprobs, + sample_logprobs, + on_device_tensors=on_device_tensors, + skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) + + +def _apply_top_k_top_p_npu( + logits: torch.Tensor, + p: torch.Tensor, + k: torch.Tensor, +) -> torch.Tensor: + """Apply top-k and top-p optimized for NPU. + + This algorithm avoids using torch.scatter which is time-consuming on NPU. + """ + batch_size, vocab_size = logits.shape + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1)) + top_k_mask = logits_sort < boundary + logits_sort.masked_fill_(top_k_mask, -float("inf")) + cutoff = top_k_mask.sum(dim=-1).min() + probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:] + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = True + strides = torch.arange(0, + batch_size * vocab_size, + vocab_size, + device=logits.device) + flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1) + valid_idx = torch.masked_select(flatten_idx, top_p_mask) + logits_flatten = logits.flatten() + valid_logits = torch.index_select(logits_flatten, 0, valid_idx) + logits = torch.empty_like(logits_flatten).fill_(-float("inf")) + logits[valid_idx] = valid_logits + return logits.reshape(batch_size, vocab_size) diff --git a/vllm_ascend/sample/sampler_v1.py b/vllm_ascend/sample/sampler_v1.py new file mode 100644 index 0000000000..90195b5dc0 --- /dev/null +++ b/vllm_ascend/sample/sampler_v1.py @@ -0,0 +1,36 @@ +import torch +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.penalties import apply_min_token_penalties +from vllm.v1.sample.sampler import Sampler + +from vllm_ascend.sample.ops.ascend_topk_topp_sampler import \ + AscendTopKTopPSampler +from vllm_ascend.sample.ops.penalties import apply_all_penalties + + +class AscendSampler(Sampler): + + def __init__(self): + super().__init__() + self.topk_topp_sampler = AscendTopKTopPSampler() + + def apply_penalties( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + if sampling_metadata.min_tokens: + apply_min_token_penalties(logits, + sampling_metadata.output_token_ids, + sampling_metadata.min_tokens) + if not sampling_metadata.no_penalties: + assert sampling_metadata.prompt_token_ids is not None + logits = apply_all_penalties( + logits, + sampling_metadata.prompt_token_ids, + sampling_metadata.presence_penalties, + sampling_metadata.frequency_penalties, + sampling_metadata.repetition_penalties, + sampling_metadata.output_token_ids, + ) + return logits diff --git a/vllm_ascend/worker/model_runner.py b/vllm_ascend/worker/model_runner.py index 43059b82cc..b5dbeceefd 100644 --- a/vllm_ascend/worker/model_runner.py +++ b/vllm_ascend/worker/model_runner.py @@ -43,8 +43,7 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata, SamplingMetadataCache from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput, - get_sampler) +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models import supports_lora, supports_multimodal @@ -66,6 +65,7 @@ _init_sampling_metadata_from_tensor_dict) from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.sample.sampler import AscendSampler if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -986,7 +986,7 @@ def __init__( self.sampling_metadata_cache: SamplingMetadataCache = \ SamplingMetadataCache() \ if self.parallel_config.pipeline_parallel_size == 1 else None - self.sampler = get_sampler() + self.sampler = AscendSampler() def get_model(self) -> nn.Module: return self.model @@ -1487,7 +1487,7 @@ def execute_model( model_input.async_callback() # Sample the next token. - assert isinstance(self.sampler, Sampler) + assert isinstance(self.sampler, AscendSampler) orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor if model_input.inputs_embeds is not None: self.sampler.include_gpu_probs_tensor = True diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 12a596c109..3b8de29af0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -53,7 +53,6 @@ KVCacheSpec) from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.sampler import Sampler from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer @@ -68,6 +67,7 @@ from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler +from vllm_ascend.sample.sampler_v1 import AscendSampler from vllm_ascend.utils import ProfileExecuteDuration, vllm_version_is from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer @@ -320,7 +320,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len( self.attn_mask_len, self.dtype) - self.sampler = Sampler() + self.sampler = AscendSampler() self.torchair_compiled_model = None # type: ignore self.torchair_compiled_models = {} # type: ignore