Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMC] Fix PyTorch support #7359

Closed
wants to merge 1 commit into from
Closed

Conversation

ekalda
Copy link
Contributor

@ekalda ekalda commented Jan 28, 2021

A PyTorch model could not be compiled through tvmc because the shape
of the input tensor could not be deduced from the model after it has been
saved. We've added an --input-shape parameter to tvmc compile and
tvmc tune that allows the inputs to be specified for PyTorch models.

A PyTorch model could not be compiled throgh tvmc because the shape
of the input tensor could not be deduced from the model after it has been
saved. We've added an --input-shape parameter to tvmc compile and
tvmc tune that allows the inputs to be specified for PyTorch models.
@ekalda
Copy link
Contributor Author

ekalda commented Jan 28, 2021

cc @leandron @u99127

Copy link
Contributor

@leandron leandron left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this fix.

cc @masahi @comaniac to have a look.


def parse_input_shapes(xs):
"""Turn the string from --input-shape into a list.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to have an example here, that describes the input format and expected output format, similar to what you have on test_parse_input_shapes__turn_into_list.

"--input-shape",
type=common.parse_input_shapes,
metavar="INPUT_SHAPE,[INPUT_SHAPE]...",
help="for PyTorch, e.g. '(1,3,224,224)'",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe clarify that it is in fact mandatory for PyTorch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. It's confusing to see such a general option only for PyTorch. I would suggest the following changes:

  1. Make --input-shape as a general option for all frontends. If present, we skip the input shape inference.
  2. --input-shape is optional by default. However, if users want to process a PyTorch model but don't specify --input-shape, we throw out an error in the PyTorch frontend.

"--input-shape",
type=common.parse_input_shapes,
metavar="INPUT_SHAPE,[INPUT_SHAPE]...",
help="for PyTorch, e.g. '(1,3,224,224)'",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. It's confusing to see such a general option only for PyTorch. I would suggest the following changes:

  1. Make --input-shape as a general option for all frontends. If present, we skip the input shape inference.
  2. --input-shape is optional by default. However, if users want to process a PyTorch model but don't specify --input-shape, we throw out an error in the PyTorch frontend.

Comment on lines +165 to +166
# Remove white space and extract numbers
strshape = shape[1].replace(" ", "").split(",")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be safer and easier to remove all spaces in xs in the beginning of this function.

try:
shapes.append([int(i) for i in strshape])
except ValueError:
raise argparse.ArgumentTypeError(f"expected numbers in shape '{shape[1]}'")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider the following two input shapes:

  • (8): shapes=[8]
  • (8,): Value error because strshape would be [8, ""].

Accordingly, I guess your intention is (8) instead of (8,). However, this is inconsistent with the Python syntax so it might confuse people. I have two proposals to deal with this:

  1. Use list syntax instead of tuple, so that the semantic is clear, and we can simply use JSON loader to deal with all variants (e.g., spaces):
    xs = "[1,3,224,224], [32]"
    shapes = json.loads(xs) # [[1,3,224,224],[32]]
  2. Follow Python syntax to only accept (8,) and throw an error for (8), which is treated as an integer instead of a tuple because buckets will be simplified in Python. In this case, I would suggest using eval to deal with all variants.
    xs = "(1,3,224,224), (32,)"
    shapes = eval(xs, {}, {}) # Remember to disable all local and global symbols to isolate this expression.
    # shapes=[(1,3,224,224),(32,)]

Either way is fine for me, and please update the help message and make sure you have a unit test to cover corner cases.

Comment on lines +108 to +110
if input_shape:
raise TVMCException("--input-shape is not supported for {}".format(self.name()))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely too ad hoc

# pylint: disable=C0415
import torch

traced_model = torch.jit.load(path)

inputs = list(traced_model.graph.inputs())[1:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this approach not working at all? If it works for some cases, we should still use it first when --input-shape is missing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into this and I didn't find a way to extract inputs from the model after it has been saved and loaded. I asked on the PyTorch forum as well (https://discuss.pytorch.org/t/input-size-disappears-between-torch-jit-save-and-torch-jit-load/108955) and since I received a grand total of zero responses, I suspect it is a deliberate design decision. If there was a way, it would be good to keep it, of course, but in that form it doesn't work any more.

@@ -389,6 +403,8 @@ def load_model(path, model_format=None):
model_format : str, optional
The underlying framework used to create the model.
If not specified, this will be inferred from the file type.
input shape : list, optional
The shape of input tensor for PyTorch models
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. make it general instead of only for PyTorch.

@comaniac comaniac added status: suprceded PR is superceded by another one and removed status: need update need update based on feedbacks status: review in progress labels Jan 29, 2021
@comaniac
Copy link
Contributor

Include the functionalities in #7366.

@comaniac comaniac closed this Jan 29, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
status: suprceded PR is superceded by another one
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants