Skip to content

Conversation

@njhill
Copy link
Member

@njhill njhill commented Mar 25, 2025

When there's top-k in the batch but no top-p.

For 128k vocab, 1024 batch size, 500 ops on A100, where max top k is 10:

Before: 11.571 sec
After: 2.136 sec

Signed-off-by: Nick Hill <nhill@redhat.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested on TPU this won't work out of the box due to some broadcasting issue.

Signed-off-by: Nick Hill <nhill@redhat.com>
@njhill
Copy link
Member Author

njhill commented Mar 25, 2025

@NickLucche that's strange. Which op has that issue?

@NickLucche
Copy link
Collaborator

Not too surprising, torch xla has more constraining rules on broadcasting.
This is the first error I have encountered

F0325 16:28:32.957930 1304047 debug_macros.h:21] Non-OK-status: status.status()
Status: INVALID_ARGUMENT: Input dimension should be either 1 or equal to the output dimension it is broadcasting into; the 0th operand dimension is 4, the 0th output dimension is 1.
*** Begin stack trace ***
	tsl::CurrentStackTrace[abi:cxx11]()
	xla::Shape const* ConsumeValue<xla::Shape const*>(absl::lts_20230802::StatusOr<xla::Shape const*>&&)
	torch_xla::ShapeHelper::ShapeOfXlaOp(xla::XlaOp)
	torch_xla::InferOutputShape(absl::lts_20230802::Span<xla::Shape const>, std::function<xla::XlaOp (absl::lts_20230802::Span<xla::XlaOp const>)> const&)
	
	
	torch_xla::XlaNode::GetOpShape(std::function<xla::Shape ()> const&) const
	torch_xla::XlaNode::XlaNode(torch::lazy::OpKind, c10::ArrayRef<torch::lazy::Value>, std::function<xla::Shape ()> const&, unsigned long, torch::lazy::hash_t)
	torch_xla::Gather::Gather(torch::lazy::Value const&, long, torch::lazy::Value const&)
	std::shared_ptr<torch::lazy::Node> torch_xla::MakeNode<torch_xla::Gather, torch::lazy::Value, long&, torch::lazy::Value>(torch::lazy::Value&&, long&, torch::lazy::Value&&)
	torch_xla::tensor_methods::gather(c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&, long, c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&)
	torch_xla::XLANativeFunctions::gather(at::Tensor const&, long, at::Tensor const&, bool)
	
	
	at::_ops::gather::redispatch(c10::DispatchKeySet, at::Tensor const&, long, at::Tensor const&, bool)
	
	
	at::_ops::gather::call(at::Tensor const&, long, at::Tensor const&, bool)

on the .gather op. I expanded k but then ran into another issue.

Signed-off-by: Nick Hill <nhill@redhat.com>
"""
if k is None and p is None:
if p is None:
if k is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a unit test checking the correctness of this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should really have blanket coverage for this kind of thing, including different combinations of parameters (i.e. top-k with/without top-p etc.). I'm not sure whether we do though. I will check and add a unit test to compare the two impls.

Signed-off-by: Nick Hill <nhill@redhat.com>
Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this version again today and it's working on TPU too, nice one @njhill thanks!
I was wondering could we still factor-out this topk opt into its own function so I can call it from TPU side?
We agreed with @WoosukKwon to try and keep things separated, I'd like to keep forward_tpu around.

@NickLucche
Copy link
Collaborator

Something like a5bf849#diff-6047245d864bf5fd68b5b947b735beca94723bad40d20bfc0803d9b3eea5c1edR121-R136.
Wdyt? Of course I'd wait for this PR to land and then rebase, I've shamelessly just copy-pasted your code there.

njhill added 2 commits March 26, 2025 07:17
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
@njhill
Copy link
Member Author

njhill commented Mar 26, 2025

Thanks @NickLucche, I've split into separate function. And @WoosukKwon I've added a correctness test.

@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 26, 2025
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for addressing my comments.

@njhill njhill merged commit 35fad35 into vllm-project:main Mar 26, 2025
39 checks passed
@njhill njhill deleted the torch-topk branch March 26, 2025 17:56
@hyeygit
Copy link
Contributor

hyeygit commented Mar 30, 2025

@njhill really neat idea to threshold the logits! However I think one corner case where this would break is if there are duplicate elements in the logit that equal the cut off value (i.e. top_k_mask). For example, given an input of [1, 2, 2, 2, 3] and k=3, the current apply_top_k_only would return [-inf, 2, 2, 2, 3] while the correct result should be [-inf, -inf, 2, 2, 3].

In #15736 I use a similar thresholding logic for top-p, but introduced a small random perturbation to break the ties. Maybe the same idea can be used here for top-k as well.

Alex4210987 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Apr 5, 2025
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: xinyuxiao <xinyuxiao2024@gmail.com>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants