-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOPI] GPU sort IR refactor to enable sort by keys #7157
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Yes, thrust's radix sort is very good.
Can you share the benchmark numbers? I'm curious.
I definitely didn't promise to be faster than thrust with #7099, just faster that what we had in TIR :) |
Ok here it is. I think TIR numbers for 5000 and 10000 agree with the ones you had in #7099
|
Yep, this matches the scaling I would expect with that algorithm. I'm kind of surprised that sequential scatter isn't worse than it is. I'm looking for ways to improve the mergesort, but I haven't found a good idea that we can implement in TIR yet, what I've seen requries a |
Thanks @masahi @mbrookhart @Laurawly |
* sort refactor initial import * sort test working * scatter 1d with positive indices working * remove negatiev indices, using extern for now * minor fix * minor fix * add sort by key test * revert scatter change * add document * fix py format Co-authored-by: masa <masa@pop-os.localdomain>
* sort refactor initial import * sort test working * scatter 1d with positive indices working * remove negatiev indices, using extern for now * minor fix * minor fix * add sort by key test * revert scatter change * add document * fix py format Co-authored-by: masa <masa@pop-os.localdomain>
* sort refactor initial import * sort test working * scatter 1d with positive indices working * remove negatiev indices, using extern for now * minor fix * minor fix * add sort by key test * revert scatter change * add document * fix py format Co-authored-by: masa <masa@pop-os.localdomain>
* sort refactor initial import * sort test working * scatter 1d with positive indices working * remove negatiev indices, using extern for now * minor fix * minor fix * add sort by key test * revert scatter change * add document * fix py format Co-authored-by: masa <masa@pop-os.localdomain>
* sort refactor initial import * sort test working * scatter 1d with positive indices working * remove negatiev indices, using extern for now * minor fix * minor fix * add sort by key test * revert scatter change * add document * fix py format Co-authored-by: masa <masa@pop-os.localdomain>
This adds support for sort by key in topi cuda by refactoring the sort IR developed in #7099. By doing different initialization, the exact same sort IR can be reused for sort by key. This is because the original IR already supports argsort and argsort is just a special case of sort by key.
My motivation was to improve the default performance of scatter 1D using sorting based approach (see #7056), when thrust is not available. But it turns out that the sequential scatter kernel is much faster than doing TIR sort, so I dropped that goal for now (this also suggests that thrust's radix sort I used in #7056 is insanely fast). I am sending this PR anyway in the hope that someone would find sort by key useful. My branch that does scatter 1D using the TIR sorting kernel, along with a benchmark script, is available at https://github.com/masahi/tvm/tree/sort_ir_refactor_with_scatter
I also removed
sort_nms_ir
, since it is not used anymore and I think it would be much faster to sort entire rows usingsort_ir
anyway.please review @mbrookhart @Laurawly @zhiics @tkonolige