Skip to content

Commit 07c6282

Browse files
author
wangxiaoxin (A)
committed
add optimze of dsv3.
Signed-off-by: wangxiaoxin (A) <w00664509@china.huawei.com>
1 parent 5a1689f commit 07c6282

File tree

8 files changed

+257
-3
lines changed

8 files changed

+257
-3
lines changed

tests/e2e/doctests/001-quickstart-test.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,17 @@ function quickstart_online_test() {
4747
wait_for_exit "$VLLM_PID"
4848
}
4949

50+
function quickstart_offline_test_topk() {
51+
export VLLM_ENABLE_TOPK_OPTIMZE=1
52+
# Do real script test
53+
python3 "${SCRIPT_DIR}/../../examples/offline_inference_npu.py"
54+
}
55+
5056
_info "====> Start simple_test"
5157
simple_test
5258
_info "====> Start quickstart_offline_test"
5359
quickstart_offline_test
5460
_info "====> Start quickstart_online_test"
5561
quickstart_online_test
62+
_info "====> Start quickstart_offline_test_topk"
63+
quickstart_offline_test_topk

tests/singlecard/test_sampler.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm/tests/entrypoints/llm/test_guided_generate.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
from typing import Optional
20+
21+
import torch
22+
23+
from vllm.v1.sample.ops.topk_topp_sampler import \
24+
apply_top_k_top_p # noqa: F401
25+
from vllm.v1.sample.sampler import Sampler # noqa: F401
26+
27+
# Set tolerance to 1 for quant ops
28+
DEFAULT_ATOL = 1e-3
29+
DEFAULT_RTOL = 1e-3
30+
31+
32+
def apply_min_p_new(
33+
logits: torch.Tensor,
34+
min_p: torch.Tensor,
35+
) -> torch.Tensor:
36+
"""
37+
Filters logits using adaptive probability thresholding.
38+
"""
39+
if min_p == 0:
40+
return logits
41+
# Convert logits to probability distribution
42+
probability_values = torch.nn.functional.softmax(logits, dim=-1)
43+
# Calculate maximum probabilities per sequence
44+
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
45+
# Reshape min_p for broadcasting
46+
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
47+
# Identify valid tokens using threshold comparison
48+
# Apply mask using boolean indexing
49+
logits = logits.masked_fill(probability_values < adjusted_min_p,
50+
-float('inf'))
51+
return logits
52+
53+
54+
def apply_top_k_top_p_new(
55+
logits: torch.Tensor,
56+
k: Optional[torch.Tensor],
57+
p: Optional[torch.Tensor],
58+
) -> torch.Tensor:
59+
batch_size, vocab_size = logits.shape
60+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
61+
62+
# Apply top-k.
63+
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
64+
top_k_mask = logits_sort < boundary
65+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
66+
67+
if p is not None:
68+
# Apply top-p.
69+
cutoff = top_k_mask.sum(dim=-1).min()
70+
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
71+
probs_sum = probs_sort.cumsum(dim=-1)
72+
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
73+
top_p_mask[:, -1] = True
74+
strides = torch.arange(0,
75+
batch_size * vocab_size,
76+
vocab_size,
77+
device=logits.device)
78+
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
79+
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
80+
logits_flatten = logits.flatten()
81+
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
82+
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
83+
logits[valid_idx] = valid_logits
84+
return logits.reshape(batch_size, vocab_size)
85+
86+
87+
# test with leading dimension and merge seqlen and batch_size as num_tokens
88+
@torch.inference_mode()
89+
def test_apply_min_p() -> None:
90+
logits = torch.randn((128, 7168)).npu()
91+
min_p = torch.Tensor([0.01]).npu()
92+
logits_new = apply_min_p_new(logits, min_p)
93+
sampler = Sampler()
94+
logits_old = sampler.apply_min_p(logits, min_p)
95+
# Compare the results.
96+
torch.testing.assert_close(logits_new,
97+
logits_old,
98+
atol=DEFAULT_ATOL,
99+
rtol=DEFAULT_RTOL)
100+
101+
102+
# test with leading dimension and merge seqlen and batch_size as num_tokens
103+
@torch.inference_mode()
104+
def test_apply_top_k_top_p() -> None:
105+
logits = torch.randn((128, 7168)).npu()
106+
k = torch.Tensor([-1]).int().npu()
107+
p = torch.Tensor([1]).int().npu()
108+
logits_new = apply_top_k_top_p_new(logits, k, p)
109+
logits_old = apply_top_k_top_p(logits, k, p)
110+
# Compare the results.
111+
torch.testing.assert_close(logits_new,
112+
logits_old,
113+
atol=DEFAULT_ATOL,
114+
rtol=DEFAULT_RTOL)

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
3737
"VLLM_ENABLE_MC2":
3838
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
39+
"VLLM_ENABLE_TOPK_OPTIMZE":
40+
lambda: bool(int(os.getenv("VLLM_ENABLE_TOPK_OPTIMZE", '0'))),
3941
"USING_LCCL_COM":
4042
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
4143
"SOC_VERSION":

vllm_ascend/models/deepseek_v2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
225225
enable_force_load_balance = False
226226
num_tokens, hidden_dim = hidden_states.shape
227227

