Skip to content

Commit

Permalink
Improve type handling in PyTorch frontend (apache#5834)
Browse files Browse the repository at this point in the history
* Improve type handling in PyTorch frontend

- Use type information from graph for inputs if available. Check
  against shape information from graph if available.
- Allow user to set default dtype (default to float32 for sanity and
  compatibility).
- Implement type promotion to follow PyTorch mechanism. This includes
  fixing the handling of many "Scalar" overloads in PyTorch binary ops.
- Fix arange/linspace type semantics.
- Added support for traced functions. (Because it really is about the
  "self" input handling.)

Aside from adding an optional default_dtype keyword argument, this does not
change the signature/requirements of from_pytorch.

* Fix scalar detection using numpy.isscalar

and address other review comments. Thank you @siju-samuel

* refine test criteron on qnn_test::test_serialized_modules, fix bool conversion of const
  • Loading branch information
t-vi authored and Trevor Morris committed Jun 30, 2020
1 parent da3eeab commit b4e5f58
Show file tree
Hide file tree
Showing 3 changed files with 290 additions and 187 deletions.
Loading

0 comments on commit b4e5f58

Please sign in to comment.