Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, Torch] Clean up and refactor PyTorch frontend #4944

Merged
merged 10 commits into from
Feb 28, 2020

Conversation

masahi
Copy link
Member

@masahi masahi commented Feb 26, 2020

This is a follow up to #4497. The main motivation is to make it easier to add support for control flow and quantized ops later. This PR itself doesn't contain any new functionality. The diff for the upcoming control flow PR, building on top of the refactoring made in this PR, is here.

  • Remove the class boilerplate. There is no need for using a big class with many states in the frontend really. Instead, I made everything just a function. This departs from other frontends, but I think it is better suited for parsing Torch IR in particular. In my upcoming control flow PR, I need to use recursion to parse conditionals and loops, and it is naturally expressed in a functional style.

  • Simplify the parsing "main loop" by removing as much unnecessary intermediate variables as possible. Inputs to each relay op, constants and input types are retrieved on the fly. We now keep updating only two variables, the outputs of intermediate relay ops and indices into that outputs.

  • Simplify the prim::GetAttr parsing logic in parse_params(...). The current implementation is difficult to understand and if you only look at the code it is not clear what it does. prim::GetAttr is inherently recursive, so I use some recursion stuff to handle it. I think it is much clearer this way. As a bonus, get_use_chains(...) function I added can be used for other purpose as well.

Please review @zhiics @icemelon9 @alexwong @jwfromm.
Aside from prim::GetAttrmechanics, it doesn't require a detailed understanding of Torch IR.

@alexwong
Copy link
Contributor

Thanks for doing this and looking forward to seeing the coming, additional PyTorch support :)

@masahi
Copy link
Member Author

masahi commented Feb 26, 2020

@alexwong The input name issue is fixed, see the last commit

python/tvm/relay/frontend/pytorch.py Outdated Show resolved Hide resolved
python/tvm/relay/frontend/pytorch.py Outdated Show resolved Hide resolved
python/tvm/relay/frontend/pytorch.py Show resolved Hide resolved
python/tvm/relay/frontend/pytorch.py Outdated Show resolved Hide resolved
@masahi masahi force-pushed the torch-refactor branch 2 times, most recently from d216c0b to 98675af Compare February 27, 2020 06:50
@anijain2305
Copy link
Contributor

Thanks for the changes. I will take a deeper look tomorrow.

Copy link
Contributor

@anijain2305 anijain2305 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

I could not follow the get_attr_chain code section, but it is mostly because I am not familiar with torch script. So, good to go from my side.

@masahi
Copy link
Member Author

masahi commented Feb 28, 2020

Merging now, since my next PRs are pending. Thanks @anijain2305 @zhiics

@masahi masahi merged commit 7ccb436 into apache:master Feb 28, 2020
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Apr 16, 2020
* The initial import of refactored implementation, all tests passed

* enable mobilenet v2 test

* minor cleanup

* reorg

* fix lint

* use input names that come with torch IR

* fix typo

* introduce parse_operators

* fix lint

* add _ prefix
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Apr 17, 2020
* The initial import of refactored implementation, all tests passed

* enable mobilenet v2 test

* minor cleanup

* reorg

* fix lint

* use input names that come with torch IR

* fix typo

* introduce parse_operators

* fix lint

* add _ prefix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants