-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOPI] Simplify GPU NMS IR and optimize a bit #7136
Conversation
UPDATE: Above issues have been fixed, CI is also green. |
b70e26c
to
622a876
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, I like how compact this makes it. I see the issue with the IOU rejection algorithm. It's a hard thing to speed up, parallelizing it leads to race conditions in which boxes get selected. For instance, if we parallelize the outer loop, we can imagine a situation where boxes 0, 1, and 2 overlap. 0 and 1 overlap enough to reject box 1, and 1 and 2 overlap enough to reject box 2, but 0 and 2 don't overlap enough to reject box 2. In a serial case, 0 would reject 1, then 2 would only be compared to 0 and accepted. We would return boxes 0 and 2. In a parallel case we could have a race condition where box 2 is rejected by box 1 before box 1 is rejected. The thread comparing 0 and 1 then rejects 1, and we return only box 0.
I've been thinking about it tonight, I don't think it's possible to parallelize the outer loop without introducing race conditions. We may, however, be able to parallelize the inner loop. If any of the boxes k in 0:j reject box j, that takes box j out of future comparisons, but no other box is affected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Let's move optimizing the triangle loop to a second PR. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job 👍
Thanks @masahi @mbrookhart |
* remove get_valid_counts from pytorch nms * fix pytorch nms for negative score * merge reset by -1 * move max_out_size handling to triangle loop * update torch nms test * fuse the last two kernels * parallelize the first kernel * merge first and last kernel * remove unnecessary cases * fix typo * revert pytorch frontend change * fuse rearrange step with triangle loop * fix max_output_size handling * check if already surpressed * fix topi vision test by wrapping tir const around int argument * fix for num anchors = 0 case * fix missing zero init of num valid boxes when the input is empty * add some comments and missing doc * typo fix * add a guard against zero dim grid / thread block inside ir_buidlder * typo fix * trigger CI
* remove get_valid_counts from pytorch nms * fix pytorch nms for negative score * merge reset by -1 * move max_out_size handling to triangle loop * update torch nms test * fuse the last two kernels * parallelize the first kernel * merge first and last kernel * remove unnecessary cases * fix typo * revert pytorch frontend change * fuse rearrange step with triangle loop * fix max_output_size handling * check if already surpressed * fix topi vision test by wrapping tir const around int argument * fix for num anchors = 0 case * fix missing zero init of num valid boxes when the input is empty * add some comments and missing doc * typo fix * add a guard against zero dim grid / thread block inside ir_buidlder * typo fix * trigger CI
* remove get_valid_counts from pytorch nms * fix pytorch nms for negative score * merge reset by -1 * move max_out_size handling to triangle loop * update torch nms test * fuse the last two kernels * parallelize the first kernel * merge first and last kernel * remove unnecessary cases * fix typo * revert pytorch frontend change * fuse rearrange step with triangle loop * fix max_output_size handling * check if already surpressed * fix topi vision test by wrapping tir const around int argument * fix for num anchors = 0 case * fix missing zero init of num valid boxes when the input is empty * add some comments and missing doc * typo fix * add a guard against zero dim grid / thread block inside ir_buidlder * typo fix * trigger CI
* remove get_valid_counts from pytorch nms * fix pytorch nms for negative score * merge reset by -1 * move max_out_size handling to triangle loop * update torch nms test * fuse the last two kernels * parallelize the first kernel * merge first and last kernel * remove unnecessary cases * fix typo * revert pytorch frontend change * fuse rearrange step with triangle loop * fix max_output_size handling * check if already surpressed * fix topi vision test by wrapping tir const around int argument * fix for num anchors = 0 case * fix missing zero init of num valid boxes when the input is empty * add some comments and missing doc * typo fix * add a guard against zero dim grid / thread block inside ir_buidlder * typo fix * trigger CI
I spent some time studying GPU NMS implementation in detail, and found a number of possible improvements. I simplified the IR quite a bit, the main NMS routine now consists only of two kernel calls (initialization and the triangle loop). I also removed unnecessary computation and memory movement. I believe what I have now is a much simpler, and faster NMS.
In particular, on PyTorch Mask RCNN workload (4900 boxes), this PR cuts more than 1000 micro seconds. This is actually significant because the PyTorch NMS kernel takes only 200 micro seconds. NMS from GluonCV SSD is also faster now by 2x (24 ms -> 12 ms, see blow).
Our NMS is still super slow for PyTorch models due to the serial triangle loop (
fused_vision_non_max_suppression_kernel2
below): PyTorch NMS doesn't have topk orscore_threshold
, so we need to take all boxes into account. Optimizing and parallelizing the triangle loop is an important future work that I plan on doing.I must admit that some optimization I did in the triangle loop is only possible because we are doing the loop sequentially. If we find a way to parallelize that loop, we need to remove those changes. But until then, I think this is a better approach.
cc @Laurawly @kevinthesun @mbrookhart @zhiics @vinx13
nvprof output from running PyTorch MaskRCNN
Before:
After:
GluonCV SSD
Before:
After: