|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +"""A layer that samples the next tokens from the model's outputs.""" |
| 3 | +from typing import Optional |
| 4 | + |
| 5 | +import torch |
| 6 | +from vllm.model_executor.layers.sampler import (Sampler, SampleResultArgsType, |
| 7 | + SamplerOutput, _apply_min_p, |
| 8 | + _apply_min_tokens_penalty, |
| 9 | + _build_sampler_output, _sample, |
| 10 | + get_logprobs) |
| 11 | +from vllm.model_executor.sampling_metadata import SamplingMetadata |
| 12 | + |
| 13 | +from vllm_ascend.sample.ops.penalties import apply_penalties |
| 14 | + |
| 15 | + |
| 16 | +class AscendSampler(Sampler): |
| 17 | + |
| 18 | + def __init__(self): |
| 19 | + super().__init__() |
| 20 | + |
| 21 | + def forward( |
| 22 | + self, |
| 23 | + logits: torch.Tensor, |
| 24 | + sampling_metadata: SamplingMetadata, |
| 25 | + ) -> Optional[SamplerOutput]: |
| 26 | + assert logits is not None |
| 27 | + _, vocab_size = logits.shape |
| 28 | + |
| 29 | + # Prepare sampling tensors with pinned memory to avoid blocking. |
| 30 | + if not sampling_metadata.reuse_sampling_tensors: |
| 31 | + self._init_sampling_tensors(logits, sampling_metadata) |
| 32 | + elif self._do_penalties: |
| 33 | + # In this case, the sampling tensors logic depends on |
| 34 | + # "output_tokens" of a sequence. As a result, we cannot |
| 35 | + # reuse sampling tensors, since "output_tokens" changes |
| 36 | + # between decode runs. |
| 37 | + self._init_sampling_tensors(logits, sampling_metadata) |
| 38 | + |
| 39 | + assert self._sampling_tensors is not None |
| 40 | + sampling_tensors = self._sampling_tensors |
| 41 | + do_penalties = self._do_penalties |
| 42 | + do_top_p_top_k = self._do_top_p_top_k |
| 43 | + do_min_p = self._do_min_p |
| 44 | + |
| 45 | + logits = _apply_min_tokens_penalty(logits, sampling_metadata) |
| 46 | + |
| 47 | + # Apply presence and frequency penalties. |
| 48 | + if do_penalties: |
| 49 | + logits = apply_penalties(logits, sampling_tensors.prompt_tokens, |
| 50 | + sampling_tensors.output_tokens, |
| 51 | + sampling_tensors.presence_penalties, |
| 52 | + sampling_tensors.frequency_penalties, |
| 53 | + sampling_tensors.repetition_penalties) |
| 54 | + |
| 55 | + # Use float32 to apply temperature scaling. |
| 56 | + # Use in-place division to avoid creating a new tensor. |
| 57 | + logits = logits.to(torch.float) |
| 58 | + logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) |
| 59 | + |
| 60 | + if do_top_p_top_k: |
| 61 | + logits = _apply_top_k_top_p_npu(logits, sampling_tensors.top_ps, |
| 62 | + sampling_tensors.top_ks) |
| 63 | + |
| 64 | + if do_min_p: |
| 65 | + logits = _apply_min_p(logits, sampling_tensors.min_ps) |
| 66 | + |
| 67 | + # We use float32 for probabilities and log probabilities. |
| 68 | + # Compute the probabilities. |
| 69 | + probs = torch.softmax(logits, dim=-1, dtype=torch.float) |
| 70 | + # Compute the log probabilities. |
| 71 | + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) |
| 72 | + |
| 73 | + # Sample the next tokens. |
| 74 | + maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample( |
| 75 | + probs, |
| 76 | + logprobs, |
| 77 | + sampling_metadata, |
| 78 | + sampling_tensors, |
| 79 | + include_gpu_probs_tensor=self.include_gpu_probs_tensor, |
| 80 | + modify_greedy_probs=self._should_modify_greedy_probs_inplace, |
| 81 | + ) |
| 82 | + |
| 83 | + if self.include_gpu_probs_tensor: |
| 84 | + assert maybe_sampled_tokens_tensor is not None |
| 85 | + on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) |
| 86 | + else: |
| 87 | + on_device_tensors = None |
| 88 | + |
| 89 | + # Get the logprobs query results. |
| 90 | + prompt_logprobs = None |
| 91 | + sample_logprobs = None |
| 92 | + if not sampling_metadata.skip_sampler_cpu_output: |
| 93 | + assert not isinstance(maybe_deferred_sample_results, |
| 94 | + SampleResultArgsType) |
| 95 | + prompt_logprobs, sample_logprobs = get_logprobs( |
| 96 | + logprobs, sampling_metadata, maybe_deferred_sample_results) |
| 97 | + |
| 98 | + return _build_sampler_output( |
| 99 | + maybe_deferred_sample_results, |
| 100 | + sampling_metadata, |
| 101 | + prompt_logprobs, |
| 102 | + sample_logprobs, |
| 103 | + on_device_tensors=on_device_tensors, |
| 104 | + skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) |
| 105 | + |
| 106 | + |
| 107 | +def _apply_top_k_top_p_npu( |
| 108 | + logits: torch.Tensor, |
| 109 | + p: torch.Tensor, |
| 110 | + k: torch.Tensor, |
| 111 | +) -> torch.Tensor: |
| 112 | + """Apply top-k and top-p optimized for NPU. |
| 113 | +
|
| 114 | + This algorithm avoids using torch.scatter which is time-consuming on NPU. |
| 115 | + """ |
| 116 | + batch_size, vocab_size = logits.shape |
| 117 | + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) |
| 118 | + |
| 119 | + boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1)) |
| 120 | + top_k_mask = logits_sort < boundary |
| 121 | + logits_sort.masked_fill_(top_k_mask, -float("inf")) |
| 122 | + cutoff = top_k_mask.sum(dim=-1).min() |
| 123 | + probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:] |
| 124 | + probs_sum = probs_sort.cumsum(dim=-1) |
| 125 | + top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1) |
| 126 | + top_p_mask[:, -1] = True |
| 127 | + strides = torch.arange(0, |
| 128 | + batch_size * vocab_size, |
| 129 | + vocab_size, |
| 130 | + device=logits.device) |
| 131 | + flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1) |
| 132 | + valid_idx = torch.masked_select(flatten_idx, top_p_mask) |
| 133 | + logits_flatten = logits.flatten() |
| 134 | + valid_logits = torch.index_select(logits_flatten, 0, valid_idx) |
| 135 | + logits = torch.empty_like(logits_flatten).fill_(-float("inf")) |
| 136 | + logits[valid_idx] = valid_logits |
| 137 | + return logits.reshape(batch_size, vocab_size) |
0 commit comments