Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic ONNX Importer #6351

Merged
merged 25 commits into from
Oct 3, 2020
Merged

Conversation

mbrookhart
Copy link
Contributor

@mbrookhart mbrookhart commented Aug 27, 2020

Hello Friends,

Over the last couple of months, @electriclilies and I have been working to add more dynamic support to relay ops, to separate the dynamic implementations into a dyn namespace, and to provide a pass for converting ops back to static forms when possible.

The culmination of that work is this PR, which refactors the ONNX importer to directly create dynamic relay graphs instead of using infer_value to make them static in the importer. Longer term, this will allow us to import dynamic models that we can't currently use.

We don't want to cause regressions for anyone, so this PR enables the dynamic_to_static pass by default in the graph runtime, we tested the PR against the ONNX model zoo https://github.com/onnx/models and fixed a number of issues in ops that apparently hadn't been tested with dynamic shapes to date.

An added benefit of this PR is that it removes a severe bottleneck in the infer_value calls. Models with lots of dynamic ops will import and compile much faster than before, Bert Squad from the ONNX model zoo imports and compiles in ~170s on master vs ~15s with this change.

This PR is not yet complete, we're working on adding and strided slice (#6316) to remove the last infer value calls.

Since we don't want to introduce regressions for anyone, I'd appreciate it if you could test any models you are currently running against this branch and let us know if you run into issues.

Thanks!

cc @masahi @jwfromm @soiferj @siju-samuel Please tag anyone else you think might be interested

@mbrookhart
Copy link
Contributor Author

cc @zhiics @icemelon9

@mbrookhart mbrookhart force-pushed the mbrookhart/dynamic_onnx branch 2 times, most recently from 4e3bb37 to 8b5899b Compare September 3, 2020 17:45
@mbrookhart
Copy link
Contributor Author

Thanks to @tmoreau89 for testing some custom models against this branch and finding a regression. Any one else using ONNX, I'd really appreciate it if you could do the same.

include/tvm/relay/transform.h Outdated Show resolved Hide resolved
python/tvm/relay/frontend/onnx.py Show resolved Hide resolved
python/tvm/relay/frontend/onnx.py Show resolved Hide resolved
tests/python/frontend/onnx/test_forward.py Outdated Show resolved Hide resolved
@tqchen
Copy link
Member

tqchen commented Sep 4, 2020

cc @zhiics @yzhliu

@tmoreau89
Copy link
Contributor

Thanks to @tmoreau89 for testing some custom models against this branch and finding a regression. Any one else using ONNX, I'd really appreciate it if you could do the same.

Happy to! On the custom model I tested, compilation time went from 95.8s down to 1.2s. Nice work!

@zhiics
Copy link
Member

zhiics commented Sep 4, 2020

@mbrookhart ping me please when it is ready for review

Lily Orth-Smith and others added 6 commits September 11, 2020 10:52
Make OneHot dynamic

Support BatchMatMul with dynamically shaped inputs

fix dynamic broadcast

Add null checks to broadcast_to rel functions

fail more isolated broadcast_to test

use StructuralEqual instead of pointer comparisions in dynamic_to_static pass

add an optional weight freeze argument to onnx importer

convert onnx resize to dynamic op

add dynamic expand to onnx importer

add a shape_func for power

fix BERTSquad, lint

handle onnx graph initializer parameters more intelligently
fix lint

fix Call reference

fix a type issue with expand

fix a bad test refactor

respond to review comments, fix batch matmul tests
@mbrookhart mbrookhart marked this pull request as ready for review September 11, 2020 20:19
Copy link
Member

@zhiics zhiics left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great effort. Only left some minor comments.

python/tvm/relay/frontend/onnx.py Outdated Show resolved Hide resolved
plevel=10,
)
if is_dynamic(out_type):
strategy.add_implementation(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this one be in generic.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, something very similar to this is already in generic.py, what I'm trying to do here is short-circuit the schedule if we have dynamic shapes. The x86 schedule, as written, assumes static shapes and breaks during schedule construction if I give it a dynamic input. Is there a cleaner way to do that short-circuit in generic.py?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not quite sure what is a better to do this though. @icemelon9 thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would the behavior be if instead we only had if not is_dynamic(out_type) to register the x86 schedule? I would think that the generic strategy would be used even if we dont readd it here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll give that a try! I'll report back shortly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unforunately, it seems like the compile engine can't find any schedules if I do this:

E             File "/home/mbrookhart/repos/mbrookhart_tvm/python/tvm/relay/backend/compile_engine.py", line 289, in lower_call
E               op, call.attrs, inputs, ret_type, target, use_autotvm=False
E             File "/home/mbrookhart/repos/mbrookhart_tvm/python/tvm/relay/backend/compile_engine.py", line 188, in select_implementation
E               best_plevel_impl = max(all_impls, key=lambda x: x.plevel)
E           ValueError: max() arg is an empty sequence

src/relay/transforms/dynamic_to_static.cc Outdated Show resolved Hide resolved
@zhiics
Copy link
Member

zhiics commented Sep 15, 2020

@jwfromm @electriclilies please take another look

@masahi
Copy link
Member

masahi commented Sep 15, 2020

@mbrookhart Does this PR enable compiling one model and running it with input data of different shapes?

@mbrookhart
Copy link
Contributor Author

@jwfromm @electriclilies @zhiics @csullivan Could you take another look?

python/tvm/autotvm/record.py Outdated Show resolved Hide resolved
python/tvm/relay/frontend/onnx.py Outdated Show resolved Hide resolved
plevel=10,
)
if is_dynamic(out_type):
strategy.add_implementation(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would the behavior be if instead we only had if not is_dynamic(out_type) to register the x86 schedule? I would think that the generic strategy would be used even if we dont readd it here.

tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)


# TODO(mbrookhart): enable cuda once VM supports heterogenous execution
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR #6337 is now merged, should we enable GPU here or are there still issues?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm hitting issues on dynamic strided slice and topk, I was going to wait until I had that fixed to enabled them in the onnx frontend.

@mbrookhart
Copy link
Contributor Author

Ping?

@zhiics
Copy link
Member

zhiics commented Oct 1, 2020

Copy link
Contributor

@electriclilies electriclilies left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, this looks good to me!

One thing I did notice is that in the onnx importer itself, you put warnings in the comments telling people that they will need to run the dynamic_to_static pass because some operators do not support dynamic shapes yet.

We should probably add note / warning to the importer documentation and tutorials -- I'm not sure if that should be a part of this PR or separate, though.

@zhiics
Copy link
Member

zhiics commented Oct 2, 2020

@mbrookhart please see Lily's last comment.

@mbrookhart
Copy link
Contributor Author

@zhiics @electriclilies Added some doc strings

@electriclilies
Copy link
Contributor

@mbrookhart Thanks! LGTM

@zhiics
Copy link
Member

zhiics commented Oct 2, 2020

@mbrookhart cool. I will merge once CI passes

@zhiics zhiics merged commit 2658ebe into apache:master Oct 3, 2020
@zhiics
Copy link
Member

zhiics commented Oct 3, 2020

TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Oct 13, 2020
* Change onnx importer to use dynamic upsampling3d (neo-ai#3)

fix pylint

* Refactor ONNX frontend to be dynamic

Make OneHot dynamic

Support BatchMatMul with dynamically shaped inputs

fix dynamic broadcast

Add null checks to broadcast_to rel functions

fail more isolated broadcast_to test

use StructuralEqual instead of pointer comparisions in dynamic_to_static pass

add an optional weight freeze argument to onnx importer

convert onnx resize to dynamic op

add dynamic expand to onnx importer

add a shape_func for power

fix BERTSquad, lint

handle onnx graph initializer parameters more intelligently

* Dynamic ONNX importer: Upsampling and Pad (neo-ai#2)

fix lint

fix Call reference

fix a type issue with expand

fix a bad test refactor

respond to review comments, fix batch matmul tests

* black format

* fix batch matmul test

* add dynamic strided slice to the onnx importer

* fix clip importer

* fix qnn tutorial

* fix bad merge, respond to review comments

* add a simple dynamic model test

* Add dynamic-shaped autopadding to convolution and pooling ops

* fix dynamic issues in a few ops

* fix pylint

* disable tests onnxrt doesn't support

* fix pytorch test

* respond to review comments

* add documentation about partially supporting dynamic shapes

Co-authored-by: Lily Orth-Smith <lorthsmith@octoml.ai>
TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Oct 14, 2020
* Change onnx importer to use dynamic upsampling3d (neo-ai#3)

fix pylint

* Refactor ONNX frontend to be dynamic

Make OneHot dynamic

Support BatchMatMul with dynamically shaped inputs

fix dynamic broadcast

Add null checks to broadcast_to rel functions

fail more isolated broadcast_to test

use StructuralEqual instead of pointer comparisions in dynamic_to_static pass

add an optional weight freeze argument to onnx importer

convert onnx resize to dynamic op

add dynamic expand to onnx importer

add a shape_func for power

fix BERTSquad, lint

handle onnx graph initializer parameters more intelligently

* Dynamic ONNX importer: Upsampling and Pad (neo-ai#2)

fix lint

fix Call reference

fix a type issue with expand

fix a bad test refactor

respond to review comments, fix batch matmul tests

* black format

* fix batch matmul test

* add dynamic strided slice to the onnx importer

* fix clip importer

* fix qnn tutorial

* fix bad merge, respond to review comments

* add a simple dynamic model test

* Add dynamic-shaped autopadding to convolution and pooling ops

* fix dynamic issues in a few ops

* fix pylint

* disable tests onnxrt doesn't support

* fix pytorch test

* respond to review comments

* add documentation about partially supporting dynamic shapes

Co-authored-by: Lily Orth-Smith <lorthsmith@octoml.ai>
TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Oct 15, 2020
* Change onnx importer to use dynamic upsampling3d (neo-ai#3)

fix pylint

* Refactor ONNX frontend to be dynamic

Make OneHot dynamic

Support BatchMatMul with dynamically shaped inputs

fix dynamic broadcast

Add null checks to broadcast_to rel functions

fail more isolated broadcast_to test

use StructuralEqual instead of pointer comparisions in dynamic_to_static pass

add an optional weight freeze argument to onnx importer

convert onnx resize to dynamic op

add dynamic expand to onnx importer

add a shape_func for power

fix BERTSquad, lint

handle onnx graph initializer parameters more intelligently

* Dynamic ONNX importer: Upsampling and Pad (neo-ai#2)

fix lint

fix Call reference

fix a type issue with expand

fix a bad test refactor

respond to review comments, fix batch matmul tests

* black format

* fix batch matmul test

* add dynamic strided slice to the onnx importer

* fix clip importer

* fix qnn tutorial

* fix bad merge, respond to review comments

* add a simple dynamic model test

* Add dynamic-shaped autopadding to convolution and pooling ops

* fix dynamic issues in a few ops

* fix pylint

* disable tests onnxrt doesn't support

* fix pytorch test

* respond to review comments

* add documentation about partially supporting dynamic shapes

Co-authored-by: Lily Orth-Smith <lorthsmith@octoml.ai>
TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Oct 15, 2020
* Change onnx importer to use dynamic upsampling3d (neo-ai#3)

fix pylint

* Refactor ONNX frontend to be dynamic

Make OneHot dynamic

Support BatchMatMul with dynamically shaped inputs

fix dynamic broadcast

Add null checks to broadcast_to rel functions

fail more isolated broadcast_to test

use StructuralEqual instead of pointer comparisions in dynamic_to_static pass

add an optional weight freeze argument to onnx importer

convert onnx resize to dynamic op

add dynamic expand to onnx importer

add a shape_func for power

fix BERTSquad, lint

handle onnx graph initializer parameters more intelligently

* Dynamic ONNX importer: Upsampling and Pad (neo-ai#2)

fix lint

fix Call reference

fix a type issue with expand

fix a bad test refactor

respond to review comments, fix batch matmul tests

* black format

* fix batch matmul test

* add dynamic strided slice to the onnx importer

* fix clip importer

* fix qnn tutorial

* fix bad merge, respond to review comments

* add a simple dynamic model test

* Add dynamic-shaped autopadding to convolution and pooling ops

* fix dynamic issues in a few ops

* fix pylint

* disable tests onnxrt doesn't support

* fix pytorch test

* respond to review comments

* add documentation about partially supporting dynamic shapes

Co-authored-by: Lily Orth-Smith <lorthsmith@octoml.ai>
TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Oct 16, 2020
* Change onnx importer to use dynamic upsampling3d (neo-ai#3)

fix pylint

* Refactor ONNX frontend to be dynamic

Make OneHot dynamic

Support BatchMatMul with dynamically shaped inputs

fix dynamic broadcast

Add null checks to broadcast_to rel functions

fail more isolated broadcast_to test

use StructuralEqual instead of pointer comparisions in dynamic_to_static pass

add an optional weight freeze argument to onnx importer

convert onnx resize to dynamic op

add dynamic expand to onnx importer

add a shape_func for power

fix BERTSquad, lint

handle onnx graph initializer parameters more intelligently

* Dynamic ONNX importer: Upsampling and Pad (neo-ai#2)

fix lint

fix Call reference

fix a type issue with expand

fix a bad test refactor

respond to review comments, fix batch matmul tests

* black format

* fix batch matmul test

* add dynamic strided slice to the onnx importer

* fix clip importer

* fix qnn tutorial

* fix bad merge, respond to review comments

* add a simple dynamic model test

* Add dynamic-shaped autopadding to convolution and pooling ops

* fix dynamic issues in a few ops

* fix pylint

* disable tests onnxrt doesn't support

* fix pytorch test

* respond to review comments

* add documentation about partially supporting dynamic shapes

Co-authored-by: Lily Orth-Smith <lorthsmith@octoml.ai>
TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Oct 16, 2020
* Change onnx importer to use dynamic upsampling3d (neo-ai#3)

fix pylint

* Refactor ONNX frontend to be dynamic

Make OneHot dynamic

Support BatchMatMul with dynamically shaped inputs

fix dynamic broadcast

Add null checks to broadcast_to rel functions

fail more isolated broadcast_to test

use StructuralEqual instead of pointer comparisions in dynamic_to_static pass

add an optional weight freeze argument to onnx importer

convert onnx resize to dynamic op

add dynamic expand to onnx importer

add a shape_func for power

fix BERTSquad, lint

handle onnx graph initializer parameters more intelligently

* Dynamic ONNX importer: Upsampling and Pad (neo-ai#2)

fix lint

fix Call reference

fix a type issue with expand

fix a bad test refactor

respond to review comments, fix batch matmul tests

* black format

* fix batch matmul test

* add dynamic strided slice to the onnx importer

* fix clip importer

* fix qnn tutorial

* fix bad merge, respond to review comments

* add a simple dynamic model test

* Add dynamic-shaped autopadding to convolution and pooling ops

* fix dynamic issues in a few ops

* fix pylint

* disable tests onnxrt doesn't support

* fix pytorch test

* respond to review comments

* add documentation about partially supporting dynamic shapes

Co-authored-by: Lily Orth-Smith <lorthsmith@octoml.ai>
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Oct 19, 2020
* Change onnx importer to use dynamic upsampling3d (#3)

fix pylint

* Refactor ONNX frontend to be dynamic

Make OneHot dynamic

Support BatchMatMul with dynamically shaped inputs

fix dynamic broadcast

Add null checks to broadcast_to rel functions

fail more isolated broadcast_to test

use StructuralEqual instead of pointer comparisions in dynamic_to_static pass

add an optional weight freeze argument to onnx importer

convert onnx resize to dynamic op

add dynamic expand to onnx importer

add a shape_func for power

fix BERTSquad, lint

handle onnx graph initializer parameters more intelligently

* Dynamic ONNX importer: Upsampling and Pad (#2)

fix lint

fix Call reference

fix a type issue with expand

fix a bad test refactor

respond to review comments, fix batch matmul tests

* black format

* fix batch matmul test

* add dynamic strided slice to the onnx importer

* fix clip importer

* fix qnn tutorial

* fix bad merge, respond to review comments

* add a simple dynamic model test

* Add dynamic-shaped autopadding to convolution and pooling ops

* fix dynamic issues in a few ops

* fix pylint

* disable tests onnxrt doesn't support

* fix pytorch test

* respond to review comments

* add documentation about partially supporting dynamic shapes

Co-authored-by: Lily Orth-Smith <lorthsmith@octoml.ai>
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Oct 19, 2020
* Change onnx importer to use dynamic upsampling3d (#3)

fix pylint

* Refactor ONNX frontend to be dynamic

Make OneHot dynamic

Support BatchMatMul with dynamically shaped inputs

fix dynamic broadcast

Add null checks to broadcast_to rel functions

fail more isolated broadcast_to test

use StructuralEqual instead of pointer comparisions in dynamic_to_static pass

add an optional weight freeze argument to onnx importer

convert onnx resize to dynamic op

add dynamic expand to onnx importer

add a shape_func for power

fix BERTSquad, lint

handle onnx graph initializer parameters more intelligently

* Dynamic ONNX importer: Upsampling and Pad (#2)

fix lint

fix Call reference

fix a type issue with expand

fix a bad test refactor

respond to review comments, fix batch matmul tests

* black format

* fix batch matmul test

* add dynamic strided slice to the onnx importer

* fix clip importer

* fix qnn tutorial

* fix bad merge, respond to review comments

* add a simple dynamic model test

* Add dynamic-shaped autopadding to convolution and pooling ops

* fix dynamic issues in a few ops

* fix pylint

* disable tests onnxrt doesn't support

* fix pytorch test

* respond to review comments

* add documentation about partially supporting dynamic shapes

Co-authored-by: Lily Orth-Smith <lorthsmith@octoml.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants