Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][Torch] Fix up graph input handling #5204

Merged
merged 6 commits into from
Apr 2, 2020

Conversation

jjohnson-arm
Copy link
Contributor

Thanks for contributing to TVM! Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers by @ them in the pull request thread.

From: https://discuss.tvm.ai/t/pytorch-frontend-graph-input-names-can-change-using-loaded-torchscript/6055

Split as two commits to make it easier to review:

  • Simplify operator input handling
    • remove outputs and output_index_map and extend use of input_vars
  • Allow user supplied input names to override graph inputs
    • PyTorch inputs now expected as list rather than dictionary - [(name, shape), (name, shape)...]
    • Input names given with from_pytorch() are now set as the graph input names and should be used in set_input()

Review request: @masahi

@masahi
Copy link
Member

masahi commented Apr 1, 2020

@jjohnson-arm need to update the tutorial too.

for output_name, output in name_output_pairs:
output_index_map[output_name] = len(outputs)
outputs.append(output)
def _update_inputs_from_pairs(name_input_pairs, input_vars):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need this function anymore. Dict's update method can be used.

@masahi
Copy link
Member

masahi commented Apr 1, 2020

please update the doc here https://github.com/apache/incubator-tvm/blob/e722301a1c8be3c7052273961b8a408ca5524c76/python/tvm/relay/frontend/pytorch.py#L1434-L1436

We should warn that this names need be around until deployment time. Our suggestion is to choose something obvious, that doesn't require remembering. Something like "input0", "input1" etc

@masahi
Copy link
Member

masahi commented Apr 1, 2020

I still like to retain the original meaning of input_vars, because this is really relay.Var and represents values that come from outside. outputs is for storing intermediate outputs from preceding relay ops.

I think you can simply replace input_vars with outputs and also inputs with outputs everywhere.

"""
input_vars = {}
ir_inputs = _get_graph_input_names(graph)
for idx, ir_input in enumerate(ir_inputs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about

for ir_input, (name, shape) in zip(ir_inputs, input_shapes):
    ...


params = script_module.state_dict()
input_vars = _get_relay_input_vars(input_shapes)
input_vars = _get_relay_input_vars(graph, input_shapes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_vars -> outputs

"""
Add quant params to outputs so that they can be referenced by other
Add quant params to inputs so that they can be referenced by other
ops later. Weights are quantized here.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For L104 and L107, please keep outputs

input_names = [output_index_map[name]
for name in _get_input_names(op_node)]
return [outputs[name] for name in input_names]
def _get_op_inputs(op_node, input_vars):
Copy link
Member

@masahi masahi Apr 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_vars -> outputs

Because inputs are not relay.Var.

input_shapes = list(zip(input_names, ishapes))

inputs = [torch.randn(shape, dtype=torch.float)
for name, shape in input_shapes]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for shape in ishapes

@masahi
Copy link
Member

masahi commented Apr 1, 2020

cc @alexwong @jwfromm @pyjhzwh This is a API change and will break your code, but for a good reason. See the discussion above.

Copy link
Member

@masahi masahi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Update the tutorial
  • Update the doc
  • Remove _update_inputs_from_pairs and use Dict's update method directly everywhere
  • Minor code style change (for loop)
  • input_vars -> outputs and inputs -> outputs.

@jjohnson-arm
Copy link
Contributor Author

  • Update the tutorial
  • Update the doc
  • Remove _update_inputs_from_pairs and use Dict's update method directly everywhere
  • Minor code style change (for loop)
  • input_vars -> outputs and inputs -> outputs.

Ok - will do. I wasn't sure about the inputs rename, happy to change it back.

@jjohnson-arm
Copy link
Contributor Author

jjohnson-arm commented Apr 2, 2020

@jjohnson-arm need to update the tutorial too.

Will look at this now.

@masahi
Copy link
Member

masahi commented Apr 2, 2020

@jjohnson-arm Unfortunately, you've just hit a known flaky test failure. Please comment out the get_valid_count test. See #4901 (comment)

Also have you verified that torch frontend tests work with this PR? I'm not sure some of the usage of update() method is supported.

@jjohnson-arm
Copy link
Contributor Author

@jjohnson-arm Unfortunately, you've just hit a known flaky test failure. Please comment out the get_valid_count test. See #4901 (comment)

Also have you verified that torch frontend tests work with this PR? I'm not sure some of the usage of update() method is supported.

Will comment out the test.
I have been runnning the tests/python/frontend/pytorch/test_forward.py tests for all my changes, and it works fine (using Python 3.6.10). Seems that python from 3.5 has supported tuples etc - https://docs.python.org/3.5/library/stdtypes.html?highlight=update#dict.update, so should be okay?

@masahi
Copy link
Member

masahi commented Apr 2, 2020

ok good to know. I was thinking the arg of update should be dict.

@masahi masahi merged commit 03cbf78 into apache:master Apr 2, 2020
@masahi
Copy link
Member

masahi commented Apr 2, 2020

Thanks @jjohnson-arm this is merged!

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Apr 16, 2020
* [Frontend][Torch] Simplify operator input handling

* [Frontend][Torch] Allow user supplied input names to override graph inputs

* Fix pylint issues

* Updates from code review feedback

* Fix tutorial to use shape list input

* Disable intermittent test failure in topi vision test
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Apr 17, 2020
* [Frontend][Torch] Simplify operator input handling

* [Frontend][Torch] Allow user supplied input names to override graph inputs

* Fix pylint issues

* Updates from code review feedback

* Fix tutorial to use shape list input

* Disable intermittent test failure in topi vision test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants