Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] More graph rewrites for Faster RCNN / MaskRCNN #7346

Merged
merged 11 commits into from
Jan 27, 2021

Conversation

masahi
Copy link
Member

@masahi masahi commented Jan 27, 2021

This PR adds two new graph rewrite to optimize Faster RCNN / MaskRCNN. Happy to split them into two PRs if preferred.

The first one is to exploit the fact that in PyTorch detection models, NMS is always followed by post NMS topk, as shown below.

https://github.com/pytorch/vision/blob/8ebfd2f5d5f1792ce2cf5a2329320f604530a68e/torchvision/models/detection/rpn.py#L272-L275

We can extract that topk parameter and use it as max_out_size parameter in our NMS. This brings a good speed up 4.51 milli sec -> 4.11 milli sec, and further speed up is easily expected if we had TIR while loop (cc @tqchen)

The second is to replace the repeated scatter loop in

https://github.com/pytorch/vision/blob/6315358dd06e3a2bcbe9c1e8cdaa10898ac2b308/torchvision/ops/poolers.py#L20-L29

with something like this:

indices_per_level = []
for level in range(num_scales):
    idx_in_level = torch.where(levels == level)[0]
    indices_per_level.append(idx_in_level)

stacked_features = torch.cat(roi_align_results, dim=0)
stacked_indices = torch.cat(indices_per_level, dim=0)
argsort_indices = torch.argort(stacked_indices)
return stacked_features[argsort_indices, :]

i.e., we are able to remove torch.zeros (which turns out very expensive, due to too much any_dim generated by Relay) and repeated 4D scatters (which is slow because scatters cannot be parallelized well). Instead, we can do concat, argsort, and batched gather to get equivalent result, which is much more efficient. This transformation is not at all obvious, I think this is a great example of the power of graph rewrite. It cuts more than 10 milli seconds from MaskRCNN / FasterRCNN.

Unfortunately I expect this PR is hard to review, let me know if you have any questions. I tried to give detailed comments to aid understanding.

This concludes the series of PRs I did to optimize MaskRCNN on GPU + VM, here is the current numbers. Surprisingly, NVPTX generates much better code for the dynamic injective ops, which is one of the bottlenecks in MaskRCNN due to a certain limitation in Relay + TE (too many unnecessary any_dim generated). I hope we can discuss this performance result further in the forum.

TVM result is obtained after auto scheduling. In MaskRCNN, there are some dynamic batch 
conv2d and conv2d transpose that cannot be tuned and hence extremely slow. 

cublas is always required to get reasonable result, since there is one big dynamic dense op
that is extremely slow with the default schedule. 

GTX 1070 ti, on an input (1, 3, 300, 300)

Faster RCNN
Torch: 0.0738 sec
TVM with cuda target + cublas: 0.0712 sec
TVM with nvptx target + cublas: 0.0708 sec

MaskRCNN
Torch: 0.115 sec
TVM with cuda target + cublas: 0.166 sec
TVM with nvptx target + cublas: 0.135 sec

please review @zhiics @kevinthesun @mbrookhart @jwfromm @anijain2305 @trevor-m

Copy link
Contributor

@trevor-m trevor-m left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, it looks good to me.

I have one question regarding the topk_after_batch_nms_pattern. It seems the topk slice will no longer get applied to the true branch. Before rewrite, it was applied to the result of the if statement - both branches. After rewrite, it is folded into NMS via max_output_size, but that is only in the false branch. Would that cause problems?

@masahi
Copy link
Member Author

masahi commented Jan 27, 2021

No, the If there is for guarding against the case where there is no boxes, see

https://github.com/pytorch/vision/blob/8ebfd2f5d5f1792ce2cf5a2329320f604530a68e/torchvision/ops/boxes.py#L78-L79

So applying topk to an empty tensor is nop anyway.

@trevor-m
Copy link
Contributor

trevor-m commented Jan 27, 2021

No, the If there is for guarding against the case where there is no boxes, see

https://github.com/pytorch/vision/blob/8ebfd2f5d5f1792ce2cf5a2329320f604530a68e/torchvision/ops/boxes.py#L78-L79

So applying topk to an empty tensor is nop anyway.

Got it, thanks! I guess the pattern does not guarantee that the true branch is for that 0 box case, but since this rewrite is only meant to be used for this particular model it is fine.

Copy link
Member

@zhiics zhiics left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the effort. lgtm. just a nitpick, feel free to ignore.

python/tvm/relay/frontend/pytorch_utils.py Show resolved Hide resolved
@zhiics zhiics merged commit 4006bde into apache:main Jan 27, 2021
alexwong pushed a commit to alexwong/tvm that referenced this pull request Feb 11, 2021
* add post nms topk to max_out_size rewrite

* add argsort conversion

* scatter pattern first cut

* matching seems to working

* dup matching fixed

* add converter

* conversion seems working

* add reshape, use take

* remove pytorch argsort converter

* update test

* add doc
electriclilies pushed a commit to electriclilies/tvm that referenced this pull request Feb 18, 2021
* add post nms topk to max_out_size rewrite

* add argsort conversion

* scatter pattern first cut

* matching seems to working

* dup matching fixed

* add converter

* conversion seems working

* add reshape, use take

* remove pytorch argsort converter

* update test

* add doc
Lokiiiiii pushed a commit to Lokiiiiii/tvm that referenced this pull request Mar 2, 2021
* add post nms topk to max_out_size rewrite

* add argsort conversion

* scatter pattern first cut

* matching seems to working

* dup matching fixed

* add converter

* conversion seems working

* add reshape, use take

* remove pytorch argsort converter

* update test

* add doc
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Mar 2, 2021
* add post nms topk to max_out_size rewrite

* add argsort conversion

* scatter pattern first cut

* matching seems to working

* dup matching fixed

* add converter

* conversion seems working

* add reshape, use take

* remove pytorch argsort converter

* update test

* add doc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants