Skip to content

Commit b55ffca

Browse files
committed
perf(npu): greatly accelerate post-processing on Ascend platform
Signed-off-by: linfeng-yuan <1102311262@qq.com>
1 parent 15314cc commit b55ffca

File tree

8 files changed

+350
-1
lines changed

8 files changed

+350
-1
lines changed

vllm_ascend/sample/__init__.py

Whitespace-only changes.

vllm_ascend/sample/ops/__init__.py

Whitespace-only changes.
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from typing import Dict, Optional
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
7+
from vllm.logger import init_logger
8+
9+
10+
logger = init_logger(__name__)
11+
12+
13+
class AscendTopKTopPSampler(TopKTopPSampler):
14+
15+
def __init__(self):
16+
super().__init__()
17+
# TODO(linfeng): eliminate warning for FlashInfer here
18+
self.forward = self.forward_npu
19+
20+
def forward_npu(
21+
self,
22+
logits: torch.Tensor,
23+
generators: Dict[int, torch.Generator],
24+
k: Optional[torch.Tensor],
25+
p: Optional[torch.Tensor],
26+
) -> torch.Tensor:
27+
"""Optimized implementation of top-k and top-p sampling on NPU."""
28+
logits = apply_top_k_top_p_npu(logits, k, p)
29+
probs = logits.softmax(dim=-1, dtype=torch.float32)
30+
return random_sample(probs, generators)
31+
32+
33+
def apply_top_k_top_p_npu(
34+
logits: torch.Tensor,
35+
k: Optional[torch.Tensor],
36+
p: Optional[torch.Tensor],
37+
) -> torch.Tensor:
38+
"""Apply top-k and top-p optimized for NPU.
39+
40+
This algorithm avoids using torch.scatter which is time-consuming on NPU.
41+
"""
42+
# TODO(linfeng): consider the case taht either p or k is applied
43+
if k is None and p is None:
44+
return logits
45+
batch_size, vocab_size = logits.shape
46+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
47+
48+
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
49+
top_k_mask = logits_sort < boundary
50+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
51+
cutoff = top_k_mask.sum(dim=-1).min()
52+
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
53+
probs_sum = probs_sort.cumsum(dim=-1)
54+
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
55+
top_p_mask[:, -1] = True
56+
strides = torch.arange(0, batch_size*vocab_size, vocab_size, device=logits.device)
57+
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
58+
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
59+
60+
logits_flatten = logits.flatten()
61+
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
62+
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
63+
logits[valid_idx] = valid_logits
64+
return logits.reshape(batch_size, vocab_size)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import torch
4+
5+
from vllm.v1.sample.ops.penalties import _convert_to_tensors
6+
from vllm.model_executor.layers.utils import get_token_bin_counts_and_mask
7+
8+
9+
def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
10+
output_tokens_tensor: torch.Tensor,
11+
presence_penalties: torch.Tensor,
12+
frequency_penalties: torch.Tensor,
13+
repetition_penalties: torch.Tensor) -> torch.Tensor:
14+
"""Optimized implementation of repetition penalties on NPU.
15+
16+
Applies penalties in place to the logits tensor
17+
logits : The input logits tensor of shape [num_seqs, vocab_size]
18+
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
19+
are padded to the maximum prompt length within the batch using
20+
`vocab_size` as the padding value. The value `vocab_size` is used
21+
for padding because it does not correspond to any valid token ID
22+
in the vocabulary.
23+
output_tokens_tensor: The output tokens tensor.
24+
presence_penalties: The presence penalties of shape (num_seqs, )
25+
frequency_penalties: The frequency penalties of shape (num_seqs, )
26+
repetition_penalties: The repetition penalties of shape (num_seqs, )
27+
"""
28+
num_seqs, vocab_size = logits.shape
29+
_, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
30+
vocab_size, num_seqs)
31+
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
32+
output_tokens_tensor, vocab_size, num_seqs)
33+
34+
35+
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
36+
1, vocab_size)
37+
38+
# Avoid IndexPut operations in original apply_penalties function which are extremely time-consuming on NPU.
39+
sequence_mask = prompt_mask | output_mask
40+
logits = torch.where(sequence_mask & torch.lt(logits, 0), logits * repetition_penalties,
41+
logits).to(logits.dtype)
42+
logits = torch.where(sequence_mask & torch.ge(logits, 0), logits / repetition_penalties,
43+
logits).to(logits.dtype)
44+
45+
# We follow the definition in OpenAI API.
46+
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
47+
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
48+
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
49+
return logits
50+
51+
def apply_all_penalties(
52+
logits: torch.Tensor,
53+
prompt_token_ids: torch.Tensor,
54+
presence_penalties: torch.Tensor,
55+
frequency_penalties: torch.Tensor,
56+
repetition_penalties: torch.Tensor,
57+
output_token_ids: list[list[int]],
58+
) -> torch.Tensor:
59+
"""
60+
Applies presence, frequency and repetition penalties to the logits.
61+
"""
62+
_, vocab_size = logits.shape
63+
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
64+
logits.device)
65+
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
66+
presence_penalties, frequency_penalties,
67+
repetition_penalties)

