[V1][TPU] TPU-optimized top-k implementation (2nd try) #15891
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Previously we found that using torch.topk resulted in significant
speed up for TPU. Turns out that's not a viable solution because the
return shape of torch.topk depends on k, which means an XLA recompilation
is triggered everytime k changes.
Additionally, we realized that torch.scatter was the main bottleneck for
the original top-k impl on TPU. This PR circumvents both problems by using
a threshold-based approach to find the top-k set. The algorithm is nearly
identical to that of top-p; see #15736 for more details.
Benchmark
Sampling microbenchmark yields similar result as in the top-p PR, with "Running 32 elapsed time" averaging ~5 ms (down from 500 ms pre-optimization).
End-to-end serving benchmark with both top-k and top-p enabled show that on TPU (v6e-1) running Llama3.1-8B, the TPU-optimization here (along with the top-p PR) yields 23X speed up.