Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] NMS in ONNX #6839

Merged
merged 10 commits into from
Dec 14, 2020
Merged

[ONNX] NMS in ONNX #6839

merged 10 commits into from
Dec 14, 2020

Conversation

mbrookhart
Copy link
Contributor

This PR adds ONNX support for NonMaxSuppression.

The ONNX API is a little odd, it roughly follows the TF op for combined_non_max_suppression. To implement it in tvm, I decided to go with relay while loops instead of writing a new TOPI kernel.

Getting it to work on CUDA was a pain, first, I needed to change a couple of input values from attributes to parameters, then the cuda kernels were out of spec and skipping tests. Implementing those tests and passing them required refactoring some of the kernels to bring them in spec, and then removing cuda threads in several places to get the CUDA kernel to return the correct results.

I labeled this WIP because I'd like to spend more time trying to figure out how to speed up the cuda kernels, but I'd love ideas from anyone who's interested in looking at it.

@jroesch @jwfromm @zhiics @tkonolige @csullivan

Because you've touched the cuda nms file, could you take a look @Laurawly @yongwww @kevinthesun ?

Thanks,
Matthew

output[i, j] = -1
valid_box_count[i, 0] = valid_idx[0]

return ib.get()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you show the performance benchmark on some popular OD model workloads with the modified nms.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been attempted to get SSD-Mobilenet from the ONNX model zoo working, but I"m hitting other bugs. I'll post some perf metrics as soon as I can get that working.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed a number of bugs in #6906 that allow me to run SSD-RN50 on CPU, but with those fixes, I'm hitting the isuse of cuda topk not supporting dynamic shapes, and the output of NMS is intrinsically dynamic. I'm back to trying to solve that problem before I can give you perf numbers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still struggling to get full object detection models working on GPU due to the issues with dynamic topk in cuda, but while I work on that, I did some profiling of the unit test. With batch size = 2, num_anchors = 5, and num_classes = 2, I get the following performance on my 1070Ti:

Total Execution time: 900 microseconds
NMS time: 66 microseconds
get_valid_counts time: 15 microseconds.

it looks like the performance issues in the onnx kernel are all about the slicing/expand dims and loop bounds checking, here are the most expensive ops:

0                      fused_dyn_strided_slice_1        89.28450                                     
1                      fused_add_2                      87.05910                                     
2                      fused_less_min                   83.04660                                     
3                    fused_vision_non_max_suppression   66.43930                                     
4   fused_expand_dims_cast_concatenate_expand_dims...   53.65090                                     
5                      fused_reshape_concatenate        46.64780                                     
6                      fused_dyn_strided_slice          46.59800                                     
7                      fused_concatenate_2              33.06690                                     
8                      fused_dyn_broadcast_to_1         29.49310                                     
9                      fused_squeeze_concatenate        26.83250                                     
10                 fused_cast_expand_dims_concatenate   24.00320                                     
11  fused_concatenate_cast_like_less_cast_like_add...   23.53060                                     
12                     fused_dyn_strided_slice_2        23.39070                                     
13  fused_concatenate_cast_like_less_cast_like_add...   22.79010                                     
14                     fused_shape_of_1                 21.37450                                     
15                     fused_shape_of_2                 21.02070                                     
16                     fused_dyn_broadcast_to           18.43470                                     
17                     fused_vision_get_valid_counts    14.89870

cc @jwfromm @tkonolige

Copy link
Contributor

@Laurawly Laurawly Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you show some performance numbers between the original cuda nms ir and your implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not been able to pass tests with the kernels currently in master, so I can't compare apples to apples

import sys
import pytest

pytest.main(sys.argv)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the thought behind this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:/ I don't think I did it, I wonder if it's a rule in autoformatting

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I did it, this means that Pytest will actually take command line flags when run as main, including -k, -s, etc.

@jwfromm
Copy link
Contributor

jwfromm commented Nov 12, 2020

Overall LGTM. I think the documentation on the NMS onnx conversion could be more detailed since it's a little difficult to understand right now.

@mbrookhart mbrookhart changed the title [WIP][ONNX] NMS in ONNX [ONNX] NMS in ONNX Nov 18, 2020
(i * num_anchors + j) * elem_length + k
]
out_indices[i * num_anchors + valid_count[i]] = j
valid_count[i] += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's data race between line 129 and line 133 for valid_count for loading and writing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use atomic add here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

atomic_add doesn't work with nvptx. That's a headache...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not have a race condition here because I'm not parallelizing over that axis. The issue with the current kernel is that it cannot implement the ONNX API because it cannot properly sort the output indices, and ONNX needs those outputs in the correct order to do some post-processing.

Copy link
Contributor

@Laurawly Laurawly Dec 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. So we lose parallelism in num_anchors (could be quite large) compared with the original implementation. Could you elaborate on why we are not getting correct output indices in current implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the sort?

It's faster, but it's not correct, that PR also disabled the relevant tests I need to ensure ONNX works properly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll go back and look at what you had.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you happen to have a script to reproduce the benchmarks you had in that PR? I assume it was a TF import?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was actually a mxnet model and the model is a customized detection model, and num_anchors is really large. You can run this tutorial for a benchmark reference: https://tvm.apache.org/docs/tutorials/frontend/deploy_ssd_gluoncv.html

Copy link
Contributor

@Laurawly Laurawly Dec 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I mentioned in PR #5339 that the model is ssd_resnet50_v1 with (1, 3, 512, 512) input shape and with Thrust turned on in build. And you can use TVM profiling tool to get the timing for each op.

output[i, j] = -1
valid_box_count[i, 0] = valid_idx[0]

return ib.get()
Copy link
Contributor

@Laurawly Laurawly Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you show some performance numbers between the original cuda nms ir and your implementation?

ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "virtual_thread", nthread_bx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you leverage more on the changes of this file based on performance results?

@mbrookhart
Copy link
Contributor Author

#7005 re implemented some of the features in this PR, I'll rebase and try to reconcile.

@mbrookhart
Copy link
Contributor Author

@Laurawly @kevinthesun I have rebased, but I was unable to get it passing tests with Yao's changes. I'm going back through the kernels one by one to see if I can get the faster versions to pass tests before attempting the ONNX integration.

(i * num_anchors + j) * elem_length + k
]
out_indices[i * num_anchors + valid_count[i]] = j
valid_count[i] += 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not have a race condition here because I'm not parallelizing over that axis. The issue with the current kernel is that it cannot implement the ONNX API because it cannot properly sort the output indices, and ONNX needs those outputs in the correct order to do some post-processing.

out[batch_idx * num_anchors + atomic_add_return[batch_idx]] = 0

with ib.if_scope(idxm(idx, num_anchors) >= valid_box_count[batch_idx]):
out[idx] = -1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation of rearrange_indices_out_ir returns an undersized tensor in some case, I think the threading isn't quite right, but i haven't been able to fix.

@mbrookhart
Copy link
Contributor Author

@Laurawly I took the mxnet example you provided and ran it with the debug runtime. It required a little bit of editing, APIs have changed slightly since that tutorial was written. Anyway, this is what I get on my 1070 TI with Thrust enabled.

main:

Ops                                                                                                     Time(us)    Time(%)  Shape
---                                                                                                     --------    -------  ----- 
fused_vision_non_max_suppression                                                                        139329.0    74.66    (1, 122640, 6)
fused_vision_get_valid_counts                                                                           124.255     0.067    (1, 122640, 6)     

this PR:

fused_vision_get_valid_counts                                                                           46138.3    50.891   (1, 122640, 6)  
fused_vision_non_max_suppression                                                                        12319.8    13.589   (1, 122640, 6)

The get valid counts function slow down, but I'm actually seeing the total runtime of these ops decrease from 139.3ms to 58.5ms

My modifications to the example can be found here: https://gist.github.com/mbrookhart/df25427cbbfb3c73ed16be72c8525610

@Laurawly
Copy link
Contributor

Laurawly commented Dec 10, 2020

@Laurawly I took the mxnet example you provided and ran it with the debug runtime. It required a little bit of editing, APIs have changed slightly since that tutorial was written. Anyway, this is what I get on my 1070 TI with Thrust enabled.

main:

Ops                                                                                                     Time(us)    Time(%)  Shape
---                                                                                                     --------    -------  ----- 
fused_vision_non_max_suppression                                                                        139329.0    74.66    (1, 122640, 6)
fused_vision_get_valid_counts                                                                           124.255     0.067    (1, 122640, 6)     

this PR:

fused_vision_get_valid_counts                                                                           46138.3    50.891   (1, 122640, 6)  
fused_vision_non_max_suppression                                                                        12319.8    13.589   (1, 122640, 6)

The get valid counts function slow down, but I'm actually seeing the total runtime of these ops decrease from 139.3ms to 58.5ms

My modifications to the example can be found here: https://gist.github.com/mbrookhart/df25427cbbfb3c73ed16be72c8525610

The time measurement is not as fast as before because of this PR: #7005. If you reverse back this one, you should get fairly good performance without any problem with the correctness. Because after my improvement, nms is not a bottleneck of the ssd model anymore while in your measurement, it seems it still is. So my point is your changes to PR #7005 here:

