Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] add cast operator after reduce to match desired dtype #100700

Closed
wants to merge 7 commits into from

Conversation

TaiPhamD
Copy link
Contributor

@TaiPhamD TaiPhamD commented May 5, 2023

This PR conditionally inserts a cast operator after a reduction operation to match the specified dtype in the exported ONNX model. The code changes affect opset9, and opset13.

I understand there's an automatic upcast to int64 before reduction most likely to prevent overflow so I left that alone and only conditionally add casting back to desired dtype.

Test int32

import torch
import onnx
a = torch.tensor([10, 20, 30, 80], dtype=torch.int32)
def test():
    class SumInt32(torch.nn.Module):
        def forward(self, a):
            return torch.sum(a, dtype=torch.int32)

    sumi = SumInt32().eval()
    assert sumi(a).dtype == torch.int32
    print("Torch model output type matches input type")

    torch.onnx.export(sumi, (a), "/tmp/sumi_int32.onnx", opset_version=12)
    model = onnx.load("/tmp/sumi_int32.onnx")

    assert model.graph.output[0].type.tensor_type.elem_type == onnx.TensorProto.INT32
    print("ONNX model output type matches input type")
test()

sumi_int32 onnx

Test int64

import onnx
import torch

a = torch.tensor([10, 20, 30, 80], dtype=torch.int64)


def test():
    class SumInt64(torch.nn.Module):
        def forward(self, a):
            return torch.sum(a, dtype=torch.int64)

    sumi = SumInt64().eval()
    assert sumi(a).dtype == torch.int64
    print("Torch model output type matches input type")
    torch.onnx.export(sumi, (a), "/tmp/sumi_int64.onnx", opset_version=12)
    model = onnx.load("/tmp/sumi_int64.onnx")
    assert model.graph.output[0].type.tensor_type.elem_type == onnx.TensorProto.INT64
    print("ONNX model output type matches input type")


test()

sum_int64 onnx

Fixes #100097

@TaiPhamD TaiPhamD requested review from BowenBao and abock as code owners May 5, 2023 09:24
@pytorch-bot
Copy link

pytorch-bot bot commented May 5, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/100700

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit 274c2a8:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented May 5, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label May 5, 2023
@TaiPhamD TaiPhamD changed the title add cast operator after sum to match desired dtype [ONNX] add cast operator after sum to match desired dtype May 5, 2023
@TaiPhamD TaiPhamD changed the title [ONNX] add cast operator after sum to match desired dtype [ONNX] add cast operator after reduce to match desired dtype May 5, 2023
@TaiPhamD
Copy link
Contributor Author

TaiPhamD commented May 5, 2023

I found a few issues with this one so will close until i fix it then will re-open.

@TaiPhamD TaiPhamD closed this May 5, 2023
@TaiPhamD TaiPhamD reopened this May 5, 2023
@BowenBao BowenBao added module: onnx Related to torch.onnx topic: bug fixes topic category labels May 5, 2023
Copy link
Collaborator

@thiagocrepaldi thiagocrepaldi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR. I have pointed out an issue in which how you accessed the scalar tpe for a torch._C.Value. Please use the recommended way

result = symbolic(g, self)
if dtype_onnx is not None:
result_dtype_scalar = result.type().scalarType()
result_dtype_onnx = _type_utils.JitScalarType._from_name(result_dtype_scalar).onnx_type()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parsing type from name is the #1 source of problems (because of the previous comment). Please only use from_value or from_dtype.

return symbolic(g, self)
result = symbolic(g, self)
if dtype_onnx is not None:
result_dtype_scalar = result.type().scalarType()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can cause segfault/assert crashes when model is torch.jit.script.
To prevent this, use exclusively JitScalarType's public APIs to extract type information from torch._C.Value nodes in a safe way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you I will make that change.

@TaiPhamD TaiPhamD requested a review from thiagocrepaldi May 5, 2023 18:03
@TaiPhamD
Copy link
Contributor Author

TaiPhamD commented May 5, 2023

sorry for the lint issue. I used flake8 to test for lint issues locally but just realized the CI uses lintrunner so I'll make that change.

@thiagocrepaldi
Copy link
Collaborator

sorry for the lint issue. I used flake8 to test for lint issues locally but just realized the CI uses lintrunner so I'll make that change.

I always do make setup_lint, lintrunner init and lintrunner -a before pushing a PR to make sure linting is fine

@BowenBao
Copy link
Collaborator

BowenBao commented May 5, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 5, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: bug fixes topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[onnx export] torch.sum unexpected return type
5 participants