vllm_ascend/sample/sampler.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""A layer that samples the next tokens from the model's outputs."""
3+
from typing import Optional
4+
5+
import torch
6+
from vllm.model_executor.layers.sampler import (Sampler,
7+
SamplerOutput,
8+
_apply_min_tokens_penalty,
9+
_apply_min_p,
10+
_sample,
11+
SampleResultArgsType,
12+
get_logprobs,
13+
_build_sampler_output)
14+
from vllm.model_executor.sampling_metadata import SamplingMetadata
15+
from vllm_ascend.sample.ops.penalties import apply_penalties
16+
17+
18+
class AscendSampler(Sampler):
19+
20+
def __init__(self):
21+
super().__init__()
22+
23+
def forward(
24+
self,
25+
logits: torch.Tensor,
26+
sampling_metadata: SamplingMetadata,
27+
) -> Optional[SamplerOutput]:
28+
assert logits is not None
29+
_, vocab_size = logits.shape
30+
31+
# Prepare sampling tensors with pinned memory to avoid blocking.
32+
if not sampling_metadata.reuse_sampling_tensors:
33+
self._init_sampling_tensors(logits, sampling_metadata)
34+
elif self._do_penalties:
35+
# In this case, the sampling tensors logic depends on
36+
# "output_tokens" of a sequence. As a result, we cannot
37+
# reuse sampling tensors, since "output_tokens" changes
38+
# between decode runs.
39+
self._init_sampling_tensors(logits, sampling_metadata)
40+
41+
assert self._sampling_tensors is not None
42+
sampling_tensors = self._sampling_tensors
43+
do_penalties = self._do_penalties
44+
do_top_p_top_k = self._do_top_p_top_k
45+
do_min_p = self._do_min_p
46+
47+
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
48+
49+
# Apply presence and frequency penalties.
50+
if do_penalties:
51+
logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
52+
sampling_tensors.output_tokens,
53+
sampling_tensors.presence_penalties,
54+
sampling_tensors.frequency_penalties,
55+
sampling_tensors.repetition_penalties)
56+
57+
# Use float32 to apply temperature scaling.
58+
# Use in-place division to avoid creating a new tensor.
59+
logits = logits.to(torch.float)
60+
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
61+
62+
if do_top_p_top_k:
63+
logits = _apply_top_k_top_p_npu(logits, sampling_tensors.top_ps,
64+
sampling_tensors.top_ks)
65+
66+
if do_min_p:
67+
logits = _apply_min_p(logits, sampling_tensors.min_ps)
68+
69+
# We use float32 for probabilities and log probabilities.
70+
# Compute the probabilities.
71+
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
72+
# Compute the log probabilities.
73+
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
74+
75+
# Sample the next tokens.
76+
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
77+
probs,
78+
logprobs,
79+
sampling_metadata,
80+
sampling_tensors,
81+
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
82+
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
83+
)
84+
85+
if self.include_gpu_probs_tensor:
86+
# Since we will defer sampler result Pythonization,
87+
# preserve GPU-side tensors in support of later
88+
# deferred pythonization of logprobs
89+
assert maybe_sampled_tokens_tensor is not None
90+
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
91+
else:
92+
# Since Pythonization has already happened, don't preserve
93+
# GPU-side tensors.
94+
on_device_tensors = None
95+
96+
# Get the logprobs query results.
97+
prompt_logprobs = None
98+
sample_logprobs = None
99+
if not sampling_metadata.skip_sampler_cpu_output:
100+
# Pythonize logprobs now (GPU -> CPU); do not defer.
101+
assert not isinstance(maybe_deferred_sample_results,
102+
SampleResultArgsType)
103+
prompt_logprobs, sample_logprobs = get_logprobs(
104+
logprobs, sampling_metadata, maybe_deferred_sample_results)
105+
106+
return _build_sampler_output(
107+
maybe_deferred_sample_results,
108+
sampling_metadata,
109+
prompt_logprobs,
110+
sample_logprobs,
111+
on_device_tensors=on_device_tensors,
112+
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
113+
114+
115+
def _apply_top_k_top_p_npu(
116+
logits: torch.Tensor,
117+
p: torch.Tensor,
118+
k: torch.Tensor,
119+
) -> torch.Tensor:
120+
"""Apply top-k and top-p optimized for NPU.
121+
122+
This algorithm avoids using torch.scatter which is time-consuming on NPU.
123+
"""
124+
# TODO(linfeng): consider the case taht either p or k is applied
125+
batch_size, vocab_size = logits.shape
126+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
127+
128+
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
129+
top_k_mask = logits_sort < boundary
130+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
131+
cutoff = top_k_mask.sum(dim=-1).min()
132+
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
133+
probs_sum = probs_sort.cumsum(dim=-1)
134+
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
135+
top_p_mask[:, -1] = True
136+
strides = torch.arange(0, batch_size*vocab_size, vocab_size, device=logits.device)
137+
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
138+
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
139+
logits_flatten = logits.flatten()
140+
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
141+
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
142+
logits[valid_idx] = valid_logits
143+
return logits.reshape(batch_size, vocab_size)

vllm_ascend/sample/sampler_v1.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
from vllm.v1.sample.sampler import Sampler
3+
from vllm.v1.sample.metadata import SamplingMetadata
4+
from vllm.v1.sample.ops.penalties import apply_min_token_penalties
5+
from vllm.logger import init_logger
6+
from vllm_ascend.sample.ops.ascend_topk_topp_sampler import AscendTopKTopPSampler
7+
from vllm_ascend.sample.ops.penalties import apply_all_penalties
8+
9+
10+
logger = init_logger(__name__)
11+
12+
13+
class AscendSampler(Sampler):
14+
15+
def __init__(self):
16+
super().__init__()
17+
self.topk_topp_sampler = AscendTopKTopPSampler()
18+
19+
def apply_penalties(
20+
self,
21+
logits: torch.Tensor,
22+
sampling_metadata: SamplingMetadata,
23+
) -> torch.Tensor:
24+
if sampling_metadata.min_tokens:
25+
apply_min_token_penalties(logits,
26+
sampling_metadata.output_token_ids,
27+
sampling_metadata.min_tokens)
28+
if not sampling_metadata.no_penalties:
29+
assert sampling_metadata.prompt_token_ids is not None
30+
logits = apply_all_penalties(
31+
logits,
32+
sampling_metadata.prompt_token_ids,
33+
sampling_metadata.presence_penalties,
34+
sampling_metadata.frequency_penalties,
35+
sampling_metadata.repetition_penalties,
36+
sampling_metadata.output_token_ids,
37+
)
38+
return logits

vllm_ascend/worker/model_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
_add_sampling_metadata_broadcastable_dict,
6161
_init_attn_metadata_from_tensor_dict,
6262
_init_sampling_metadata_from_tensor_dict)
63+
from vllm_ascend.sample.sampler import AscendSampler
6364

6465
if TYPE_CHECKING:
6566
from vllm.attention.backends.abstract import AttentionBackend
@@ -820,7 +821,12 @@ def load_model(self) -> None:
820821
logger.info("Starting to load model %s...", self.model_config.model)
821822
with DeviceMemoryProfiler() as m:
822823
self.model = get_model(vllm_config=self.vllm_config)
823-
824+
# Same options with those in model_runner_v1.py
825+
# option 1
826+
if hasattr(self.model, "sampler"):
827+
self.model.sampler = AscendSampler()
828+
# option 2
829+
# self.model = NPUModelWrapperV1(model)
824830
self.model_memory_usage = m.consumed_memory
825831
logger.info("Loading model weights took %.4f GB",
826832
self.model_memory_usage / float(2**30))

vllm_ascend/worker/model_runner_v1.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
from vllm.inputs import INPUT_REGISTRY
3434
from vllm.logger import logger
3535
from vllm.model_executor.layers.fused_moe import FusedMoE
36+
from vllm.model_executor.layers.sampler import sampler_output
3637
from vllm.model_executor.model_loader import get_model
38+
from vllm.model_executor.sampling_metadata import SamplingMetadata
3739
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
3840
from vllm.platforms import current_platform
3941
from vllm.sampling_params import SamplingType
@@ -52,6 +54,7 @@
5254
from vllm_ascend.attention.attention import AttentionMaskBuilder
5355
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
5456
AscendMetadata)
57+
from vllm_ascend.sample.sampler_v1 import AscendSampler
5558

5659
if TYPE_CHECKING:
5760
from vllm.v1.core.scheduler_output import SchedulerOutput
@@ -810,6 +813,12 @@ def load_model(self) -> None:
810813

811814
with DeviceMemoryProfiler() as m: # noqa: SIM117
812815
self.model = get_model(vllm_config=self.vllm_config)
816+
# option 1
817+
if hasattr(self.model, "sampler"):
818+
self.model.sampler = AscendSampler()
819+
# option 2
820+
# self.model = NPUModelWrapperV1(model)
821+
813822
if self.lora_config:
814823
raise ValueError("LoRA model is not supported on NPU now.")
815824

@@ -889,3 +898,25 @@ def get_kv_cache_spec(self) -> KVCacheSpec:
889898
f"Unknown attention type: {attn_module.attn_type}")
890899

891900
return kv_cache_spec
901+
902+
# class NPUModelWrapperV1(nn.Module):
903+
904+
# def __init__(self, model: nn.Module):
905+
# super().__init__()
906+
# self._model = model
907+
# self.sampler = AscendSampler()
908+
909+
# def __getattr__(self, name):
910+
# return getattr(self._model, name)
911+
912+
# def sample(
913+
# self,
914+
# logits: Optional[torch.Tensor],
915+
# sampling_metadata: SamplingMetadata,
916+
# ) -> Optional[SamplerOutput]:
917+
# next_tokens = self.sampler(logits, sampling_metadata)
918+
# return next_tokens
919+
920+
# def forward():
921+
# # necessary if using wrapper class
922+
# pass

0 commit comments

Comments
 (0)