228-
if self.n_shared_experts is not None:
229-
shared_output = self.shared_experts(hidden_states)
228+
old_hidden_states = hidden_states.detach()
230229

231230
if self.tp_size > 1:
232231
# pass
@@ -265,6 +264,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
265264
else:
266265
final_hidden_states = router_hidden_states
267266

267+
if self.n_shared_experts is not None:
268+
shared_output = self.shared_experts(old_hidden_states)
269+
268270
if shared_output is not None:
269271
final_hidden_states = final_hidden_states + shared_output
270272

vllm_ascend/ops/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def fused_experts(
363363
num_experts)).to(topk_ids.dtype)
364364

365365
# Sort by local expert IDs
366-
sort_indices = torch.argsort(filtered_experts)
366+
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
367367
sorted_token_indices = token_indices[sort_indices]
368368
sorted_weights = filtered_weights[sort_indices]
369369

vllm_ascend/patch/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,30 @@
166166
# Future Plan:
167167
# Revert it when the ascend support triton kernel.
168168
#
169+
# ** File: v1/sample/sampler.py **
170+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
171+
# 1. `vllm.v1.sample.sampler.Sampler.apply_top_k_top_p`
172+
# Why:
173+
# We need to use the patched `apply_top_k_top_p` in `sample`.
174+
# The mainly reason to overwrite `apply_top_k_top_p` is
175+
# to improve performance.
176+
# How:
177+
# Re-implementation the `apply_top_k_top_p` function by pytorch
178+
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
179+
# - https://github.com/vllm-project/vllm-ascend/pull/970
180+
# Future Plan:
181+
# Revert it when the ascend scatter performance improves.
182+
#
183+
# ** File: v1/sample/sampler.py **
184+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~s
185+
# 1. `vllm.v1.sample.sampler.Sampler.apply_min_p`
186+
# Why:
187+
# We need to use the patched `apply_min_p` in `sample`.
188+
# The mainly reason to overwrite `apply_min_p` is
189+
# to improve performance.
190+
# How:
191+
# Re-implementation the `apply_min_p` function by pytorch
192+
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
193+
# - https://github.com/vllm-project/vllm-ascend/pull/970
194+
# Future Plan:
195+
# Revert it when the ascend indexput performance improves.

vllm_ascend/patch/worker/patch_common/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@
2323
import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa
2424
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
2525
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
26+
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
2627
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import Optional
19+
20+
import torch
21+
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
22+
from vllm.v1.sample.sampler import Sampler
23+
24+
from vllm_ascend import envs
25+
26+
27+
def apply_min_p(
28+
self,
29+
logits: torch.Tensor,
30+
min_p: torch.Tensor,
31+
) -> torch.Tensor:
32+
"""
33+
Filters logits using adaptive probability thresholding.
34+
"""
35+
# Convert logits to probability distribution
36+
probability_values = torch.nn.functional.softmax(logits, dim=-1)
37+
# Calculate maximum probabilities per sequence
38+
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
39+
# Reshape min_p for broadcasting
40+
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
41+
# Identify valid tokens using threshold comparison
42+
# Apply mask using boolean indexing
43+
logits = logits.masked_fill(probability_values < adjusted_min_p,
44+
-float('inf'))
45+
return logits
46+
47+
48+
def _apply_top_k_top_p(
49+
logits: torch.Tensor,
50+
p: torch.Tensor,
51+
k: torch.Tensor,
52+
) -> torch.Tensor:
53+
batch_size, vocab_size = logits.shape
54+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
55+
56+
# Apply top-k.
57+
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
58+
top_k_mask = logits_sort < boundary
59+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
60+
61+
# Apply top-p.
62+
cutoff = top_k_mask.sum(dim=-1).min()
63+
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
64+
probs_sum = probs_sort.cumsum(dim=-1)
65+
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
66+
67+
top_p_mask[:, -1] = True
68+
strides = torch.arange(0,
69+
batch_size * vocab_size,
70+
vocab_size,
71+
device=logits.device)
72+
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
73+
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
74+
logits_flatten = logits.flatten()
75+
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
76+
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
77+
logits[valid_idx] = valid_logits
78+
return logits.reshape(batch_size, vocab_size)
79+
80+
81+
def topk_topp_forward_native(
82+
self,
83+
logits: torch.Tensor,
84+
generators: dict[int, torch.Generator],
85+
k: Optional[torch.Tensor],
86+
p: Optional[torch.Tensor],
87+
) -> torch.Tensor:
88+
"""
89+
PyTorch-native implementation of top-k and top-p sampling.
90+
91+
The logits tensor may be updated in-place.
92+
"""
93+
logits = _apply_top_k_top_p(logits, k, p)
94+
probs = logits.softmax(dim=-1, dtype=torch.float32)
95+
return random_sample(probs, generators)
96+
97+
98+
Sampler.apply_min_p = apply_min_p
99+
if envs.VLLM_ENABLE_TOPK_OPTIMZE:
100+
TopKTopPSampler.forward_native = topk_topp_forward_native

0 commit comments

Comments
 (0)