Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
64 changes: 64 additions & 0 deletions vllm_ascend/sample/ops/ascend_topk_topp_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Dict, Optional

import torch
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample


class AscendTopKTopPSampler(TopKTopPSampler):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is duplicated with #970


def forward_native(
self,
logits: torch.Tensor,
generators: Dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""Optimized implementation of top-k and top-p sampling on NPU."""
logits = apply_top_k_top_p_npu(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)


def apply_top_k_top_p_npu(
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""Apply top-k and/or top-p optimized for NPU."""
if k is None and p is None:
return logits

batch_size, vocab_size = logits.shape
device = logits.device
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if k is not None:
safe_k = torch.clamp(k, min=1, max=vocab_size)
boundary_idx = (vocab_size - safe_k).unsqueeze(1)
boundary = logits_sort.gather(1, boundary_idx)
top_k_mask = logits_sort < boundary
logits_sort = logits_sort.masked_fill(top_k_mask, -float("inf"))
else:
top_k_mask = torch.zeros_like(logits_sort, dtype=torch.bool)

cutoffs = top_k_mask.sum(dim=-1)
strides = torch.arange(0,
batch_size * vocab_size,
vocab_size,
device=device).unsqueeze(1)
if p is not None:
global_cutoff = cutoffs.min()
active_part = logits_idx[:, global_cutoff:]
probs_sort = logits_sort[:, global_cutoff:].softmax(dim=-1)
cumprob = probs_sort.cumsum(dim=-1)
top_p_mask = (cumprob <= (1 - p.unsqueeze(1))) | (torch.arange(
probs_sort.size(1), device=device) == probs_sort.size(1) - 1)
else:
active_part = logits_idx
top_p_mask = torch.arange(vocab_size, device=device).expand(
batch_size, -1) >= cutoffs.unsqueeze(1)

valid_idx = (active_part + strides).masked_select(top_p_mask)
logits_flatten = logits.flatten()
output = torch.full_like(logits_flatten, -float('inf'))
output[valid_idx] = logits_flatten[valid_idx]
return output.reshape(batch_size, vocab_size)
68 changes: 68 additions & 0 deletions vllm_ascend/sample/ops/penalties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-License-Identifier: Apache-2.0

import torch
from vllm.model_executor.layers.utils import get_token_bin_counts_and_mask
from vllm.v1.sample.ops.penalties import _convert_to_tensors


def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor) -> torch.Tensor:
"""Optimized implementation of repetition penalties on NPU.

Applies penalties in place to the logits tensor
logits : The input logits tensor of shape [num_seqs, vocab_size]
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
are padded to the maximum prompt length within the batch using
`vocab_size` as the padding value. The value `vocab_size` is used
for padding because it does not correspond to any valid token ID
in the vocabulary.
output_tokens_tensor: The output tokens tensor.
presence_penalties: The presence penalties of shape (num_seqs, )
frequency_penalties: The frequency penalties of shape (num_seqs, )
repetition_penalties: The repetition penalties of shape (num_seqs, )
"""
num_seqs, vocab_size = logits.shape
_, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
vocab_size, num_seqs)
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
output_tokens_tensor, vocab_size, num_seqs)

repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, vocab_size)

# Avoid IndexPut operations in original apply_penalties function which are extremely time-consuming on NPU.
sequence_mask = prompt_mask | output_mask
logits = torch.where(sequence_mask & torch.lt(logits, 0),
logits * repetition_penalties,
logits).to(logits.dtype)
logits = torch.where(sequence_mask & torch.ge(logits, 0),
logits / repetition_penalties,
logits).to(logits.dtype)

# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
return logits


def apply_all_penalties(
logits: torch.Tensor,
prompt_token_ids: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor,
output_token_ids: list[list[int]],
) -> torch.Tensor:
"""
Applies presence, frequency and repetition penalties to the logits.
"""
_, vocab_size = logits.shape
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
logits.device)
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
presence_penalties, frequency_penalties,
repetition_penalties)
137 changes: 137 additions & 0 deletions vllm_ascend/sample/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# SPDX-License-Identifier: Apache-2.0
"""A layer that samples the next tokens from the model's outputs."""
from typing import Optional

import torch
from vllm.model_executor.layers.sampler import (Sampler, SampleResultArgsType,
SamplerOutput, _apply_min_p,
_apply_min_tokens_penalty,
_build_sampler_output, _sample,
get_logprobs)
from vllm.model_executor.sampling_metadata import SamplingMetadata

from vllm_ascend.sample.ops.penalties import apply_penalties


class AscendSampler(Sampler):

def __init__(self):
super().__init__()

def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
assert logits is not None
_, vocab_size = logits.shape

# Prepare sampling tensors with pinned memory to avoid blocking.
if not sampling_metadata.reuse_sampling_tensors:
self._init_sampling_tensors(logits, sampling_metadata)
elif self._do_penalties:
# In this case, the sampling tensors logic depends on
# "output_tokens" of a sequence. As a result, we cannot
# reuse sampling tensors, since "output_tokens" changes
# between decode runs.
self._init_sampling_tensors(logits, sampling_metadata)

assert self._sampling_tensors is not None
sampling_tensors = self._sampling_tensors
do_penalties = self._do_penalties
do_top_p_top_k = self._do_top_p_top_k
do_min_p = self._do_min_p

logits = _apply_min_tokens_penalty(logits, sampling_metadata)

# Apply presence and frequency penalties.
if do_penalties:
logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
sampling_tensors.output_tokens,
sampling_tensors.presence_penalties,
sampling_tensors.frequency_penalties,
sampling_tensors.repetition_penalties)

# Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))

if do_top_p_top_k:
logits = _apply_top_k_top_p_npu(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)

if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)

# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)

# Sample the next tokens.
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
)

if self.include_gpu_probs_tensor:
assert maybe_sampled_tokens_tensor is not None
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
else:
on_device_tensors = None

# Get the logprobs query results.
prompt_logprobs = None
sample_logprobs = None
if not sampling_metadata.skip_sampler_cpu_output:
assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
prompt_logprobs, sample_logprobs = get_logprobs(
logprobs, sampling_metadata, maybe_deferred_sample_results)

return _build_sampler_output(
maybe_deferred_sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)


def _apply_top_k_top_p_npu(
logits: torch.Tensor,
p: torch.Tensor,
k: torch.Tensor,
) -> torch.Tensor:
"""Apply top-k and top-p optimized for NPU.

This algorithm avoids using torch.scatter which is time-consuming on NPU.
"""
batch_size, vocab_size = logits.shape
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
top_k_mask = logits_sort < boundary
logits_sort.masked_fill_(top_k_mask, -float("inf"))
cutoff = top_k_mask.sum(dim=-1).min()
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = True
strides = torch.arange(0,
batch_size * vocab_size,
vocab_size,
device=logits.device)
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
logits_flatten = logits.flatten()
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
logits[valid_idx] = valid_logits
return logits.reshape(batch_size, vocab_size)
36 changes: 36 additions & 0 deletions vllm_ascend/sample/sampler_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.penalties import apply_min_token_penalties
from vllm.v1.sample.sampler import Sampler

from vllm_ascend.sample.ops.ascend_topk_topp_sampler import \
AscendTopKTopPSampler
from vllm_ascend.sample.ops.penalties import apply_all_penalties


class AscendSampler(Sampler):

def __init__(self):
super().__init__()
self.topk_topp_sampler = AscendTopKTopPSampler()

def apply_penalties(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if sampling_metadata.min_tokens:
apply_min_token_penalties(logits,
sampling_metadata.output_token_ids,
sampling_metadata.min_tokens)
if not sampling_metadata.no_penalties:
assert sampling_metadata.prompt_token_ids is not None
logits = apply_all_penalties(
logits,
sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids,
)
return logits
8 changes: 4 additions & 4 deletions vllm_ascend/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput,
get_sampler)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import supports_lora, supports_multimodal
Expand All @@ -66,6 +65,7 @@
_init_sampling_metadata_from_tensor_dict)

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.sample.sampler import AscendSampler

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
Expand Down Expand Up @@ -986,7 +986,7 @@ def __init__(
self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None
self.sampler = get_sampler()
self.sampler = AscendSampler()

def get_model(self) -> nn.Module:
return self.model
Expand Down Expand Up @@ -1487,7 +1487,7 @@ def execute_model(
model_input.async_callback()

# Sample the next token.
assert isinstance(self.sampler, Sampler)
assert isinstance(self.sampler, AscendSampler)
orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = True
Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
KVCacheSpec)
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
Expand All @@ -68,6 +67,7 @@
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.sample.sampler_v1 import AscendSampler
from vllm_ascend.utils import ProfileExecuteDuration, vllm_version_is
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer

Expand Down Expand Up @@ -320,7 +320,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
self.attn_mask_len, self.dtype)

self.sampler = Sampler()
self.sampler = AscendSampler()

self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
Expand Down
Loading