if return_indices:
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = batch_size // max_threads + 1
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
i = bx * max_threads + tx
with ib.if_scope(i < batch_size):
with ib.for_range(0, valid_count[i]) as j:
idx = box_indices[i * num_anchors + j]
with ib.if_scope(idx >= 0):
box_indices[i * num_anchors + j] = indices[i * num_anchors + idx]

and here:
out_shape = box_indices.shape
valid_box_count_shape = [box_indices.shape[0], 1]
valid_box_count = tvm.tir.decl_buffer(valid_box_count_shape, "int32", "valid_box_count")
output = tvm.tir.decl_buffer(box_indices.shape, "int32", "output")
return te.extern(
[out_shape, valid_box_count_shape],
[box_indices],
lambda ins, outs: rearrange_indices_out_ir(ins[0], outs[0], outs[1]),
dtype="int32",
out_buffers=[output, valid_box_count],
name="rearrange_indices_out_gpu",
tag="rearrange_indices_out_gpu",

should be faster and thus gets the performance improvement.

@mbrookhart
Copy link
Contributor Author

So, we seem to be at a bit of an impass. The implementation that is in master is not 100% correct, and it was even less correct before #7005. This PR is the best combination of correctness and performance we have a the moment. I would like to get it merged to have a correct implementation in place, but I would also like to find a way to speed up get valid counts, what I have here is not great.

I'm happy to spend the next few days trying to improve parallelism in get_valid_counts, but given that this is a performance improvement over main and it's more correct, I think we shouldn't block merging over the current performance, we could always do a second PR to improve performance once we have the unit tests in place

@mbrookhart
Copy link
Contributor Author

@Laurawly I just rewrote get_valid_counts in the way you suggested. I still need to better parallelize the sum/conditional scan operation, but this takes it to:

Ops                                                                                                     Time(us)   Time(%)  Shape               Inputs  Outputs  
---                                                                                                     --------   -------  -----               ------  -------  
fused_vision_non_max_suppression                                                                        12105.9    25.723   (1, 122640, 6)      
fused_vision_get_valid_counts                                                                           3517.22    7.474    (1, 122640, 6)

I'm going to take another look at NMS before I try to parallelize the sum

@Laurawly
Copy link
Contributor

Laurawly commented Dec 10, 2020

So, we seem to be at a bit of an impass. The implementation that is in master is not 100% correct, and it was even less correct before #7005. This PR is the best combination of correctness and performance we have a the moment. I would like to get it merged to have a correct implementation in place, but I would also like to find a way to speed up get valid counts, what I have here is not great.

I'm happy to spend the next few days trying to improve parallelism in get_valid_counts, but given that this is a performance improvement over main and it's more correct, I think we shouldn't block merging over the current performance, we could always do a second PR to improve performance once we have the unit tests in place

We have tested the correctness of mxnet ssd related detection models before PR #7005 and they are in the correctness margin for customers to ship. I do suggest a perf comparison of your implementation and the one in main block by block. And an overall performance comparison with the implementation before PR #7005's change to nms.py.

@mbrookhart
Copy link
Contributor Author

We have tested the correctness of ssd related detection models before PR #7005 and they are in the correctness margin for customers to ship.

This indicates that the kernel was correct for certain inputs, but not that it was correct overall.

@Laurawly
Copy link
Contributor

I'm happy to spend the next few days trying to improve parallelism in get_valid_counts, but given that this is a performance improvement over main and it's more correct, I think we shouldn't block merging over the current performance, we could always do a second PR to improve performance once we have the unit tests in place

Why don't you put the new implementation in a separate file say nms_onnx.py or a separate function. And we can see if we can merge it back once enough tests have passed to test the flakiness of the kernel and when it has better parallelism than using only batch_size number of threads.

@mbrookhart
Copy link
Contributor Author

I really don't want the code fragmentation, especially when it's currently ~9x faster than main. I will keep pounding on speeding it up

@mbrookhart
Copy link
Contributor Author

I found a combination of serial and threaded portions of NMS that combine to make it fast while still passing the tests:

fused_vision_get_valid_counts                                                                           3674.62    10.87    (1, 122640, 6)
fused_vision_non_max_suppression                                                                        402.791    1.191    (1, 122640, 6)

I'll go back to parallelizing the sum/scan on get_valid_counts tomorrow, but at this point, this is ~30x faster than main.

@mbrookhart
Copy link
Contributor Author

@Laurawly Commit fd5ce64, just before #7005, gets me this :/ I don't think the perf problem came from 7005

fused_vision_non_max_suppression                                                                        142188.0    76.13    (1, 122640, 6)
fused_vision_get_valid_counts                                                                           76.874      0.041    (1, 122640, 6)

Copy link
Contributor

@Laurawly Laurawly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you show perf comparison of this cuda nms implementation with this version of nms https://github.com/apache/tvm/blob/fdfc7eb8876278cbfd31f6d1d82ca75829a7aac4/python/tvm/topi/cuda/nms.py ? Trying to understand where are the speedups coming from. I'll also run some benchmark from my end when I get some bandwidth.

and target.kind.name == "cuda"
and tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True)
):
if tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True):
Copy link
Contributor

