-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Torch] Fix dtype handling for modules with integer parameters #6311
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
siju-samuel
approved these changes
Aug 21, 2020
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
WARNING:root:Untyped Tensor found, assume it is float32
-> This annoying warning is gone. Thanks.
Thanks @masahi. This PR is merged. |
trevor-m
pushed a commit
to trevor-m/tvm
that referenced
this pull request
Aug 26, 2020
…e#6311) * return the correct type for GetAttr node * keep _get_pytorch_value_type intact * add test and handle quantized param
trevor-m
pushed a commit
to trevor-m/tvm
that referenced
this pull request
Aug 26, 2020
…e#6311) * return the correct type for GetAttr node * keep _get_pytorch_value_type intact * add test and handle quantized param
trevor-m
pushed a commit
to trevor-m/tvm
that referenced
this pull request
Aug 26, 2020
…e#6311) * return the correct type for GetAttr node * keep _get_pytorch_value_type intact * add test and handle quantized param
electriclilies
pushed a commit
to electriclilies/tvm
that referenced
this pull request
Aug 26, 2020
…e#6311) * return the correct type for GetAttr node * keep _get_pytorch_value_type intact * add test and handle quantized param
trevor-m
pushed a commit
to trevor-m/tvm
that referenced
this pull request
Sep 2, 2020
…e#6311) * return the correct type for GetAttr node * keep _get_pytorch_value_type intact * add test and handle quantized param
trevor-m
pushed a commit
to neo-ai/tvm
that referenced
this pull request
Sep 3, 2020
…e#6311) * return the correct type for GetAttr node * keep _get_pytorch_value_type intact * add test and handle quantized param
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This fixes the interesting typing problem raised in #6300.
In Torchscript, parameters of operations like conv are always accessed via
prim::GetAttr
nodes. For example, when we seeThe input weight %6 is accessed in this way:
The problem is, Torch cannot figure out the correct type of GetAttr nodes. So when we visit
aten::_convolution
to get its input types, we have to assume that this is an untyped tensor and get annoying warningsUntyped Tensor found, assume it is float
.This hasn't been a big issue so far because parameters are usually float anyways. But #6300 brought a use case where there are integer parameters as well as float ones. Changing
default_dtype
to int doesn't solve it.So I added a workaround when we try to get the dtype of GetAttr nodes. For every GetAttr node there is a corresponding parameter from the original PyTorch module with known, correct dtype. And for every PyTorch parameter tensor we have a corresponding Relay Var (via
convert_params(...)
function). So inside_get_input_types
when we findGetAttr
node, we return the dtype of corresponding Relay Var, instead of returningdefault_type
.This is the solution I came up with minimal change. It fixes the problem, but I feel it is a bit hacky. Please let me know if there are better ways to handle this. The test case I added is a minimal reproduction of the issue raised in #6300.
As a bonus, there would be no more annoying warnings
Untyped Tensor found...
when working with traced models.please review @siju-samuel @t-vi