Skip to content

Commit 7940e8e

Browse files
wangxiaoxin (A)wangxiaoxin-sherie
authored andcommitted
add optimze of dsv3.
Signed-off-by: wangxiaoxin (A) <w00664509@china.huawei.com> Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
1 parent 5a1689f commit 7940e8e

File tree

8 files changed

+315
-3
lines changed

8 files changed

+315
-3
lines changed

tests/singlecard/test_offline_inference.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import pytest
2626
import vllm # noqa: F401
27+
from vllm import SamplingParams
2728
from vllm.assets.image import ImageAsset
2829

2930
import vllm_ascend # noqa: F401
@@ -81,3 +82,27 @@ def test_multimodal(model, prompt_template, vllm_runner):
8182
vllm_model.generate_greedy(prompts=prompts,
8283
images=images,
8384
max_tokens=64)
85+
86+
87+
@pytest.mark.parametrize("model", MODELS)
88+
@pytest.mark.parametrize("dtype", ["half", "float16"])
89+
@pytest.mark.parametrize("max_tokens", [5])
90+
def test_models_topk(model: str, dtype: str, max_tokens: int) -> None:
91+
os.environ["VLLM_ENABLE_TOPK_OPTIMZE"] = "1"
92+
example_prompts = [
93+
"Hello, my name is",
94+
"The president of the United States is",
95+
"The capital of France is",
96+
"The future of AI is",
97+
]
98+
sampling_params = SamplingParams(max_tokens=max_tokens,
99+
temperature=0.0,
100+
top_k=50,
101+
top_p=0.9)
102+
103+
with VllmRunner(model,
104+
max_model_len=8192,
105+
dtype=dtype,
106+
enforce_eager=True,
107+
gpu_memory_utilization=0.7) as vllm_model:
108+
vllm_model.generate(example_prompts, sampling_params)

tests/singlecard/test_sampler.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm/tests/entrypoints/llm/test_guided_generate.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
from typing import Optional
20+
21+
import torch
22+
23+
from vllm.v1.sample.sampler import Sampler # noqa: F401
24+
25+
# Set tolerance to 1 for quant ops
26+
DEFAULT_ATOL = 1e-3
27+
DEFAULT_RTOL = 1e-3
28+
29+
30+
def apply_min_p_new(
31+
logits: torch.Tensor,
32+
min_p: torch.Tensor,
33+
) -> torch.Tensor:
34+
"""
35+
Filters logits using adaptive probability thresholding.
36+
"""
37+
if min_p == 0:
38+
return logits
39+
# Convert logits to probability distribution
40+
probability_values = torch.nn.functional.softmax(logits, dim=-1)
41+
# Calculate maximum probabilities per sequence
42+
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
43+
# Reshape min_p for broadcasting
44+
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
45+
# Identify valid tokens using threshold comparison
46+
# Apply mask using boolean indexing
47+
logits = logits.masked_fill(probability_values < adjusted_min_p,
48+
-float('inf'))
49+
return logits
50+
51+
52+
def apply_top_k_top_p(
53+
logits: torch.Tensor,
54+
k: Optional[torch.Tensor],
55+
p: Optional[torch.Tensor],
56+
) -> torch.Tensor:
57+
"""Apply top-k and top-p masks to the logits.
58+
59+
If a top-p is used, this function will sort the logits tensor,
60+
which can be slow for large batches.
61+
62+
The logits tensor may be updated in-place.
63+
"""
64+
if p is None:
65+
if k is None:
66+
return logits
67+
68+
# Avoid sorting vocab for top-k only case.
69+
return apply_top_k_only(logits, k)
70+
71+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
72+
73+
if k is not None:
74+
# Apply top-k.
75+
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
76+
# Get all the top_k values.
77+
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
78+
top_k_mask = logits_sort < top_k_mask
79+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
80+
81+
if p is not None:
82+
# Apply top-p.
83+
probs_sort = logits_sort.softmax(dim=-1)
84+
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
85+
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
86+
# at least one
87+
top_p_mask[:, -1] = False
88+
logits_sort.masked_fill_(top_p_mask, -float("inf"))
89+
90+
# Re-sort the probabilities.
91+
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
92+
return logits
93+
94+
95+
def apply_top_k_top_p_new(
96+
logits: torch.Tensor,
97+
k: Optional[torch.Tensor],
98+
p: Optional[torch.Tensor],
99+
) -> torch.Tensor:
100+
batch_size, vocab_size = logits.shape
101+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
102+
103+
# Apply top-k.
104+
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
105+
top_k_mask = logits_sort < boundary
106+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
107+
108+
if p is not None:
109+
# Apply top-p.
110+
cutoff = top_k_mask.sum(dim=-1).min()
111+
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
112+
probs_sum = probs_sort.cumsum(dim=-1)
113+
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
114+
top_p_mask[:, -1] = True
115+
strides = torch.arange(0,
116+
batch_size * vocab_size,
117+
vocab_size,
118+
device=logits.device)
119+
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
120+
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
121+
logits_flatten = logits.flatten()
122+
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
123+
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
124+
logits[valid_idx] = valid_logits
125+
return logits.reshape(batch_size, vocab_size)
126+
127+
128+
# test with leading dimension and merge seqlen and batch_size as num_tokens
129+
@torch.inference_mode()
130+
def test_apply_min_p() -> None:
131+
logits = torch.randn((128, 7168)).npu()
132+
min_p = torch.Tensor([0.01]).npu()
133+
logits_new = apply_min_p_new(logits, min_p)
134+
sampler = Sampler()
135+
logits_old = sampler.apply_min_p(logits, min_p)
136+
# Compare the results.
137+
torch.testing.assert_close(logits_new,
138+
logits_old,
139+
atol=DEFAULT_ATOL,
140+
rtol=DEFAULT_RTOL)
141+
142+
143+
# test with leading dimension and merge seqlen and batch_size as num_tokens
144+
@torch.inference_mode()
145+
def test_apply_top_k_top_p() -> None:
146+
logits = torch.randn((128, 7168)).npu()
147+
k = torch.Tensor([-1]).int().npu()
148+
p = torch.Tensor([1]).int().npu()
149+
logits_new = apply_top_k_top_p_new(logits, k, p)
150+
logits_old = apply_top_k_top_p(logits, k, p)
151+
# Compare the results.
152+
torch.testing.assert_close(logits_new,
153+
logits_old,
154+
atol=DEFAULT_ATOL,
155+
rtol=DEFAULT_RTOL)

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
3737
"VLLM_ENABLE_MC2":
3838
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
39+
"VLLM_ENABLE_TOPK_OPTIMZE":
40+
lambda: bool(int(os.getenv("VLLM_ENABLE_TOPK_OPTIMZE", '0'))),
3941
"USING_LCCL_COM":
4042
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
4143
"SOC_VERSION":

vllm_ascend/models/deepseek_v2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
225225
enable_force_load_balance = False
226226
num_tokens, hidden_dim = hidden_states.shape
227227

228-
if self.n_shared_experts is not None:
229-
shared_output = self.shared_experts(hidden_states)
228+
old_hidden_states = hidden_states.detach()
230229

231230
if self.tp_size > 1:
232231
# pass
@@ -265,6 +264,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
265264
else:
266265
final_hidden_states = router_hidden_states
267266

267+
if self.n_shared_experts is not None:
268+
shared_output = self.shared_experts(old_hidden_states)
269+
268270
if shared_output is not None:
269271
final_hidden_states = final_hidden_states + shared_output
270272

vllm_ascend/ops/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def fused_experts(
363363
num_experts)).to(topk_ids.dtype)
364364

365365
# Sort by local expert IDs
366-
sort_indices = torch.argsort(filtered_experts)
366+
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
367367
sorted_token_indices = token_indices[sort_indices]
368368
sorted_weights = filtered_weights[sort_indices]
369369

vllm_ascend/patch/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,30 @@
166166
# Future Plan:
167167
# Revert it when the ascend support triton kernel.
168168
#
169+
# ** File: v1/sample/sampler.py **
170+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
171+
# 1. `vllm.v1.sample.sampler.Sampler.apply_top_k_top_p`
172+
# Why:
173+
# We need to use the patched `apply_top_k_top_p` in `sample`.
174+
# The mainly reason to overwrite `apply_top_k_top_p` is
175+
# to improve performance.
176+
# How:
177+
# Re-implementation the `apply_top_k_top_p` function by pytorch
178+
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
179+
# - https://github.com/vllm-project/vllm-ascend/pull/970
180+
# Future Plan:
181+
# Revert it when the ascend scatter performance improves.
182+
#
183+
# ** File: v1/sample/sampler.py **
184+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~s
185+
# 1. `vllm.v1.sample.sampler.Sampler.apply_min_p`
186+
# Why:
187+
# We need to use the patched `apply_min_p` in `sample`.
188+
# The mainly reason to overwrite `apply_min_p` is
189+
# to improve performance.
190+
# How:
191+
# Re-implementation the `apply_min_p` function by pytorch
192+
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
193+
# - https://github.com/vllm-project/vllm-ascend/pull/970
194+
# Future Plan:
195+
# Revert it when the ascend indexput performance improves.

vllm_ascend/patch/worker/patch_common/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@
2323
import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa
2424
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
2525
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
26+
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
2627
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import Optional
19+
20+
import torch
21+
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
22+
from vllm.v1.sample.sampler import Sampler
23+
24+
from vllm_ascend import envs
25+
26+
27+
def apply_min_p(
28+
self,
29+
logits: torch.Tensor,
30+
min_p: torch.Tensor,
31+
) -> torch.Tensor:
32+
"""
33+
Filters logits using adaptive probability thresholding.
34+
"""
35+
# Convert logits to probability distribution
36+
probability_values = torch.nn.functional.softmax(logits, dim=-1)
37+
# Calculate maximum probabilities per sequence
38+
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
39+
# Reshape min_p for broadcasting
40+
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
41+
# Identify valid tokens using threshold comparison
42+
# Apply mask using boolean indexing
43+
logits = logits.masked_fill(probability_values < adjusted_min_p,
44+
-float('inf'))
45+
return logits
46+
47+
48+
def _apply_top_k_top_p(
49+
logits: torch.Tensor,
50+
p: torch.Tensor,
51+
k: torch.Tensor,
52+
) -> torch.Tensor:
53+
batch_size, vocab_size = logits.shape
54+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
55+
56+
# Apply top-k.
57+
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
58+
top_k_mask = logits_sort < boundary
59+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
60+
61+
# Apply top-p.
62+
cutoff = top_k_mask.sum(dim=-1).min()
63+
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
64+
probs_sum = probs_sort.cumsum(dim=-1)
65+
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
66+
67+
top_p_mask[:, -1] = True
68+
strides = torch.arange(0,
69+
batch_size * vocab_size,
70+
vocab_size,
71+
device=logits.device)
72+
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
73+
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
74+
logits_flatten = logits.flatten()
75+
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
76+
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
77+
logits[valid_idx] = valid_logits
78+
return logits.reshape(batch_size, vocab_size)
79+
80+
81+
def topk_topp_forward_native(
82+
self,
83+
logits: torch.Tensor,
84+
generators: dict[int, torch.Generator],
85+
k: Optional[torch.Tensor],
86+
p: Optional[torch.Tensor],
87+
) -> torch.Tensor:
88+
"""
89+
PyTorch-native implementation of top-k and top-p sampling.
90+
91+
The logits tensor may be updated in-place.
92+
"""
93+
logits = _apply_top_k_top_p(logits, k, p)
94+
probs = logits.softmax(dim=-1, dtype=torch.float32)
95+
return random_sample(probs, generators)
96+
97+
98+
Sampler.apply_min_p = apply_min_p
99+
if envs.VLLM_ENABLE_TOPK_OPTIMZE:
100+
TopKTopPSampler.forward_native = topk_topp_forward_native

0 commit comments

Comments
 (0)