-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend][Torch] Fix up graph input handling #5204
Conversation
@jjohnson-arm need to update the tutorial too. |
python/tvm/relay/frontend/pytorch.py
Outdated
for output_name, output in name_output_pairs: | ||
output_index_map[output_name] = len(outputs) | ||
outputs.append(output) | ||
def _update_inputs_from_pairs(name_input_pairs, input_vars): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't need this function anymore. Dict's update
method can be used.
please update the doc here https://github.com/apache/incubator-tvm/blob/e722301a1c8be3c7052273961b8a408ca5524c76/python/tvm/relay/frontend/pytorch.py#L1434-L1436 We should warn that this names need be around until deployment time. Our suggestion is to choose something obvious, that doesn't require remembering. Something like "input0", "input1" etc |
I still like to retain the original meaning of I think you can simply replace |
python/tvm/relay/frontend/pytorch.py
Outdated
""" | ||
input_vars = {} | ||
ir_inputs = _get_graph_input_names(graph) | ||
for idx, ir_input in enumerate(ir_inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about
for ir_input, (name, shape) in zip(ir_inputs, input_shapes):
...
python/tvm/relay/frontend/pytorch.py
Outdated
|
||
params = script_module.state_dict() | ||
input_vars = _get_relay_input_vars(input_shapes) | ||
input_vars = _get_relay_input_vars(graph, input_shapes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_vars
-> outputs
""" | ||
Add quant params to outputs so that they can be referenced by other | ||
Add quant params to inputs so that they can be referenced by other | ||
ops later. Weights are quantized here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For L104 and L107, please keep outputs
python/tvm/relay/frontend/pytorch.py
Outdated
input_names = [output_index_map[name] | ||
for name in _get_input_names(op_node)] | ||
return [outputs[name] for name in input_names] | ||
def _get_op_inputs(op_node, input_vars): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_vars
-> outputs
Because inputs are not relay.Var
.
input_shapes = list(zip(input_names, ishapes)) | ||
|
||
inputs = [torch.randn(shape, dtype=torch.float) | ||
for name, shape in input_shapes] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for shape in ishapes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Update the tutorial
- Update the doc
- Remove
_update_inputs_from_pairs
and use Dict's update method directly everywhere - Minor code style change (for loop)
input_vars
->outputs
andinputs
->outputs
.
Ok - will do. I wasn't sure about the inputs rename, happy to change it back. |
Will look at this now. |
@jjohnson-arm Unfortunately, you've just hit a known flaky test failure. Please comment out the get_valid_count test. See #4901 (comment) Also have you verified that torch frontend tests work with this PR? I'm not sure some of the usage of |
Will comment out the test. |
ok good to know. I was thinking the arg of update should be dict. |
Thanks @jjohnson-arm this is merged! |
* [Frontend][Torch] Simplify operator input handling * [Frontend][Torch] Allow user supplied input names to override graph inputs * Fix pylint issues * Updates from code review feedback * Fix tutorial to use shape list input * Disable intermittent test failure in topi vision test
* [Frontend][Torch] Simplify operator input handling * [Frontend][Torch] Allow user supplied input names to override graph inputs * Fix pylint issues * Updates from code review feedback * Fix tutorial to use shape list input * Disable intermittent test failure in topi vision test
Thanks for contributing to TVM! Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers by @ them in the pull request thread.
From: https://discuss.tvm.ai/t/pytorch-frontend-graph-input-names-can-change-using-loaded-torchscript/6055
Split as two commits to make it easier to review:
Review request: @masahi