Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added vllm_ascend/sample/__init__.py
Empty file.
Empty file.
64 changes: 64 additions & 0 deletions vllm_ascend/sample/ops/ascend_topk_topp_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Dict, Optional

import torch
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample


class AscendTopKTopPSampler(TopKTopPSampler):

def forward_native(
self,
logits: torch.Tensor,
generators: Dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""Optimized implementation of top-k and top-p sampling on NPU."""
logits = apply_top_k_top_p_npu(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)


def apply_top_k_top_p_npu(
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""Apply top-k and/or top-p optimized for NPU."""
if k is None and p is None:
return logits

batch_size, vocab_size = logits.shape
device = logits.device
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if k is not None:
safe_k = torch.clamp(k, min=1, max=vocab_size)
boundary_idx = (vocab_size - safe_k).unsqueeze(1)
boundary = logits_sort.gather(1, boundary_idx)
top_k_mask = logits_sort < boundary
logits_sort = logits_sort.masked_fill(top_k_mask, -float("inf"))
else:
top_k_mask = torch.zeros_like(logits_sort, dtype=torch.bool)

cutoffs = top_k_mask.sum(dim=-1)
strides = torch.arange(0,
batch_size * vocab_size,
vocab_size,
device=device).unsqueeze(1)
if p is not None:
global_cutoff = cutoffs.min()
active_part = logits_idx[:, global_cutoff:]
probs_sort = logits_sort[:, global_cutoff:].softmax(dim=-1)
cumprob = probs_sort.cumsum(dim=-1)
top_p_mask = (cumprob <= (1 - p.unsqueeze(1))) | (torch.arange(
probs_sort.size(1), device=device) == probs_sort.size(1) - 1)
else:
active_part = logits_idx
top_p_mask = torch.arange(vocab_size, device=device).expand(
batch_size, -1) >= cutoffs.unsqueeze(1)

valid_idx = (active_part + strides).masked_select(top_p_mask)
logits_flatten = logits.flatten()
output = torch.full_like(logits_flatten, -float('inf'))
output[valid_idx] = logits_flatten[valid_idx]
return output.reshape(batch_size, vocab_size)
68 changes: 68 additions & 0 deletions vllm_ascend/sample/ops/penalties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-License-Identifier: Apache-2.0

import torch
from vllm.model_executor.layers.utils import get_token_bin_counts_and_mask
from vllm.v1.sample.ops.penalties import _convert_to_tensors


def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor) -> torch.Tensor:
"""Optimized implementation of repetition penalties on NPU.

Applies penalties in place to the logits tensor
logits : The input logits tensor of shape [num_seqs, vocab_size]
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
are padded to the maximum prompt length within the batch using
`vocab_size` as the padding value. The value `vocab_size` is used
for padding because it does not correspond to any valid token ID
in the vocabulary.
output_tokens_tensor: The output tokens tensor.
presence_penalties: The presence penalties of shape (num_seqs, )
frequency_penalties: The frequency penalties of shape (num_seqs, )
repetition_penalties: The repetition penalties of shape (num_seqs, )
"""
num_seqs, vocab_size = logits.shape
_, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
vocab_size, num_seqs)
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
output_tokens_tensor, vocab_size, num_seqs)

repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, vocab_size)

# Avoid IndexPut operations in original apply_penalties function which are extremely time-consuming on NPU.
sequence_mask = prompt_mask | output_mask
logits = torch.where(sequence_mask & torch.lt(logits, 0),
logits * repetition_penalties,
logits).to(logits.dtype)
logits = torch.where(sequence_mask & torch.ge(logits, 0),
logits / repetition_penalties,
logits).to(logits.dtype)

# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
return logits


def apply_all_penalties(
logits: torch.Tensor,
prompt_token_ids: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor,
output_token_ids: list[list[int]],
) -> torch.Tensor:
"""
Applies presence, frequency and repetition penalties to the logits.
"""
_, vocab_size = logits.shape
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
logits.device)
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
presence_penalties, frequency_penalties,
repetition_penalties)
137 changes: 137 additions & 0 deletions vllm_ascend/sample/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# SPDX-License-Identifier: Apache-2.0
"""A layer that samples the next tokens from the model's outputs."""
from typing import Optional

import torch
from vllm.model_executor.layers.sampler import (Sampler, SampleResultArgsType,
SamplerOutput, _apply_min_p,
_apply_min_tokens_penalty,
_build_sampler_output, _sample,
get_logprobs)
from vllm.model_executor.sampling_metadata import SamplingMetadata

