Skip to content

Commit d377ba3

Browse files
committed
refactor: support scenarios where top_p or top_k is None
Signed-off-by: linfeng-yuan <1102311262@qq.com>
1 parent b55ffca commit d377ba3

File tree

6 files changed

+63
-102
lines changed

6 files changed

+63
-102
lines changed
Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,12 @@
11
from typing import Dict, Optional
22

33
import torch
4-
import torch.nn as nn
5-
64
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
7-
from vllm.logger import init_logger
8-
9-
10-
logger = init_logger(__name__)
115

126

137
class AscendTopKTopPSampler(TopKTopPSampler):
148

15-
def __init__(self):
16-
super().__init__()
17-
# TODO(linfeng): eliminate warning for FlashInfer here
18-
self.forward = self.forward_npu
19-
20-
def forward_npu(
9+
def forward_native(
2110
self,
2211
logits: torch.Tensor,
2312
generators: Dict[int, torch.Generator],
@@ -28,37 +17,48 @@ def forward_npu(
2817
logits = apply_top_k_top_p_npu(logits, k, p)
2918
probs = logits.softmax(dim=-1, dtype=torch.float32)
3019
return random_sample(probs, generators)
31-
20+
3221

3322
def apply_top_k_top_p_npu(
3423
logits: torch.Tensor,
3524
k: Optional[torch.Tensor],
3625
p: Optional[torch.Tensor],
3726
) -> torch.Tensor:
38-
"""Apply top-k and top-p optimized for NPU.
39-
40-
This algorithm avoids using torch.scatter which is time-consuming on NPU.
41-
"""
42-
# TODO(linfeng): consider the case taht either p or k is applied
27+
"""Apply top-k and/or top-p optimized for NPU."""
4328
if k is None and p is None:
4429
return logits
30+
4531
batch_size, vocab_size = logits.shape
32+
device = logits.device
4633
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
34+
if k is not None:
35+
safe_k = torch.clamp(k, min=1, max=vocab_size)
36+
boundary_idx = (vocab_size - safe_k).unsqueeze(1)
37+
boundary = logits_sort.gather(1, boundary_idx)
38+
top_k_mask = logits_sort < boundary
39+
logits_sort = logits_sort.masked_fill(top_k_mask, -float("inf"))
40+
else:
41+
top_k_mask = torch.zeros_like(logits_sort, dtype=torch.bool)
4742

48-
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
49-
top_k_mask = logits_sort < boundary
50-
logits_sort.masked_fill_(top_k_mask, -float("inf"))
51-
cutoff = top_k_mask.sum(dim=-1).min()
52-
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
53-
probs_sum = probs_sort.cumsum(dim=-1)
54-
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
55-
top_p_mask[:, -1] = True
56-
strides = torch.arange(0, batch_size*vocab_size, vocab_size, device=logits.device)
57-
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
58-
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
43+
cutoffs = top_k_mask.sum(dim=-1)
44+
strides = torch.arange(0,
45+
batch_size * vocab_size,
46+
vocab_size,
47+
device=device).unsqueeze(1)
48+
if p is not None:
49+
global_cutoff = cutoffs.min()
50+
active_part = logits_idx[:, global_cutoff:]
51+
probs_sort = logits_sort[:, global_cutoff:].softmax(dim=-1)
52+
cumprob = probs_sort.cumsum(dim=-1)
53+
top_p_mask = (cumprob <= (1 - p.unsqueeze(1))) | (torch.arange(
54+
probs_sort.size(1), device=device) == probs_sort.size(1) - 1)
55+
else:
56+
active_part = logits_idx
57+
top_p_mask = torch.arange(vocab_size, device=device).expand(
58+
batch_size, -1) >= cutoffs.unsqueeze(1)
5959

60+
valid_idx = (active_part + strides).masked_select(top_p_mask)
6061
logits_flatten = logits.flatten()
61-
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
62-
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
63-
logits[valid_idx] = valid_logits
64-
return logits.reshape(batch_size, vocab_size)
62+
output = torch.full_like(logits_flatten, -float('inf'))
63+
output[valid_idx] = logits_flatten[valid_idx]
64+
return output.reshape(batch_size, vocab_size)

vllm_ascend/sample/ops/penalties.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import torch
4-
5-
from vllm.v1.sample.ops.penalties import _convert_to_tensors
64
from vllm.model_executor.layers.utils import get_token_bin_counts_and_mask
5+
from vllm.v1.sample.ops.penalties import _convert_to_tensors
76

87

98
def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
@@ -31,23 +30,25 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
3130
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
3231
output_tokens_tensor, vocab_size, num_seqs)
3332

34-
3533
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
3634
1, vocab_size)
37-
35+
3836
# Avoid IndexPut operations in original apply_penalties function which are extremely time-consuming on NPU.
3937
sequence_mask = prompt_mask | output_mask
40-
logits = torch.where(sequence_mask & torch.lt(logits, 0), logits * repetition_penalties,
41-
logits).to(logits.dtype)
42-
logits = torch.where(sequence_mask & torch.ge(logits, 0), logits / repetition_penalties,
43-
logits).to(logits.dtype)
38+
logits = torch.where(sequence_mask & torch.lt(logits, 0),
39+
logits * repetition_penalties,
40+
logits).to(logits.dtype)
41+
logits = torch.where(sequence_mask & torch.ge(logits, 0),
42+
logits / repetition_penalties,
43+
logits).to(logits.dtype)
4444

4545
# We follow the definition in OpenAI API.
4646
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
4747
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
4848
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
4949
return logits
5050

51+
5152
def apply_all_penalties(
5253
logits: torch.Tensor,
5354
prompt_token_ids: torch.Tensor,
@@ -64,4 +65,4 @@ def apply_all_penalties(
6465
logits.device)
6566
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
6667
presence_penalties, frequency_penalties,
67-
repetition_penalties)
68+
repetition_penalties)

vllm_ascend/sample/sampler.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,13 @@
33
from typing import Optional
44

55
import torch
6-
from vllm.model_executor.layers.sampler import (Sampler,
7-
SamplerOutput,
8-
_apply_min_tokens_penalty,
9-
_apply_min_p,
10-
_sample,
11-
SampleResultArgsType,
12-
get_logprobs,
13-
_build_sampler_output)
6+
from vllm.model_executor.layers.sampler import (Sampler, SampleResultArgsType,
7+
SamplerOutput, _apply_min_p,
8+
_apply_min_tokens_penalty,
9+
_build_sampler_output, _sample,
10+
get_logprobs)
1411
from vllm.model_executor.sampling_metadata import SamplingMetadata
12+
1513
from vllm_ascend.sample.ops.penalties import apply_penalties
1614

1715

@@ -61,7 +59,7 @@ def forward(
6159

6260
if do_top_p_top_k:
6361
logits = _apply_top_k_top_p_npu(logits, sampling_tensors.top_ps,
64-
sampling_tensors.top_ks)
62+
sampling_tensors.top_ks)
6563

6664
if do_min_p:
6765
logits = _apply_min_p(logits, sampling_tensors.min_ps)
@@ -83,21 +81,15 @@ def forward(
8381
)
8482

8583
if self.include_gpu_probs_tensor:
86-
# Since we will defer sampler result Pythonization,
87-
# preserve GPU-side tensors in support of later
88-
# deferred pythonization of logprobs
8984
assert maybe_sampled_tokens_tensor is not None
9085
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
9186
else:
92-
# Since Pythonization has already happened, don't preserve
93-
# GPU-side tensors.
9487
on_device_tensors = None
9588

