Skip to content

Commit 1346edf

Browse files
committed
perf: speed up topk_topp_sampler
1 parent 9f5ab59 commit 1346edf

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

vllm_ascend/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import vllm_ascend.ops.layernorm # noqa
2424
import vllm_ascend.ops.rotary_embedding # noqa
2525
import vllm_ascend.ops.vocab_parallel_embedding # noqa
26+
import vllm_ascend.ops.topk_topp_sampler
2627

2728

2829
class dummyFusionOp:
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from typing import Optional
3+
import torch
4+
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, apply_top_k_top_p_tpu, random_sample
5+
6+
7+
def forward_npu(
8+
self,
9+
logits: torch.Tensor,
10+
generators: dict[int, torch.Generator],
11+
k: Optional[torch.Tensor],
12+
p: Optional[torch.Tensor],
13+
) -> torch.Tensor:
14+
logits = apply_top_k_top_p_tpu(logits, k, p)
15+
probs = logits.softmax(dim=-1, dtype=torch.float32)
16+
return random_sample(probs, generators)
17+
18+
19+
TopKTopPSampler.forward_native = forward_npu

0 commit comments

Comments
 (0)