Skip to content

Commit

Permalink
Fix logaddexp for ONNX export (#1158)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Jul 2, 2023
1 parent 98d8946 commit c3e23ec
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions egs/librispeech/ASR/zipformer/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,24 @@
# The following function is to solve the above error when exporting
# models to ONNX via torch.jit.trace()
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
if not torch.jit.is_tracing():
# Caution(fangjun): Put torch.jit.is_scripting() before
# torch.onnx.is_in_onnx_export();
# otherwise, it will cause errors for torch.jit.script().
#
# torch.logaddexp() works for both torch.jit.script() and
# torch.jit.trace() but it causes errors for ONNX export.
#
if torch.jit.is_scripting():
# Note: We cannot use torch.jit.is_tracing() here as it also
# matches torch.onnx.export().
return torch.logaddexp(x, y)
else:
elif torch.onnx.is_in_onnx_export():
max_value = torch.max(x, y)
diff = torch.abs(x - y)
return max_value + torch.log1p(torch.exp(-diff))
else:
# for torch.jit.trace()
return torch.logaddexp(x, y)

class PiecewiseLinear(object):
"""
Expand Down

5 comments on commit c3e23ec

@joazoa
Copy link

@joazoa joazoa commented on c3e23ec Jul 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still getting the
Exporting the operator 'aten::logaddexp' to ONNX opset version 13 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.
None

I don't have the problem with the version before this commit: 219bba1

@csukuangfj
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry. Will recheck it.

@csukuangfj
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry. I cannot reproduce your issue. The latest master including this PR works perfectly for me.


FYI: I have used the following script for testing when creating this PR:

#!/usr/bin/env python3

import torch


class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        if torch.jit.is_scripting():
            return torch.logaddexp(x, y)
        elif torch.onnx.is_in_onnx_export():
            max_value = torch.max(x, y)
            diff = torch.abs(x - y)
            return max_value + torch.log1p(torch.exp(-diff))
        else:
            return torch.logaddexp(x, y)


def main():
    f = Foo()
    x = torch.rand(3)
    m1 = torch.jit.script(f)
    print("---m1---")
    print(m1.graph)

    m2 = torch.jit.trace(f, (x, x))
    print("---m2---")
    print(m2.graph)

    torch.onnx.export(f, (x, x), "a.onnx")


if __name__ == "__main__":
    main()

The output is given below:

---m1---
graph(%self : __torch__.Foo,
      %x.1 : Tensor,
      %y.1 : Tensor):
  %7 : Tensor = aten::logaddexp(%x.1, %y.1) # ./a.py:12:19
  return (%7)

---m2---
graph(%self : __torch__.___torch_mangle_0.Foo,
      %x : Float(3, strides=[1], requires_grad=0, device=cpu),
      %y : Float(3, strides=[1], requires_grad=0, device=cpu)):
  %8 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::logaddexp(%x, %y) # ./a.py:18:0
  return (%8)

The screenshot for the following command is

netron ./a.onnx
Screenshot 2023-07-03 at 20 41 28

You can see that when using torch.jit.script() and torch.jit.trace(), it indeed is using torch.logaddexp() for exporting.
But for torch.onnx.export(), it is using our customized implementation.

@joazoa
Copy link

@joazoa joazoa commented on c3e23ec Jul 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I am probably doing something wrong.

When i run your test script, i get the same output as you pasted above.

When i add some debug prints, i end up 3x in the
else:
return torch.logaddexp(x, y)

and the 4th time, i end up in
elif torch.onnx.is_in_onnx_export():

However when I run the onnx_export.py file:
I end up here 4 times:
else:
return torch.logaddexp(x, y)

@csukuangfj
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However when I run the onnx_export.py file:

Oh, I see. You are getting errors when running export-onnx.py, but I was testing with export-onnnx-streaming.py.


I just ran export-onnx.py and got the same error as you. Please wait for a moment. Will try to fix it.

Please sign in to comment.