9689
# Get the logprobs query results.
9790
prompt_logprobs = None
9891
sample_logprobs = None
9992
if not sampling_metadata.skip_sampler_cpu_output:
100-
# Pythonize logprobs now (GPU -> CPU); do not defer.
10193
assert not isinstance(maybe_deferred_sample_results,
10294
SampleResultArgsType)
10395
prompt_logprobs, sample_logprobs = get_logprobs(
@@ -121,10 +113,9 @@ def _apply_top_k_top_p_npu(
121113
122114
This algorithm avoids using torch.scatter which is time-consuming on NPU.
123115
"""
124-
# TODO(linfeng): consider the case taht either p or k is applied
125116
batch_size, vocab_size = logits.shape
126117
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
127-
118+
128119
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
129120
top_k_mask = logits_sort < boundary
130121
logits_sort.masked_fill_(top_k_mask, -float("inf"))
@@ -133,7 +124,10 @@ def _apply_top_k_top_p_npu(
133124
probs_sum = probs_sort.cumsum(dim=-1)
134125
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
135126
top_p_mask[:, -1] = True
136-
strides = torch.arange(0, batch_size*vocab_size, vocab_size, device=logits.device)
127+
strides = torch.arange(0,
128+
batch_size * vocab_size,
129+
vocab_size,
130+
device=logits.device)
137131
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
138132
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
139133
logits_flatten = logits.flatten()

vllm_ascend/sample/sampler_v1.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import torch
2-
from vllm.v1.sample.sampler import Sampler
32
from vllm.v1.sample.metadata import SamplingMetadata
43
from vllm.v1.sample.ops.penalties import apply_min_token_penalties
5-
from vllm.logger import init_logger
6-
from vllm_ascend.sample.ops.ascend_topk_topp_sampler import AscendTopKTopPSampler
7-
from vllm_ascend.sample.ops.penalties import apply_all_penalties
8-
4+
from vllm.v1.sample.sampler import Sampler
95

10-
logger = init_logger(__name__)
6+
from vllm_ascend.sample.ops.ascend_topk_topp_sampler import \
7+
AscendTopKTopPSampler
8+
from vllm_ascend.sample.ops.penalties import apply_all_penalties
119

1210

1311
class AscendSampler(Sampler):
@@ -35,4 +33,4 @@ def apply_penalties(
3533
sampling_metadata.repetition_penalties,
3634
sampling_metadata.output_token_ids,
3735
)
38-
return logits
36+
return logits

vllm_ascend/worker/model_runner.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
_add_sampling_metadata_broadcastable_dict,
6161
_init_attn_metadata_from_tensor_dict,
6262
_init_sampling_metadata_from_tensor_dict)
63+
6364
from vllm_ascend.sample.sampler import AscendSampler
6465

6566
if TYPE_CHECKING:
@@ -821,12 +822,7 @@ def load_model(self) -> None:
821822
logger.info("Starting to load model %s...", self.model_config.model)
822823
with DeviceMemoryProfiler() as m:
823824
self.model = get_model(vllm_config=self.vllm_config)
824-
# Same options with those in model_runner_v1.py
825-
# option 1
826-
if hasattr(self.model, "sampler"):
827-
self.model.sampler = AscendSampler()
828-
# option 2
829-
# self.model = NPUModelWrapperV1(model)
825+
self.model.sampler = AscendSampler()
830826
self.model_memory_usage = m.consumed_memory
831827
logger.info("Loading model weights took %.4f GB",
832828
self.model_memory_usage / float(2**30))

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@
3333
from vllm.inputs import INPUT_REGISTRY
3434
from vllm.logger import logger
3535
from vllm.model_executor.layers.fused_moe import FusedMoE
36-
from vllm.model_executor.layers.sampler import sampler_output
3736
from vllm.model_executor.model_loader import get_model
38-
from vllm.model_executor.sampling_metadata import SamplingMetadata
3937
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
4038
from vllm.platforms import current_platform
4139
from vllm.sampling_params import SamplingType
@@ -813,11 +811,7 @@ def load_model(self) -> None:
813811

814812
with DeviceMemoryProfiler() as m: # noqa: SIM117
815813
self.model = get_model(vllm_config=self.vllm_config)
816-
# option 1
817-
if hasattr(self.model, "sampler"):
818-
self.model.sampler = AscendSampler()
819-
# option 2
820-
# self.model = NPUModelWrapperV1(model)
814+
self.model.sampler = AscendSampler()
821815

822816
if self.lora_config:
823817
raise ValueError("LoRA model is not supported on NPU now.")
@@ -898,25 +892,3 @@ def get_kv_cache_spec(self) -> KVCacheSpec:
898892
f"Unknown attention type: {attn_module.attn_type}")
899893

900894
return kv_cache_spec
901-
902-
# class NPUModelWrapperV1(nn.Module):
903-
904-
# def __init__(self, model: nn.Module):
905-
# super().__init__()
906-
# self._model = model
907-
# self.sampler = AscendSampler()
908-
909-
# def __getattr__(self, name):
910-
# return getattr(self._model, name)
911-
912-
# def sample(
913-
# self,
914-
# logits: Optional[torch.Tensor],
915-
# sampling_metadata: SamplingMetadata,
916-
# ) -> Optional[SamplerOutput]:
917-
# next_tokens = self.sampler(logits, sampling_metadata)
918-
# return next_tokens
919-
920-
# def forward():
921-
# # necessary if using wrapper class
922-
# pass

0 commit comments

Comments
 (0)