Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Simplify GPU NMS IR and optimize a bit #7136

Merged
merged 22 commits into from
Dec 21, 2020

Conversation

masahi
Copy link
Member

@masahi masahi commented Dec 20, 2020

I spent some time studying GPU NMS implementation in detail, and found a number of possible improvements. I simplified the IR quite a bit, the main NMS routine now consists only of two kernel calls (initialization and the triangle loop). I also removed unnecessary computation and memory movement. I believe what I have now is a much simpler, and faster NMS.

In particular, on PyTorch Mask RCNN workload (4900 boxes), this PR cuts more than 1000 micro seconds. This is actually significant because the PyTorch NMS kernel takes only 200 micro seconds. NMS from GluonCV SSD is also faster now by 2x (24 ms -> 12 ms, see blow).

Our NMS is still super slow for PyTorch models due to the serial triangle loop (fused_vision_non_max_suppression_kernel2 below): PyTorch NMS doesn't have topk or score_threshold, so we need to take all boxes into account. Optimizing and parallelizing the triangle loop is an important future work that I plan on doing.

I must admit that some optimization I did in the triangle loop is only possible because we are doing the loop sequentially. If we find a way to parallelize that loop, we need to remove those changes. But until then, I think this is a better approach.

cc @Laurawly @kevinthesun @mbrookhart @zhiics @vinx13

nvprof output from running PyTorch MaskRCNN

Before:

            Type  Time(%)      Time     Calls       Avg       Min       Max  Name                                                                                            
 GPU activities:   99.95%  2.43544s         1  2.43544s  2.43544s  2.43544s  fused_vision_non_max_suppression_kernel2
                    0.02%  443.24us         1  443.24us  443.24us  443.24us  fused_vision_non_max_suppression_kernel6                                                        
                    0.02%  420.55us         1  420.55us  420.55us  420.55us  fused_vision_non_max_suppression_kernel1                                                                            
                    0.01%  253.40us         1  253.40us  253.40us  253.40us  fused_vision_non_max_suppression_kernel5                                    
                    0.00%  23.298us         1  23.298us  23.298us  23.298us  void cub::DeviceRadixSortSingleTileKernel<cub::DeviceRadixSortPolicy<float, int, int>::Policy700, bool=1, float, int, int>(int const *, cub::DeviceRadixSortSingleTileKernel<cub::DeviceRadixSortPolicy<float, int, int>::Policy700, bool=1, float, int, int>*, cub::DeviceRa
dixSortPolicy<float, int, int>::Policy700 const *, cub::DeviceRadixSortSingleTileKernel<cub::DeviceRadixSortPolicy<float, int, int>::Policy700, bool=1, float, int, int>**, bool=1, int, int)                                                                                       

After:

            Type  Time(%)      Time     Calls       Avg       Min       Max  Name                                                                                            
 GPU activities:  100.00%  2.17494s         1  2.17494s  2.17494s  2.17494s  fused_vision_non_max_suppression_kernel2
                    0.00%  22.401us         1  22.401us  22.401us  22.401us  void cub::DeviceRadixSortSingleTileKernel<cub::DeviceRadixSortPolicy<float, int, int>::Policy700
, bool=1, float, int, int>(int const *, cub::DeviceRadixSortSingleTileKernel<cub::DeviceRadixSortPolicy<float, int, int>::Policy700, bool=1, float, int, int>*, cub::DeviceRa
dixSortPolicy<float, int, int>::Policy700 const *, cub::DeviceRadixSortSingleTileKernel<cub::DeviceRadixSortPolicy<float, int, int>::Policy700, bool=1, float, int, int>**, b
ool=1, int, int)                                                                                                                                                             
                    0.00%  14.976us        10  1.4970us     512ns  6.7840us  [CUDA memcpy HtoD]
                    0.00%  5.4720us         1  5.4720us  5.4720us  5.4720us  fused_vision_non_max_suppression_kernel1

GluonCV SSD
Before:

            Type  Time(%)      Time     Calls       Avg       Min       Max  Name                                                                                            
 GPU activities:   31.27%  24.388ms         1  24.388ms  24.388ms  24.388ms  fused_vision_non_max_suppression_kernel2                                                        
                   18.44%  14.378ms       206  69.795us     544ns  970.55us  [CUDA memcpy HtoD]                                                                              
                    4.59%  3.5769ms         1  3.5769ms  3.5769ms  3.5769ms  fused_vision_non_max_suppression_kernel4
                    3.82%  2.9766ms         1  2.9766ms  2.9766ms  2.9766ms  fused_vision_get_valid_counts_kernel1
                    3.64%  2.8394ms         1  2.8394ms  2.8394ms  2.8394ms  fused_nn_conv2d_add_4_kernel0

After:

            Type  Time(%)      Time     Calls       Avg       Min       Max  Name                                                                                            
 GPU activities:   24.25%  14.487ms       206  70.325us     544ns  987.28us  [CUDA memcpy HtoD]                                                                              
                   20.90%  12.489ms         1  12.489ms  12.489ms  12.489ms  fused_vision_non_max_suppression_kernel2                                                        
                    4.77%  2.8497ms         1  2.8497ms  2.8497ms  2.8497ms  fused_nn_conv2d_add_4_kernel0                                               
                    4.29%  2.5619ms         1  2.5619ms  2.5619ms  2.5619ms  fused_nn_conv2d_add_11_kernel0
                    3.09%  1.8472ms         1  1.8472ms  1.8472ms  1.8472ms  fused_nn_conv2d_add_nn_relu_3_kernel0