from vllm_ascend.sample.ops.penalties import apply_penalties


class AscendSampler(Sampler):

def __init__(self):
super().__init__()

def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
assert logits is not None
_, vocab_size = logits.shape

# Prepare sampling tensors with pinned memory to avoid blocking.
if not sampling_metadata.reuse_sampling_tensors:
self._init_sampling_tensors(logits, sampling_metadata)
elif self._do_penalties:
# In this case, the sampling tensors logic depends on
# "output_tokens" of a sequence. As a result, we cannot
# reuse sampling tensors, since "output_tokens" changes
# between decode runs.
self._init_sampling_tensors(logits, sampling_metadata)

assert self._sampling_tensors is not None
sampling_tensors = self._sampling_tensors
do_penalties = self._do_penalties
do_top_p_top_k = self._do_top_p_top_k
do_min_p = self._do_min_p

logits = _apply_min_tokens_penalty(logits, sampling_metadata)

# Apply presence and frequency penalties.
if do_penalties:
logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
sampling_tensors.output_tokens,
sampling_tensors.presence_penalties,
sampling_tensors.frequency_penalties,
sampling_tensors.repetition_penalties)

# Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))

if do_top_p_top_k:
logits = _apply_top_k_top_p_npu(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)

if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)

# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)

# Sample the next tokens.
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
)

if self.include_gpu_probs_tensor:
assert maybe_sampled_tokens_tensor is not None
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
else:
on_device_tensors = None

# Get the logprobs query results.
prompt_logprobs = None
sample_logprobs = None
if not sampling_metadata.skip_sampler_cpu_output:
assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
prompt_logprobs, sample_logprobs = get_logprobs(
logprobs, sampling_metadata, maybe_deferred_sample_results)

return _build_sampler_output(
maybe_deferred_sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)


def _apply_top_k_top_p_npu(
logits: torch.Tensor,
p: torch.Tensor,
k: torch.Tensor,
) -> torch.Tensor:
"""Apply top-k and top-p optimized for NPU.

This algorithm avoids using torch.scatter which is time-consuming on NPU.
"""
batch_size, vocab_size = logits.shape
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
top_k_mask = logits_sort < boundary
logits_sort.masked_fill_(top_k_mask, -float("inf"))
cutoff = top_k_mask.sum(dim=-1).min()
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = True
strides = torch.arange(0,
batch_size * vocab_size,
vocab_size,
device=logits.device)
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
logits_flatten = logits.flatten()
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
logits[valid_idx] = valid_logits
return logits.reshape(batch_size, vocab_size)
36 changes: 36 additions & 0 deletions vllm_ascend/sample/sampler_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.penalties import apply_min_token_penalties
from vllm.v1.sample.sampler import Sampler

from vllm_ascend.sample.ops.ascend_topk_topp_sampler import \
AscendTopKTopPSampler
from vllm_ascend.sample.ops.penalties import apply_all_penalties


class AscendSampler(Sampler):

def __init__(self):
super().__init__()
self.topk_topp_sampler = AscendTopKTopPSampler()

def apply_penalties(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if sampling_metadata.min_tokens:
apply_min_token_penalties(logits,
sampling_metadata.output_token_ids,
sampling_metadata.min_tokens)
if not sampling_metadata.no_penalties:
assert sampling_metadata.prompt_token_ids is not None
logits = apply_all_penalties(
logits,
sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids,
)
return logits
4 changes: 3 additions & 1 deletion vllm_ascend/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)

from vllm_ascend.sample.sampler import AscendSampler

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend

Expand Down Expand Up @@ -820,7 +822,7 @@ def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m:
self.model = get_model(vllm_config=self.vllm_config)

self.model.sampler = AscendSampler()
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30))
Expand Down
3 changes: 3 additions & 0 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from vllm_ascend.attention.attention import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
from vllm_ascend.sample.sampler_v1 import AscendSampler

if TYPE_CHECKING:
from vllm.v1.core.scheduler_output import SchedulerOutput
Expand Down Expand Up @@ -810,6 +811,8 @@ def load_model(self) -> None:

with DeviceMemoryProfiler() as m: # noqa: SIM117
self.model = get_model(vllm_config=self.vllm_config)
self.model.sampler = AscendSampler()

if self.lora_config:
raise ValueError("LoRA model is not supported on NPU now.")

Expand Down