From 4be61f74f1279362e7e73f4cbeaedcd574e4def1 Mon Sep 17 00:00:00 2001 From: "wangxiaoxin (A)" Date: Thu, 5 Jun 2025 10:29:30 +0800 Subject: [PATCH] add optimze of dsv3. After testing, the tpu_apply_top_k_top_p function achieves optimal performance. Signed-off-by: wangxiaoxin (A) Co-authored-by: ZhengWG --- .../test_offline_inference_distributed.py | 24 +++ tests/singlecard/test_offline_inference.py | 23 +++ tests/singlecard/test_sampler.py | 147 ++++++++++++++++++ vllm_ascend/envs.py | 2 + vllm_ascend/models/deepseek_v2.py | 6 +- vllm_ascend/ops/fused_moe.py | 2 +- vllm_ascend/patch/__init__.py | 27 ++++ .../patch/worker/patch_common/__init__.py | 1 + .../worker/patch_common/patch_sampler.py | 101 ++++++++++++ 9 files changed, 330 insertions(+), 3 deletions(-) create mode 100644 tests/singlecard/test_sampler.py create mode 100644 vllm_ascend/patch/worker/patch_common/patch_sampler.py diff --git a/tests/multicard/test_offline_inference_distributed.py b/tests/multicard/test_offline_inference_distributed.py index 941055cf72..84a393064d 100644 --- a/tests/multicard/test_offline_inference_distributed.py +++ b/tests/multicard/test_offline_inference_distributed.py @@ -21,8 +21,10 @@ Run `pytest tests/test_offline_inference.py`. """ import os +from unittest.mock import patch import vllm # noqa: F401 +from vllm import SamplingParams from tests.conftest import VllmRunner @@ -61,3 +63,25 @@ def test_models_distributed_DeepSeek(): distributed_executor_backend="mp", ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) + + +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMZE": "1"}) +def test_models_distributed_topk() -> None: + example_prompts = [ + "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", + "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.", + "Compare and contrast artificial intelligence with human intelligence in terms of processing information.", + ] + dtype = "half" + sampling_params = SamplingParams(max_tokens=5, + temperature=0.0, + top_k=50, + top_p=0.9) + + with VllmRunner( + "deepseek-ai/DeepSeek-V2-Lite", + dtype=dtype, + tensor_parallel_size=4, + distributed_executor_backend="mp", + ) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) diff --git a/tests/singlecard/test_offline_inference.py b/tests/singlecard/test_offline_inference.py index 5d0e16e5fd..d6c36c7f46 100644 --- a/tests/singlecard/test_offline_inference.py +++ b/tests/singlecard/test_offline_inference.py @@ -21,9 +21,11 @@ Run `pytest tests/test_offline_inference.py`. """ import os +from unittest.mock import patch import pytest import vllm # noqa: F401 +from vllm import SamplingParams from vllm.assets.image import ImageAsset import vllm_ascend # noqa: F401 @@ -81,3 +83,24 @@ def test_multimodal(model, prompt_template, vllm_runner): vllm_model.generate_greedy(prompts=prompts, images=images, max_tokens=64) + + +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMZE": "1"}) +def test_models_topk() -> None: + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(max_tokens=5, + temperature=0.0, + top_k=50, + top_p=0.9) + + with VllmRunner("Qwen/Qwen2.5-0.5B-Instruct", + max_model_len=8192, + dtype="float16", + enforce_eager=True, + gpu_memory_utilization=0.7) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) diff --git a/tests/singlecard/test_sampler.py b/tests/singlecard/test_sampler.py new file mode 100644 index 0000000000..b21142018e --- /dev/null +++ b/tests/singlecard/test_sampler.py @@ -0,0 +1,147 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/entrypoints/llm/test_guided_generate.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Optional + +import torch +from vllm.v1.sample.sampler import Sampler # noqa: F401 + +# Set tolerance to 1 for quant ops +DEFAULT_ATOL = 1e-3 +DEFAULT_RTOL = 1e-3 + + +def apply_min_p_new( + logits: torch.Tensor, + min_p: torch.Tensor, +) -> torch.Tensor: + """ + Filters logits using adaptive probability thresholding. + """ + if min_p == 0: + return logits + # Convert logits to probability distribution + probability_values = torch.nn.functional.softmax(logits, dim=-1) + # Calculate maximum probabilities per sequence + max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True) + # Reshape min_p for broadcasting + adjusted_min_p = min_p.unsqueeze(1) * max_probabilities + # Identify valid tokens using threshold comparison + # Apply mask using boolean indexing + logits = logits.masked_fill(probability_values < adjusted_min_p, + -float('inf')) + return logits + + +def apply_top_k_top_p( + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + """Apply top-k and top-p masks to the logits. + + If a top-p is used, this function will sort the logits tensor, + which can be slow for large batches. + + The logits tensor may be updated in-place. + """ + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + if k is not None: + # Apply top-k. + top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B + # Get all the top_k values. + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_mask + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + if p is not None: + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + # at least one + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + + # Re-sort the probabilities. + logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) + return logits + + +def apply_top_k_top_p_new( + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + batch_size, vocab_size = logits.shape + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + # Apply top-k. + boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1)) + top_k_mask = logits_sort < boundary + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + if p is not None: + # Apply top-p. + cutoff = top_k_mask.sum(dim=-1).min() + probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:] + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = True + strides = torch.arange(0, + batch_size * vocab_size, + vocab_size, + device=logits.device) + flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1) + valid_idx = torch.masked_select(flatten_idx, top_p_mask) + logits_flatten = logits.flatten() + valid_logits = torch.index_select(logits_flatten, 0, valid_idx) + logits = torch.empty_like(logits_flatten).fill_(-float("inf")) + logits[valid_idx] = valid_logits + return logits.reshape(batch_size, vocab_size) + + +# test with leading dimension and merge seqlen and batch_size as num_tokens +@torch.inference_mode() +def test_apply_min_p() -> None: + logits = torch.randn((128, 7168)).npu() + min_p = torch.Tensor([0.01]).npu() + logits_new = apply_min_p_new(logits, min_p) + sampler = Sampler() + logits_old = sampler.apply_min_p(logits, min_p) + # Compare the results. + torch.testing.assert_close(logits_new, + logits_old, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + + +# test with leading dimension and merge seqlen and batch_size as num_tokens +@torch.inference_mode() +def test_apply_top_k_top_p() -> None: + logits = torch.randn((128, 7168)).npu() + k = torch.Tensor([-1]).int().npu() + p = torch.Tensor([1]).int().npu() + logits_new = apply_top_k_top_p_new(logits, k, p) + logits_old = apply_top_k_top_p(logits, k, p) + # Compare the results. + torch.testing.assert_close(logits_new, + logits_old, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index a4e9450b8b..9378e6f715 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -36,6 +36,8 @@ lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))), "VLLM_ENABLE_MC2": lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))), + "VLLM_ASCEND_ENABLE_TOPK_OPTIMZE": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_OPTIMZE", '0'))), "USING_LCCL_COM": lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))), "SOC_VERSION": diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 515ebe1a8e..d92f9c5da9 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -241,8 +241,7 @@ def forward( num_tokens, hidden_size = hidden_states.shape - if self.n_shared_experts is not None: - shared_output = self.shared_experts(hidden_states) + old_hidden_states = hidden_states.clone() if self.tp_size > 1: if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: @@ -291,6 +290,9 @@ def forward( if num_padding_tokens > 0: hidden_states = hidden_states[:-num_padding_tokens] + if self.n_shared_experts is not None: + shared_output = self.shared_experts(old_hidden_states) + if shared_output is not None: hidden_states = hidden_states + shared_output diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 4853b27282..43188295ec 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -362,7 +362,7 @@ def fused_experts( num_experts)).to(topk_ids.dtype) # Sort by local expert IDs - sort_indices = torch.argsort(filtered_experts) + sort_indices = torch.argsort(filtered_experts.view(torch.float32)) sorted_token_indices = token_indices[sort_indices] sorted_weights = filtered_weights[sort_indices] diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 5660d62391..ccf9bd9e0c 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -166,3 +166,30 @@ # Future Plan: # Revert it when the ascend support triton kernel. # +# ** File: v1/sample/sampler.py ** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.v1.sample.sampler.Sampler.apply_top_k_top_p` +# Why: +# We need to use the patched `apply_top_k_top_p` in `sample`. +# The mainly reason to overwrite `apply_top_k_top_p` is +# to improve performance. +# How: +# Re-implementation the `apply_top_k_top_p` function by pytorch +# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit.... +# - https://github.com/vllm-project/vllm-ascend/pull/970 +# Future Plan: +# Revert it when the ascend scatter performance improves. +# +# ** File: v1/sample/sampler.py ** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~s +# 1. `vllm.v1.sample.sampler.Sampler.apply_min_p` +# Why: +# We need to use the patched `apply_min_p` in `sample`. +# The mainly reason to overwrite `apply_min_p` is +# to improve performance. +# How: +# Re-implementation the `apply_min_p` function by pytorch +# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit.... +# - https://github.com/vllm-project/vllm-ascend/pull/970 +# Future Plan: +# Revert it when the ascend indexput performance improves. diff --git a/vllm_ascend/patch/worker/patch_common/__init__.py b/vllm_ascend/patch/worker/patch_common/__init__.py index 5b55ac6b3e..7618823ba6 100644 --- a/vllm_ascend/patch/worker/patch_common/__init__.py +++ b/vllm_ascend/patch/worker/patch_common/__init__.py @@ -23,4 +23,5 @@ import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa +import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa diff --git a/vllm_ascend/patch/worker/patch_common/patch_sampler.py b/vllm_ascend/patch/worker/patch_common/patch_sampler.py new file mode 100644 index 0000000000..495404184e --- /dev/null +++ b/vllm_ascend/patch/worker/patch_common/patch_sampler.py @@ -0,0 +1,101 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional + +import torch +from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample +from vllm.v1.sample.sampler import Sampler + +from vllm_ascend import envs + + +def apply_min_p( + self, + logits: torch.Tensor, + min_p: torch.Tensor, +) -> torch.Tensor: + """ + Filters logits using adaptive probability thresholding. + """ + # Convert logits to probability distribution + probability_values = torch.nn.functional.softmax(logits, dim=-1) + # Calculate maximum probabilities per sequence + max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True) + # Reshape min_p for broadcasting + adjusted_min_p = min_p.unsqueeze(1) * max_probabilities + # Identify valid tokens using threshold comparison + # Apply mask using boolean indexing + logits = logits.masked_fill(probability_values < adjusted_min_p, + -float('inf')) + return logits + + +def _apply_top_k_top_p( + logits: torch.Tensor, + p: torch.Tensor, + k: torch.Tensor, +) -> torch.Tensor: + probs = logits.softmax(dim=-1) + probs_sort, _ = probs.sort(dim=-1, descending=False) + + if k is not None: + top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, ) + top_k_count = top_k_count.unsqueeze(dim=1) + top_k_cutoff = probs_sort.gather(-1, top_k_count) + + # Make sure the no top-k rows are no-op. + no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) + top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) + + elements_to_discard = probs < top_k_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + if p is not None: + cumprob = torch.cumsum(probs_sort, dim=-1) + top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = False # at least one + + top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) + top_p_cutoff = probs_sort.gather(-1, top_p_count) + elements_to_discard = probs < top_p_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + return logits + + +def topk_topp_forward_native( + self, + logits: torch.Tensor, + generators: dict[int, torch.Generator], + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + """ + PyTorch-native implementation of top-k and top-p sampling. + + The logits tensor may be updated in-place. + """ + logits = _apply_top_k_top_p(logits, k, p) + probs = logits.softmax(dim=-1, dtype=torch.float32) + return random_sample(probs, generators) + + +Sampler.apply_min_p = apply_min_p +if envs.VLLM_ASCEND_ENABLE_TOPK_OPTIMZE: + TopKTopPSampler.forward_native = topk_topp_forward_native