Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Support complex in FX exporter #100554

Closed
wants to merge 45 commits into from

Conversation

titaiwangms
Copy link
Collaborator

@titaiwangms titaiwangms commented May 3, 2023

Stack from ghstack (oldest at bottom):

Previous to the PR, the complex dtype would only fail. This PR keeps torch.fx.Graph with complex dtype, while mapping them to float dtype in torchscript(onnx) graph with real representation.

The change happens in multiple files:

  1. placeholder: Apply torch.view_as_real() before sending fake tensor to graph building.
  2. call_function: Fill in TorchScriptTensor dtype and shape with real representation dtype and shape.
  3. Registry: Add is_complex, and supports complex onnxfunction.
  4. Dispatcher: Filter with/out complex onnxfunction before opschema matching, based on the dtype in torch args
  5. Test cases: input/output view_as_real for result comparisons.

@titaiwangms titaiwangms requested review from BowenBao and abock as code owners May 3, 2023 16:13
@pytorch-bot
Copy link

pytorch-bot bot commented May 3, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/100554

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ddb5acb:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label May 3, 2023
titaiwangms added a commit that referenced this pull request May 3, 2023
ghstack-source-id: 9758268225091a89c7dd1954bc5e3771ecb2f926
Pull Request resolved: #100554
@titaiwangms titaiwangms marked this pull request as draft May 3, 2023 16:13
@titaiwangms titaiwangms added module: onnx Related to torch.onnx topic: new features topic category labels Jun 18, 2023
@titaiwangms titaiwangms marked this pull request as ready for review June 19, 2023 03:44
@titaiwangms titaiwangms requested a review from justinchuby June 19, 2023 03:57
titaiwangms added a commit that referenced this pull request Jun 19, 2023
ghstack-source-id: cea8f8c8b39132d368439fd5b9ad208bcdeb0606
Pull Request resolved: #100554
@justinchuby justinchuby self-assigned this Jun 19, 2023
justinchuby
justinchuby previously approved these changes Jun 19, 2023
Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

torch/onnx/_internal/fx/type_utils.py Outdated Show resolved Hide resolved
torch/onnx/_internal/fx/type_utils.py Outdated Show resolved Hide resolved
@justinchuby
Copy link
Collaborator

LGTM. Just a few minor comments

Previous to the PR, the complex dtype would only fail. This PR keeps torch.fx.Graph with complex dtype, while mapping them to float dtype in torchscript(onnx) graph with real representation.

The change happens in multiple files:

1. `placeholder`: Apply torch.view_as_real() before sending fake tensor to graph building.
2. `call_function`: Fill in TorchScriptTensor dtype and shape with real representation dtype and shape.
3. Registry: Add `is_complex`, and supports complex onnxfunction.
4. Dispatcher: Filter with/out complex onnxfunction before opschema matching, based on the dtype in torch args
5. Test cases: input/output view_as_real for result comparisons.

[ghstack-poisoned]
titaiwangms added a commit that referenced this pull request Jul 27, 2023
ghstack-source-id: a722217c4578f73ede800118daeb1932f74dbf06
Pull Request resolved: #100554
Previous to the PR, the complex dtype would only fail. This PR keeps torch.fx.Graph with complex dtype, while mapping them to float dtype in torchscript(onnx) graph with real representation.

The change happens in multiple files:

1. `placeholder`: Apply torch.view_as_real() before sending fake tensor to graph building.
2. `call_function`: Fill in TorchScriptTensor dtype and shape with real representation dtype and shape.
3. Registry: Add `is_complex`, and supports complex onnxfunction.
4. Dispatcher: Filter with/out complex onnxfunction before opschema matching, based on the dtype in torch args
5. Test cases: input/output view_as_real for result comparisons.

[ghstack-poisoned]
# NOTE: ONNX Runtime doesn't support tensor of complex64/complex128, so we
# convert them to float32/float64.
# TODO: Need stft enabled to test complex64/complex128
if isinstance(ref_output, torch.Tensor) and torch.is_complex(ref_output):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should add a new input/output adaptor step for this. Checkout RemoveNonTensorInputStep for example.
Feel free to do it in this or follow-up PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