@Laurawly Laurawly Dec 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opencl doesn't have thrust lib, need to assert target is cuda here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks!

ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "virtual_thread", nthread_bx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doens't seem to make sort faster or more stable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes it support dynamic input shapes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm working on a fuller rewrite of the sort kernel for stabilty in another branch, I understand why this isn't fast or stable, and I can do better for small inputs, but I need to rework things more to support full sized inputs.

@Laurawly
Copy link
Contributor

Laurawly commented Dec 11, 2020

@Laurawly fdfc7eb matches the other tests I've done on recent versions of the main branch, 131ms total time, 75% of runtime.

This benchmark has a large gap compared with the one we tested before. We need to verify it on our end first. CC @zhiics @yidawang

@jwfromm
Copy link
Contributor

jwfromm commented Dec 11, 2020

@Laurawly I understand your concern about performance but I think @mbrookhart has provided pretty ample evidence that this code is better than what we have in master in terms of both correctness and performance. I think we should merge this and investigate improving performance even further in a follow-up PR. Since this PR itself causes no regressions, there's no need to make it even bigger with tons of additional optimization.

@Laurawly
Copy link
Contributor

@Laurawly I understand your concern about performance but I think @mbrookhart has provided pretty ample evidence that this code is better than what we have in master in terms of both correctness and performance. I think we should merge this and investigate improving performance even further in a follow-up PR. Since this PR itself causes no regressions, there's no need to make it even bigger with tons of additional optimization.

@zhiics is currently testing it. If it works for him, I’ll approve it.

@zhiics
Copy link
Member

zhiics commented Dec 12, 2020

Sorry for the late response. I will update on Monday.

Copy link
Contributor

@Laurawly Laurawly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just verified with @zhiics , it works for us! Thanks for the contribution 😄

@mbrookhart
Copy link
Contributor Author

Awesome, thank you for the help @Laurawly. The idea for speeding up get valid counts was especially helpful. I have an idea for parallelizing the last serial function there, but it follows a similar path to #7099, where I'm hitting some odd lowering issues. As soon as I can work those out, I'll submit a PR for parallel scan to further optimize get_valid_counts.

@zhiics zhiics merged commit 054466b into apache:main Dec 14, 2020
@zhiics
Copy link
Member

zhiics commented Dec 14, 2020

Thanks everyone!

@mbrookhart mbrookhart deleted the mbrookhart/onnx_nms branch December 14, 2020 18:54
TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Jan 20, 2021
* NMS partially working on CPU, fails on GPU

* support dynamic iou_threshold

* WIP NMS with while loops

* working nms with dynamic shapes

* add a test with dynamic score_threshold and pass it

* Fix type checking in lambda lift

* ONNX NMS working on GPU, had to remove threading from some kernels

fix lint

fix lambda lift tests

fix unit tests

respond to review comments

fix lint

* better parallelize get_valid_counts

* improve nms parallelization

* respond to cuda/thrust enablement issue

Co-authored-by: Jared Roesch <roeschinc@gmail.com>
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Jan 21, 2021
* NMS partially working on CPU, fails on GPU

* support dynamic iou_threshold

* WIP NMS with while loops

* working nms with dynamic shapes

* add a test with dynamic score_threshold and pass it

* Fix type checking in lambda lift

* ONNX NMS working on GPU, had to remove threading from some kernels

fix lint

fix lambda lift tests

fix unit tests

respond to review comments

fix lint

* better parallelize get_valid_counts

* improve nms parallelization

* respond to cuda/thrust enablement issue

Co-authored-by: Jared Roesch <roeschinc@gmail.com>
electriclilies pushed a commit to electriclilies/tvm that referenced this pull request Feb 18, 2021
* NMS partially working on CPU, fails on GPU

* support dynamic iou_threshold

* WIP NMS with while loops

* working nms with dynamic shapes

* add a test with dynamic score_threshold and pass it

* Fix type checking in lambda lift

* ONNX NMS working on GPU, had to remove threading from some kernels

fix lint

fix lambda lift tests

fix unit tests

respond to review comments

fix lint

* better parallelize get_valid_counts

* improve nms parallelization

* respond to cuda/thrust enablement issue

Co-authored-by: Jared Roesch <roeschinc@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants