You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Previously we found that using torch.topk resulted in significant
speed up for TPU. Turns out that's not a viable solution because the
return shape of torch.topk depends on k, which means an XLA recompilation
is triggered everytime k changes.
Additionally, we realized that torch.scatter was the main bottleneck for
the original top-k impl on TPU. This PR circumvents both problems by using
a threshold-based approach to find the top-k set. The algorithm is nearly
identical to that of top-p; see #15736 for more details.
Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
0 commit comments