-
Notifications
You must be signed in to change notification settings - Fork 23k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] add cast operator after reduce to match desired dtype #100700
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/100700
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit 274c2a8: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
I found a few issues with this one so will close until i fix it then will re-open. |
… the data input dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR. I have pointed out an issue in which how you accessed the scalar tpe for a torch._C.Value
. Please use the recommended way
torch/onnx/symbolic_opset13.py
Outdated
result = symbolic(g, self) | ||
if dtype_onnx is not None: | ||
result_dtype_scalar = result.type().scalarType() | ||
result_dtype_onnx = _type_utils.JitScalarType._from_name(result_dtype_scalar).onnx_type() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parsing type from name is the #1 source of problems (because of the previous comment). Please only use from_value
or from_dtype
.
torch/onnx/symbolic_opset13.py
Outdated
return symbolic(g, self) | ||
result = symbolic(g, self) | ||
if dtype_onnx is not None: | ||
result_dtype_scalar = result.type().scalarType() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can cause segfault/assert crashes when model is torch.jit.script
.
To prevent this, use exclusively JitScalarType
's public APIs to extract type information from torch._C.Value
nodes in a safe way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you I will make that change.
sorry for the lint issue. I used flake8 to test for lint issues locally but just realized the CI uses lintrunner so I'll make that change. |
I always do |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR conditionally inserts a cast operator after a reduction operation to match the specified dtype in the exported ONNX model. The code changes affect opset9, and opset13.
I understand there's an automatic upcast to int64 before reduction most likely to prevent overflow so I left that alone and only conditionally add casting back to desired dtype.
Test int32
Test int64
Fixes #100097