else input
for input in pytorch_inputs
]
ort_input = {k: v.cpu().numpy() for k, v in zip(input_names, pytorch_inputs)}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: input adapter step

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM w/ comments. Please also resolve others' comments before landing.

Previous to the PR, the complex dtype would only fail. This PR keeps torch.fx.Graph with complex dtype, while mapping them to float dtype in torchscript(onnx) graph with real representation.

The change happens in multiple files:

1. `placeholder`: Apply torch.view_as_real() before sending fake tensor to graph building.
2. `call_function`: Fill in TorchScriptTensor dtype and shape with real representation dtype and shape.
3. Registry: Add `is_complex`, and supports complex onnxfunction.
4. Dispatcher: Filter with/out complex onnxfunction before opschema matching, based on the dtype in torch args
5. Test cases: input/output view_as_real for result comparisons.

[ghstack-poisoned]
@justinchuby
Copy link
Collaborator

Are we handling any complex constants that may be inside the graph? Maybe that’s an unlikely case that we don’t need to care yet?

Previous to the PR, the complex dtype would only fail. This PR keeps torch.fx.Graph with complex dtype, while mapping them to float dtype in torchscript(onnx) graph with real representation.

The change happens in multiple files:

1. `placeholder`: Apply torch.view_as_real() before sending fake tensor to graph building.
2. `call_function`: Fill in TorchScriptTensor dtype and shape with real representation dtype and shape.
3. Registry: Add `is_complex`, and supports complex onnxfunction.
4. Dispatcher: Filter with/out complex onnxfunction before opschema matching, based on the dtype in torch args
5. Test cases: input/output view_as_real for result comparisons.

[ghstack-poisoned]
@titaiwangms
Copy link
Collaborator Author

Are we handling any complex constants that may be inside the graph? Maybe that’s an unlikely case that we don’t need to care yet?

It's not covered in this PR. We can add it when we have a model or use case.

Previous to the PR, the complex dtype would only fail. This PR keeps torch.fx.Graph with complex dtype, while mapping them to float dtype in torchscript(onnx) graph with real representation.

The change happens in multiple files:

1. `placeholder`: Apply torch.view_as_real() before sending fake tensor to graph building.
2. `call_function`: Fill in TorchScriptTensor dtype and shape with real representation dtype and shape.
3. Registry: Add `is_complex`, and supports complex onnxfunction.
4. Dispatcher: Filter with/out complex onnxfunction before opschema matching, based on the dtype in torch args
5. Test cases: input/output view_as_real for result comparisons.

[ghstack-poisoned]
@titaiwangms titaiwangms added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 27, 2023
Previous to the PR, the complex dtype would only fail. This PR keeps torch.fx.Graph with complex dtype, while mapping them to float dtype in torchscript(onnx) graph with real representation.

The change happens in multiple files:

1. `placeholder`: Apply torch.view_as_real() before sending fake tensor to graph building.
2. `call_function`: Fill in TorchScriptTensor dtype and shape with real representation dtype and shape.
3. Registry: Add `is_complex`, and supports complex onnxfunction.
4. Dispatcher: Filter with/out complex onnxfunction before opschema matching, based on the dtype in torch args
5. Test cases: input/output view_as_real for result comparisons.

[ghstack-poisoned]
@titaiwangms
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot.

@titaiwangms
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

bobby-palmer pushed a commit to bobby-palmer/pytorch that referenced this pull request Jul 29, 2023
Previous to the PR, the complex dtype would only fail. This PR keeps torch.fx.Graph with complex dtype, while mapping them to float dtype in torchscript(onnx) graph with real representation.

The change happens in multiple files:

1. `placeholder`: Apply torch.view_as_real() before sending fake tensor to graph building.
2. `call_function`: Fill in TorchScriptTensor dtype and shape with real representation dtype and shape.
3. Registry: Add `is_complex`, and supports complex onnxfunction.
4. Dispatcher: Filter with/out complex onnxfunction before opschema matching, based on the dtype in torch args
5. Test cases: input/output view_as_real for result comparisons.
Pull Request resolved: pytorch#100554
Approved by: https://github.com/BowenBao
@facebook-github-bot facebook-github-bot deleted the gh/titaiwangms/17/head branch July 31, 2023 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: new features topic category
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

6 participants