-
-
Couldn't load subscription status.
- Fork 10.8k
[V1][TPU] Enable Top K #15489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1][TPU] Enable Top K #15489
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
|
This pull request has merge conflicts that must be resolved before it can be |
817c4f1 to
a5bf849
Compare
|
Let's hold until main is fixed to reduce entropy |
|
This pull request has merge conflicts that must be resolved before it can be |
|
Thank you @NickLucche for this PR and thank you @njhill for fixing the batch case for top-k! In #15242 I only tested with the microbenchmark and test_sampler.py (scalar case only), not realizing that One thing to note is that on TPU |
fac0e65 to
4759b89
Compare
4759b89 to
0b022fb
Compare
|
The main blocker for this PR is topk recompilation. import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch
B, V = 3, 64
device = xm.xla_device()
logits = torch.randn(B, V, device=device)
# pre compile
k = 3
top_k_mask = logits.topk(k, dim=1).values
top_k_mask = top_k_mask.cpu()
print(met.short_metrics_report())
met.clear_all()
# Run
for k in [1, V]:
top_k_mask = logits.topk(k, dim=1).values
top_k_mask = top_k_mask.cpu()
print(met.short_metrics_report()) # shows it compiles
met.clear_all()Is there something I am missing? @hyeygit @yaochengji |
@NickLucche thank you for raising this. I think it's because the return shape of If this k-induced recompilation is not acceptable (sounds like it isn't), then let me implement a TPU-specific top-k using the same logic as in #15736. Will send out a PR shortly. |
Yeah the reason is quite clear to me, I just don't know why I believed this wouldn't trigger recompilation lol
That will do thanks a lot! |
|
This pull request has merge conflicts that must be resolved before it can be |
@NickLucche yes I think I can corroborate with this. I ran |
|
Just ran this PR for llama 70B with 8 x v6e TPUs. It achieves 5 reqs/sec instead of the previous 5.1 reqs/sec so the performance penalty I see is negligible. |
6dc18ce to
6a6579c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks, Nick!
Signed-off-by: NickLucche <nlucches@redhat.com>
Previously we found that using torch.topk resulted in significant speed up for TPU. Turns out that's not a viable solution because the return shape of torch.topk depends on k, which means an XLA recompilation is triggered everytime k changes. Additionally, we realized that torch.scatter was the main bottleneck for the original top-k impl on TPU. This PR circumvents both problems by using a threshold-based approach to find the top-k set. The algorithm is nearly identical to that of top-p; see vllm-project#15736 for more details. Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
0530d80 to
71700b5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Just some small comments
Signed-off-by: NickLucche <nlucches@redhat.com>
|
@NickLucche small request -- could you incorporate the top-k equivalence test from my PR? https://github.com/vllm-project/vllm/pull/15891/files#diff-09d15417fe42d494c51aaa9635ad51536751cc3c0659a7c4ce3b66bd6900eb1f |
Co-authored-by: Hyesoo Yang <hyeygit@gmail.com> Signed-off-by: NickLucche <nlucches@redhat.com>
98ac301 to
ee0cd58
Compare
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Hyesoo Yang <hyeygit@gmail.com> Co-authored-by: Hyesoo Yang <hyeygit@gmail.com> Signed-off-by: Yang Wang <elainewy@meta.com>
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Hyesoo Yang <hyeygit@gmail.com> Co-authored-by: Hyesoo Yang <hyeygit@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Hyesoo Yang <hyeygit@gmail.com> Co-authored-by: Hyesoo Yang <hyeygit@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Hyesoo Yang <hyeygit@gmail.com> Co-authored-by: Hyesoo Yang <hyeygit@gmail.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Hyesoo Yang <hyeygit@gmail.com> Co-authored-by: Hyesoo Yang <hyeygit@gmail.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
|
default top_k pad value should not be zero? 0 trilgget the error in |
Enabling the topk optimization that was introduced in #15242.
Currently facing the very issue foreseen by @njhill here #15242 (comment).
Dumping work for reference, will look into it asap.Update:
For completeness, I've run microbenchmarks and the new impl is slower (but of course correct):
cc @hyeygit