-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TVMC] Fix PyTorch support #7359
Conversation
A PyTorch model could not be compiled throgh tvmc because the shape of the input tensor could not be deduced from the model after it has been saved. We've added an --input-shape parameter to tvmc compile and tvmc tune that allows the inputs to be specified for PyTorch models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
def parse_input_shapes(xs): | ||
"""Turn the string from --input-shape into a list. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to have an example here, that describes the input format and expected output format, similar to what you have on test_parse_input_shapes__turn_into_list
.
"--input-shape", | ||
type=common.parse_input_shapes, | ||
metavar="INPUT_SHAPE,[INPUT_SHAPE]...", | ||
help="for PyTorch, e.g. '(1,3,224,224)'", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe clarify that it is in fact mandatory for PyTorch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. It's confusing to see such a general option only for PyTorch. I would suggest the following changes:
- Make
--input-shape
as a general option for all frontends. If present, we skip the input shape inference. --input-shape
is optional by default. However, if users want to process a PyTorch model but don't specify--input-shape
, we throw out an error in the PyTorch frontend.
"--input-shape", | ||
type=common.parse_input_shapes, | ||
metavar="INPUT_SHAPE,[INPUT_SHAPE]...", | ||
help="for PyTorch, e.g. '(1,3,224,224)'", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. It's confusing to see such a general option only for PyTorch. I would suggest the following changes:
- Make
--input-shape
as a general option for all frontends. If present, we skip the input shape inference. --input-shape
is optional by default. However, if users want to process a PyTorch model but don't specify--input-shape
, we throw out an error in the PyTorch frontend.
# Remove white space and extract numbers | ||
strshape = shape[1].replace(" ", "").split(",") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be safer and easier to remove all spaces in xs
in the beginning of this function.
try: | ||
shapes.append([int(i) for i in strshape]) | ||
except ValueError: | ||
raise argparse.ArgumentTypeError(f"expected numbers in shape '{shape[1]}'") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider the following two input shapes:
(8)
:shapes=[8]
(8,)
: Value error becausestrshape
would be[8, ""]
.
Accordingly, I guess your intention is (8)
instead of (8,)
. However, this is inconsistent with the Python syntax so it might confuse people. I have two proposals to deal with this:
- Use list syntax instead of tuple, so that the semantic is clear, and we can simply use JSON loader to deal with all variants (e.g., spaces):
xs = "[1,3,224,224], [32]" shapes = json.loads(xs) # [[1,3,224,224],[32]]
- Follow Python syntax to only accept
(8,)
and throw an error for(8)
, which is treated as an integer instead of a tuple because buckets will be simplified in Python. In this case, I would suggest usingeval
to deal with all variants.xs = "(1,3,224,224), (32,)" shapes = eval(xs, {}, {}) # Remember to disable all local and global symbols to isolate this expression. # shapes=[(1,3,224,224),(32,)]
Either way is fine for me, and please update the help message and make sure you have a unit test to cover corner cases.
if input_shape: | ||
raise TVMCException("--input-shape is not supported for {}".format(self.name())) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is definitely too ad hoc
# pylint: disable=C0415 | ||
import torch | ||
|
||
traced_model = torch.jit.load(path) | ||
|
||
inputs = list(traced_model.graph.inputs())[1:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this approach not working at all? If it works for some cases, we should still use it first when --input-shape
is missing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked into this and I didn't find a way to extract inputs from the model after it has been saved and loaded. I asked on the PyTorch forum as well (https://discuss.pytorch.org/t/input-size-disappears-between-torch-jit-save-and-torch-jit-load/108955) and since I received a grand total of zero responses, I suspect it is a deliberate design decision. If there was a way, it would be good to keep it, of course, but in that form it doesn't work any more.
@@ -389,6 +403,8 @@ def load_model(path, model_format=None): | |||
model_format : str, optional | |||
The underlying framework used to create the model. | |||
If not specified, this will be inferred from the file type. | |||
input shape : list, optional | |||
The shape of input tensor for PyTorch models |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto. make it general instead of only for PyTorch.
Include the functionalities in #7366. |
A PyTorch model could not be compiled through tvmc because the shape
of the input tensor could not be deduced from the model after it has been
saved. We've added an --input-shape parameter to tvmc compile and
tvmc tune that allows the inputs to be specified for PyTorch models.