-
Notifications
You must be signed in to change notification settings - Fork 561
optimize the funtion of computing topk and topp in sampler. #970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,147 @@ | ||
| # | ||
| # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
| # This file is a part of the vllm-ascend project. | ||
| # Adapted from vllm/tests/entrypoints/llm/test_guided_generate.py | ||
| # Copyright 2023 The vLLM team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| from vllm.v1.sample.sampler import Sampler # noqa: F401 | ||
|
|
||
| # Set tolerance to 1 for quant ops | ||
| DEFAULT_ATOL = 1e-3 | ||
| DEFAULT_RTOL = 1e-3 | ||
|
|
||
|
|
||
| def apply_min_p_new( | ||
| logits: torch.Tensor, | ||
| min_p: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Filters logits using adaptive probability thresholding. | ||
| """ | ||
| if min_p == 0: | ||
| return logits | ||
| # Convert logits to probability distribution | ||
| probability_values = torch.nn.functional.softmax(logits, dim=-1) | ||
| # Calculate maximum probabilities per sequence | ||
| max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True) | ||
| # Reshape min_p for broadcasting | ||
| adjusted_min_p = min_p.unsqueeze(1) * max_probabilities | ||
| # Identify valid tokens using threshold comparison | ||
| # Apply mask using boolean indexing | ||
| logits = logits.masked_fill(probability_values < adjusted_min_p, | ||
| -float('inf')) | ||
| return logits | ||
|
|
||
|
|
||
| def apply_top_k_top_p( | ||
| logits: torch.Tensor, | ||
| k: Optional[torch.Tensor], | ||
| p: Optional[torch.Tensor], | ||
| ) -> torch.Tensor: | ||
| """Apply top-k and top-p masks to the logits. | ||
|
|
||
| If a top-p is used, this function will sort the logits tensor, | ||
| which can be slow for large batches. | ||
|
|
||
| The logits tensor may be updated in-place. | ||
| """ | ||
| logits_sort, logits_idx = logits.sort(dim=-1, descending=False) | ||
|
|
||
| if k is not None: | ||
| # Apply top-k. | ||
| top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B | ||
| # Get all the top_k values. | ||
| top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) | ||
| top_k_mask = logits_sort < top_k_mask | ||
| logits_sort.masked_fill_(top_k_mask, -float("inf")) | ||
|
|
||
| if p is not None: | ||
| # Apply top-p. | ||
| probs_sort = logits_sort.softmax(dim=-1) | ||
| probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) | ||
| top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) | ||
| # at least one | ||
| top_p_mask[:, -1] = False | ||
| logits_sort.masked_fill_(top_p_mask, -float("inf")) | ||
|
|
||
| # Re-sort the probabilities. | ||
| logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) | ||
| return logits | ||
|
|
||
|
|
||
| def apply_top_k_top_p_new( | ||
| logits: torch.Tensor, | ||
| k: Optional[torch.Tensor], | ||
| p: Optional[torch.Tensor], | ||
| ) -> torch.Tensor: | ||
| batch_size, vocab_size = logits.shape | ||
| logits_sort, logits_idx = logits.sort(dim=-1, descending=False) | ||
|
|
||
| # Apply top-k. | ||
| boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1)) | ||
| top_k_mask = logits_sort < boundary | ||
| logits_sort.masked_fill_(top_k_mask, -float("inf")) | ||
|
|
||
| if p is not None: | ||
| # Apply top-p. | ||
| cutoff = top_k_mask.sum(dim=-1).min() | ||
| probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:] | ||
| probs_sum = probs_sort.cumsum(dim=-1) | ||
| top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1) | ||
| top_p_mask[:, -1] = True | ||
| strides = torch.arange(0, | ||
| batch_size * vocab_size, | ||
| vocab_size, | ||
| device=logits.device) | ||
| flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1) | ||
| valid_idx = torch.masked_select(flatten_idx, top_p_mask) | ||
| logits_flatten = logits.flatten() | ||
| valid_logits = torch.index_select(logits_flatten, 0, valid_idx) | ||
| logits = torch.empty_like(logits_flatten).fill_(-float("inf")) | ||
| logits[valid_idx] = valid_logits | ||
| return logits.reshape(batch_size, vocab_size) | ||
|
|
||
|
|
||
| # test with leading dimension and merge seqlen and batch_size as num_tokens | ||
| @torch.inference_mode() | ||
| def test_apply_min_p() -> None: | ||
| logits = torch.randn((128, 7168)).npu() | ||
| min_p = torch.Tensor([0.01]).npu() | ||
| logits_new = apply_min_p_new(logits, min_p) | ||
| sampler = Sampler() | ||
| logits_old = sampler.apply_min_p(logits, min_p) | ||
| # Compare the results. | ||
| torch.testing.assert_close(logits_new, | ||
| logits_old, | ||
| atol=DEFAULT_ATOL, | ||
| rtol=DEFAULT_RTOL) | ||
|
|
||
|
|
||
| # test with leading dimension and merge seqlen and batch_size as num_tokens | ||
| @torch.inference_mode() | ||
| def test_apply_top_k_top_p() -> None: | ||
| logits = torch.randn((128, 7168)).npu() | ||
| k = torch.Tensor([-1]).int().npu() | ||
| p = torch.Tensor([1]).int().npu() | ||
| logits_new = apply_top_k_top_p_new(logits, k, p) | ||
| logits_old = apply_top_k_top_p(logits, k, p) | ||
| # Compare the results. | ||
| torch.testing.assert_close(logits_new, | ||
| logits_old, | ||
| atol=DEFAULT_ATOL, | ||
| rtol=DEFAULT_RTOL) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -362,7 +362,7 @@ def fused_experts( | |
| num_experts)).to(topk_ids.dtype) | ||
|
|
||
| # Sort by local expert IDs | ||
| sort_indices = torch.argsort(filtered_experts) | ||
| sort_indices = torch.argsort(filtered_experts.view(torch.float32)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we change the dtype of
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sort can use aicore under float32, with better performance
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sort can use aicore under float32, with better performance |
||
| sorted_token_indices = token_indices[sort_indices] | ||
| sorted_weights = filtered_weights[sort_indices] | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| # | ||
momo609 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # This file is a part of the vllm-ascend project. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| from typing import Optional | ||
|
|
||
| import torch | ||
| from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample | ||
| from vllm.v1.sample.sampler import Sampler | ||
|
|
||
| from vllm_ascend import envs | ||
|
|
||
|
|
||
| def apply_min_p( | ||
| self, | ||
| logits: torch.Tensor, | ||
| min_p: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Filters logits using adaptive probability thresholding. | ||
| """ | ||
| # Convert logits to probability distribution | ||
| probability_values = torch.nn.functional.softmax(logits, dim=-1) | ||
| # Calculate maximum probabilities per sequence | ||
| max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True) | ||
| # Reshape min_p for broadcasting | ||
| adjusted_min_p = min_p.unsqueeze(1) * max_probabilities | ||
| # Identify valid tokens using threshold comparison | ||
| # Apply mask using boolean indexing | ||
| logits = logits.masked_fill(probability_values < adjusted_min_p, | ||
| -float('inf')) | ||
| return logits | ||
|
|
||
|
|
||
| def _apply_top_k_top_p( | ||
| logits: torch.Tensor, | ||
| p: torch.Tensor, | ||
| k: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| probs = logits.softmax(dim=-1) | ||
| probs_sort, _ = probs.sort(dim=-1, descending=False) | ||
|
|
||
| if k is not None: | ||
| top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, ) | ||
| top_k_count = top_k_count.unsqueeze(dim=1) | ||
| top_k_cutoff = probs_sort.gather(-1, top_k_count) | ||
|
|
||
| # Make sure the no top-k rows are no-op. | ||
| no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) | ||
| top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) | ||
|
|
||
| elements_to_discard = probs < top_k_cutoff | ||
| logits.masked_fill_(elements_to_discard, -float("inf")) | ||
|
|
||
| if p is not None: | ||
| cumprob = torch.cumsum(probs_sort, dim=-1) | ||
| top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) | ||
| top_p_mask[:, -1] = False # at least one | ||
|
|
||
| top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) | ||
| top_p_cutoff = probs_sort.gather(-1, top_p_count) | ||
| elements_to_discard = probs < top_p_cutoff | ||
| logits.masked_fill_(elements_to_discard, -float("inf")) | ||
|
|
||
| return logits | ||
|
|
||
|
|
||
| def topk_topp_forward_native( | ||
| self, | ||
| logits: torch.Tensor, | ||
| generators: dict[int, torch.Generator], | ||
| k: Optional[torch.Tensor], | ||
| p: Optional[torch.Tensor], | ||
| ) -> torch.Tensor: | ||
| """ | ||
| PyTorch-native implementation of top-k and top-p sampling. | ||
|
|
||
| The logits tensor may be updated in-place. | ||
| """ | ||
| logits = _apply_top_k_top_p(logits, k, p) | ||
| probs = logits.softmax(dim=-1, dtype=torch.float32) | ||
| return random_sample(probs, generators) | ||
|
|
||
|
|
||
| Sampler.apply_min_p = apply_min_p | ||
| if envs.VLLM_ASCEND_ENABLE_TOPK_OPTIMZE: | ||
| TopKTopPSampler.forward_native = topk_topp_forward_native | ||
Uh oh!
There was an error while loading. Please reload this page.