@masahi masahi changed the title [TOPI, Torch] Simplify GPU NMS and fix PyTorch NMS conversion with negative scores [TOPI, Torch] Simplify GPU NMS and optimize a bit Dec 20, 2020
@masahi masahi changed the title [TOPI, Torch] Simplify GPU NMS and optimize a bit [TOPI] Simplify GPU NMS IR and optimize a bit Dec 20, 2020
@masahi masahi marked this pull request as draft December 20, 2020 22:53
@masahi
Copy link
Member Author

masahi commented Dec 20, 2020

There is some issue when running one of NMS tests in test_any.py, which is not run on CI due to flaky sort. Convert to draft until I resolve this problem.

UPDATE: Above issues have been fixed, CI is also green.

@masahi masahi marked this pull request as ready for review December 21, 2020 02:07
Copy link
Contributor

@mbrookhart mbrookhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I like how compact this makes it. I see the issue with the IOU rejection algorithm. It's a hard thing to speed up, parallelizing it leads to race conditions in which boxes get selected. For instance, if we parallelize the outer loop, we can imagine a situation where boxes 0, 1, and 2 overlap. 0 and 1 overlap enough to reject box 1, and 1 and 2 overlap enough to reject box 2, but 0 and 2 don't overlap enough to reject box 2. In a serial case, 0 would reject 1, then 2 would only be compared to 0 and accepted. We would return boxes 0 and 2. In a parallel case we could have a race condition where box 2 is rejected by box 1 before box 1 is rejected. The thread comparing 0 and 1 then rejects 1, and we return only box 0.

I've been thinking about it tonight, I don't think it's possible to parallelize the outer loop without introducing race conditions. We may, however, be able to parallelize the inner loop. If any of the boxes k in 0:j reject box j, that takes box j out of future comparisons, but no other box is affected.

python/tvm/topi/cuda/nms.py Outdated Show resolved Hide resolved
python/tvm/topi/cuda/nms.py Show resolved Hide resolved
python/tvm/topi/cuda/nms.py Outdated Show resolved Hide resolved
Copy link
Contributor

@mbrookhart mbrookhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Let's move optimizing the triangle loop to a second PR. Thanks!

Copy link
Contributor

@Laurawly Laurawly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job 👍

@Laurawly Laurawly merged commit 82942fb into apache:main Dec 21, 2020
@Laurawly
Copy link
Contributor

Thanks @masahi @mbrookhart

masahi added a commit to masahi/tvm that referenced this pull request Dec 24, 2020
* remove get_valid_counts from pytorch nms

* fix pytorch nms for negative score

* merge reset by -1

* move max_out_size handling to triangle loop

* update torch nms test

* fuse the last two kernels

* parallelize the first kernel

* merge first and last kernel

* remove unnecessary cases

* fix typo

* revert pytorch frontend change

* fuse rearrange step with triangle loop

* fix max_output_size handling

* check if already surpressed

* fix topi vision test by wrapping tir const around int argument

* fix for num anchors = 0 case

* fix missing zero init of num valid boxes when the input is empty

* add some comments and missing doc

* typo fix

* add a guard against zero dim grid / thread block inside ir_buidlder

* typo fix

* trigger CI
TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Jan 20, 2021
* remove get_valid_counts from pytorch nms

* fix pytorch nms for negative score

* merge reset by -1

* move max_out_size handling to triangle loop

* update torch nms test

* fuse the last two kernels

* parallelize the first kernel

* merge first and last kernel

* remove unnecessary cases

* fix typo

* revert pytorch frontend change

* fuse rearrange step with triangle loop

* fix max_output_size handling

* check if already surpressed

* fix topi vision test by wrapping tir const around int argument

* fix for num anchors = 0 case

* fix missing zero init of num valid boxes when the input is empty

* add some comments and missing doc

* typo fix

* add a guard against zero dim grid / thread block inside ir_buidlder

* typo fix

* trigger CI
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Jan 21, 2021
* remove get_valid_counts from pytorch nms

* fix pytorch nms for negative score

* merge reset by -1

* move max_out_size handling to triangle loop

* update torch nms test

* fuse the last two kernels

* parallelize the first kernel

* merge first and last kernel

* remove unnecessary cases

* fix typo

* revert pytorch frontend change

* fuse rearrange step with triangle loop

* fix max_output_size handling

* check if already surpressed

* fix topi vision test by wrapping tir const around int argument

* fix for num anchors = 0 case

* fix missing zero init of num valid boxes when the input is empty

* add some comments and missing doc

* typo fix

* add a guard against zero dim grid / thread block inside ir_buidlder

* typo fix

* trigger CI
electriclilies pushed a commit to electriclilies/tvm that referenced this pull request Feb 18, 2021
* remove get_valid_counts from pytorch nms

* fix pytorch nms for negative score

* merge reset by -1

* move max_out_size handling to triangle loop

* update torch nms test

* fuse the last two kernels

* parallelize the first kernel

* merge first and last kernel

* remove unnecessary cases

* fix typo

* revert pytorch frontend change

* fuse rearrange step with triangle loop

* fix max_output_size handling

* check if already surpressed

* fix topi vision test by wrapping tir const around int argument

* fix for num anchors = 0 case

* fix missing zero init of num valid boxes when the input is empty

* add some comments and missing doc

* typo fix

* add a guard against zero dim grid / thread block inside ir_buidlder

* typo fix

* trigger CI
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants