Skip to content

Conversation

@hyeygit
Copy link
Contributor

@hyeygit hyeygit commented Apr 1, 2025

Previously we found that using torch.topk resulted in significant
speed up for TPU. Turns out that's not a viable solution because the
return shape of torch.topk depends on k, which means an XLA recompilation
is triggered everytime k changes.

Additionally, we realized that torch.scatter was the main bottleneck for
the original top-k impl on TPU. This PR circumvents both problems by using
a threshold-based approach to find the top-k set. The algorithm is nearly
identical to that of top-p; see #15736 for more details.

Benchmark

Sampling microbenchmark yields similar result as in the top-p PR, with "Running 32 elapsed time" averaging ~5 ms (down from 500 ms pre-optimization).

End-to-end serving benchmark with both top-k and top-p enabled show that on TPU (v6e-1) running Llama3.1-8B, the TPU-optimization here (along with the top-p PR) yields 23X speed up.

Description Throughput (req/s)
Baseline (NO sampling) 5.98
Baseline w/ non-optimized sampling (forward_native) 0.24
TPU-optimized sampling 5.57

@github-actions
Copy link

github-actions bot commented Apr 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@hyeygit hyeygit marked this pull request as draft April 1, 2025 15:36
@mergify mergify bot added ci/build v1 tpu Related to Google TPUs labels Apr 1, 2025
@hyeygit
Copy link
Contributor Author

hyeygit commented Apr 1, 2025

cc @NickLucche fyi

Previously we found that using torch.topk resulted in significant
speed up for TPU. Turns out that's not a viable solution because the
return shape of torch.topk depends on k, which means an XLA recompilation
is triggered everytime k changes.

Additionally, we realized that torch.scatter was the main bottleneck for
the original top-k impl on TPU. This PR circumvents both problems by using
a threshold-based approach to find the top-k set. The algorithm is nearly
identical to that of top-p; see vllm-project#15736 for more details.

Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
@hyeygit hyeygit marked this pull request as ready for review April 3, 2025 15:38
@NickLucche
Copy link
Collaborator

I think we can get this one in as a commit in #15489. We can probably close this.

@hyeygit
Copy link
Contributor Author

hyeygit commented Apr 7, 2025

I think we can get this one in as a commit in #15489. We can probably close this.

@NickLucche Sounds good. Could you also patch ae7bc73 to your PR? Thanks!

@hyeygit hyeygit closed this Apr 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants