-
Notifications
You must be signed in to change notification settings - Fork 23.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Support complex in FX exporter #100554
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/100554
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ddb5acb: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 9758268225091a89c7dd1954bc5e3771ecb2f926 Pull Request resolved: #100554
[ghstack-poisoned]
[ghstack-poisoned]
ghstack-source-id: cea8f8c8b39132d368439fd5b9ad208bcdeb0606 Pull Request resolved: #100554
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
LGTM. Just a few minor comments |
Previous to the PR, the complex dtype would only fail. This PR keeps torch.fx.Graph with complex dtype, while mapping them to float dtype in torchscript(onnx) graph with real representation. The change happens in multiple files: 1. `placeholder`: Apply torch.view_as_real() before sending fake tensor to graph building. 2. `call_function`: Fill in TorchScriptTensor dtype and shape with real representation dtype and shape. 3. Registry: Add `is_complex`, and supports complex onnxfunction. 4. Dispatcher: Filter with/out complex onnxfunction before opschema matching, based on the dtype in torch args 5. Test cases: input/output view_as_real for result comparisons. [ghstack-poisoned]
ghstack-source-id: a722217c4578f73ede800118daeb1932f74dbf06 Pull Request resolved: #100554
Previous to the PR, the complex dtype would only fail. This PR keeps torch.fx.Graph with complex dtype, while mapping them to float dtype in torchscript(onnx) graph with real representation. The change happens in multiple files: 1. `placeholder`: Apply torch.view_as_real() before sending fake tensor to graph building. 2. `call_function`: Fill in TorchScriptTensor dtype and shape with real representation dtype and shape. 3. Registry: Add `is_complex`, and supports complex onnxfunction. 4. Dispatcher: Filter with/out complex onnxfunction before opschema matching, based on the dtype in torch args 5. Test cases: input/output view_as_real for result comparisons. [ghstack-poisoned]
test/onnx/onnx_test_common.py
Outdated
# NOTE: ONNX Runtime doesn't support tensor of complex64/complex128, so we | ||
# convert them to float32/float64. | ||
# TODO: Need stft enabled to test complex64/complex128 | ||
if isinstance(ref_output, torch.Tensor) and torch.is_complex(ref_output): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should add a new input/output adaptor step for this. Checkout RemoveNonTensorInputStep
for example.
Feel free to do it in this or follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
else input | ||
for input in pytorch_inputs | ||
] | ||
ort_input = {k: v.cpu().numpy() for k, v in zip(input_names, pytorch_inputs)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: input adapter step
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM w/ comments. Please also resolve others' comments before landing.
Previous to the PR, the complex dtype would only fail. This PR keeps torch.fx.Graph with complex dtype, while mapping them to float dtype in torchscript(onnx) graph with real representation. The change happens in multiple files: 1. `placeholder`: Apply torch.view_as_real() before sending fake tensor to graph building. 2. `call_function`: Fill in TorchScriptTensor dtype and shape with real representation dtype and shape. 3. Registry: Add `is_complex`, and supports complex onnxfunction. 4. Dispatcher: Filter with/out complex onnxfunction before opschema matching, based on the dtype in torch args 5. Test cases: input/output view_as_real for result comparisons. [ghstack-poisoned]
Are we handling any complex constants that may be inside the graph? Maybe that’s an unlikely case that we don’t need to care yet? |
Previous to the PR, the complex dtype would only fail. This PR keeps torch.fx.Graph with complex dtype, while mapping them to float dtype in torchscript(onnx) graph with real representation. The change happens in multiple files: 1. `placeholder`: Apply torch.view_as_real() before sending fake tensor to graph building. 2. `call_function`: Fill in TorchScriptTensor dtype and shape with real representation dtype and shape. 3. Registry: Add `is_complex`, and supports complex onnxfunction. 4. Dispatcher: Filter with/out complex onnxfunction before opschema matching, based on the dtype in torch args 5. Test cases: input/output view_as_real for result comparisons. [ghstack-poisoned]
It's not covered in this PR. We can add it when we have a model or use case. |
Previous to the PR, the complex dtype would only fail. This PR keeps torch.fx.Graph with complex dtype, while mapping them to float dtype in torchscript(onnx) graph with real representation. The change happens in multiple files: 1. `placeholder`: Apply torch.view_as_real() before sending fake tensor to graph building. 2. `call_function`: Fill in TorchScriptTensor dtype and shape with real representation dtype and shape. 3. Registry: Add `is_complex`, and supports complex onnxfunction. 4. Dispatcher: Filter with/out complex onnxfunction before opschema matching, based on the dtype in torch args 5. Test cases: input/output view_as_real for result comparisons. [ghstack-poisoned]
Previous to the PR, the complex dtype would only fail. This PR keeps torch.fx.Graph with complex dtype, while mapping them to float dtype in torchscript(onnx) graph with real representation. The change happens in multiple files: 1. `placeholder`: Apply torch.view_as_real() before sending fake tensor to graph building. 2. `call_function`: Fill in TorchScriptTensor dtype and shape with real representation dtype and shape. 3. Registry: Add `is_complex`, and supports complex onnxfunction. 4. Dispatcher: Filter with/out complex onnxfunction before opschema matching, based on the dtype in torch args 5. Test cases: input/output view_as_real for result comparisons. [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Previous to the PR, the complex dtype would only fail. This PR keeps torch.fx.Graph with complex dtype, while mapping them to float dtype in torchscript(onnx) graph with real representation. The change happens in multiple files: 1. `placeholder`: Apply torch.view_as_real() before sending fake tensor to graph building. 2. `call_function`: Fill in TorchScriptTensor dtype and shape with real representation dtype and shape. 3. Registry: Add `is_complex`, and supports complex onnxfunction. 4. Dispatcher: Filter with/out complex onnxfunction before opschema matching, based on the dtype in torch args 5. Test cases: input/output view_as_real for result comparisons. Pull Request resolved: pytorch#100554 Approved by: https://github.com/BowenBao
Stack from ghstack (oldest at bottom):
Previous to the PR, the complex dtype would only fail. This PR keeps torch.fx.Graph with complex dtype, while mapping them to float dtype in torchscript(onnx) graph with real representation.
The change happens in multiple files:
placeholder
: Apply torch.view_as_real() before sending fake tensor to graph building.call_function
: Fill in TorchScriptTensor dtype and shape with real representation dtype and shape.is_complex
, and supports complex onnxfunction.