Skip to content

Conversation

@NickLucche
Copy link
Collaborator

@NickLucche NickLucche commented Mar 25, 2025

Enabling the topk optimization that was introduced in #15242.

Currently facing the very issue foreseen by @njhill here #15242 (comment).

ERROR 03-25 18:27:23 [core.py:343]     random_sampled = self.topk_topp_sampler(
ERROR 03-25 18:27:23 [core.py:343]                      ^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-25 18:27:23 [core.py:343]   File "/home/nick/vllm/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 03-25 18:27:23 [core.py:343]     return self._call_impl(*args, **kwargs)
ERROR 03-25 18:27:23 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-25 18:27:23 [core.py:343]   File "/home/nick/vllm/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 03-25 18:27:23 [core.py:343]     return forward_call(*args, **kwargs)
ERROR 03-25 18:27:23 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-25 18:27:23 [core.py:343]   File "/home/nick/vllm/vllm/v1/sample/ops/topk_topp_sampler.py", line 119, in forward_tpu
ERROR 03-25 18:27:23 [core.py:343]     topk_values, topk_indices = torch.topk(logits, k, dim=-1)
ERROR 03-25 18:27:23 [core.py:343]                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-25 18:27:23 [core.py:343] TypeError: topk(): argument 'k' (position 2) must be int, not Tensor

Dumping work for reference, will look into it asap.

Update:

For completeness, I've run microbenchmarks and the new impl is slower (but of course correct):

//before
Running 32 elapsed time: 0.0018310546875
Running 32 elapsed time: 0.0017833709716796875
// after
 Running 32 elapsed time: 0.003275632858276367
Running 32 elapsed time: 0.003297090530395508

cc @hyeygit

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Mar 25, 2025
@mergify
Copy link

mergify bot commented Mar 26, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 26, 2025
@mergify mergify bot removed the needs-rebase label Mar 26, 2025
@NickLucche NickLucche marked this pull request as ready for review March 26, 2025 10:38
@NickLucche NickLucche marked this pull request as draft March 26, 2025 11:14
@NickLucche
Copy link
Collaborator Author

Let's hold until main is fixed to reduce entropy

@mergify
Copy link

mergify bot commented Mar 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added tpu Related to Google TPUs needs-rebase labels Mar 27, 2025
@hyeygit
Copy link
Contributor

hyeygit commented Mar 27, 2025

Thank you @NickLucche for this PR and thank you @njhill for fixing the batch case for top-k! In #15242 I only tested with the microbenchmark and test_sampler.py (scalar case only), not realizing that k can be batched. Thank you for the catch and sorry for the miss.

One thing to note is that on TPU torch.topk still involves a full vocab sort (see XLA lowering). The reason using torch.topk is so much faster on TPU is because of avoiding a full-vocab torch.scatter (as used in apply_top_k_top_p) which is extremely slow on TPU.

@NickLucche
Copy link
Collaborator Author

NickLucche commented Mar 31, 2025

The main blocker for this PR is topk recompilation.
I was under the impression the XLA lowering would circumvent that need

import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch

B, V = 3, 64
device = xm.xla_device()

logits = torch.randn(B, V, device=device)

# pre compile
k = 3
top_k_mask = logits.topk(k, dim=1).values
top_k_mask = top_k_mask.cpu()

print(met.short_metrics_report())
met.clear_all()

# Run
for k in [1, V]:
    top_k_mask = logits.topk(k, dim=1).values
    top_k_mask = top_k_mask.cpu()
    print(met.short_metrics_report()) # shows it compiles
    met.clear_all()

Is there something I am missing? @hyeygit @yaochengji

@hyeygit
Copy link
Contributor

hyeygit commented Mar 31, 2025

The main blocker for this PR is topk recompilation. I was under the impression the XLA lowering would circumvent that need

@NickLucche thank you for raising this. I think it's because the return shape of torch.topk is a function of k, so a change in k would cause XLA to recompile.

If this k-induced recompilation is not acceptable (sounds like it isn't), then let me implement a TPU-specific top-k using the same logic as in #15736. Will send out a PR shortly.

@NickLucche
Copy link
Collaborator Author

I think it's because the return shape of torch.topk is a function of k, so a change in k would cause XLA to recompile.

Yeah the reason is quite clear to me, I just don't know why I believed this wouldn't trigger recompilation lol

then let me implement a TPU-specific top-k using the same logic as in #15736.

That will do thanks a lot!
Then we have to put this PR on hold until the above gets merged or I cherry pick your work here and co-author.

@mergify
Copy link

mergify bot commented Apr 2, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 2, 2025
@NickLucche
Copy link
Collaborator Author

@hyeygit I took the liberty of adding your contrib #15891 to this PR so we can test. Locally everything looks fine, let's wait for the CI tests. Nice work thanks!

@hyeygit
Copy link
Contributor

hyeygit commented Apr 3, 2025

So enabling topk here is costing us a bit of performance.

@NickLucche yes I think I can corroborate with this. I ran benchmark_serving.py with top-k and top-p enabled and observed a 6.8% drop in throughput (5.98 req/s -> 5.57 req/s). This was on TPU VM v6e-1 and running Llama3.1-8B.

@alexm-redhat
Copy link
Collaborator

Just ran this PR for llama 70B with 8 x v6e TPUs. It achieves 5 reqs/sec instead of the previous 5.1 reqs/sec so the performance penalty I see is negligible.

Copy link
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks, Nick!

NickLucche and others added 3 commits April 11, 2025 15:01
Signed-off-by: NickLucche <nlucches@redhat.com>
Previously we found that using torch.topk resulted in significant
speed up for TPU. Turns out that's not a viable solution because the
return shape of torch.topk depends on k, which means an XLA recompilation
is triggered everytime k changes.

Additionally, we realized that torch.scatter was the main bottleneck for
the original top-k impl on TPU. This PR circumvents both problems by using
a threshold-based approach to find the top-k set. The algorithm is nearly
identical to that of top-p; see vllm-project#15736 for more details.

Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Copy link
Collaborator

@alexm-redhat alexm-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just some small comments

Signed-off-by: NickLucche <nlucches@redhat.com>
@hyeygit
Copy link
Contributor

hyeygit commented Apr 14, 2025

@NickLucche small request -- could you incorporate the top-k equivalence test from my PR? https://github.com/vllm-project/vllm/pull/15891/files#diff-09d15417fe42d494c51aaa9635ad51536751cc3c0659a7c4ce3b66bd6900eb1f

Co-authored-by: Hyesoo Yang <hyeygit@gmail.com>

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) April 17, 2025 16:28
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 17, 2025
@robertgshaw2-redhat robertgshaw2-redhat merged commit eb5819b into vllm-project:main Apr 17, 2025
58 checks passed
yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Apr 21, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
Co-authored-by: Hyesoo Yang <hyeygit@gmail.com>
Signed-off-by: Yang Wang <elainewy@meta.com>
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
Co-authored-by: Hyesoo Yang <hyeygit@gmail.com>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
Co-authored-by: Hyesoo Yang <hyeygit@gmail.com>
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
Co-authored-by: Hyesoo Yang <hyeygit@gmail.com>
Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
Co-authored-by: Hyesoo Yang <hyeygit@gmail.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
@hgt312
Copy link

hgt312 commented Jul 29, 2025

default top_k pad value should not be zero? 0 trilgget the error in probs_sort.gather(-1, top_k_count)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants