Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF1-Keras Conv2DTranspose Bias Term Not Fused. #1740

Closed
leimao opened this issue Oct 13, 2021 · 9 comments
Closed

TF1-Keras Conv2DTranspose Bias Term Not Fused. #1740

leimao opened this issue Oct 13, 2021 · 9 comments
Assignees
Labels
contribution welcome Community contribution is welcomed keras Issues related to Keras

Comments

@leimao
Copy link
Contributor

leimao commented Oct 13, 2021

Describe the bug

Conv2DTranspose with bias term cannot be exported as single Conv2DTranspose ONNX operator. It is exported as Conv2DTranspose + Add instead.

Urgency

We wish it could be fixed asap.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): 20.04.2 LTS
  • Tensorflow Version: 1.15.5
  • Python version: 3.8.10

To Reproduce

In NVIDIA TensorFlow-1 Docker Container:

docker run -it --rm --gpus all -v $(pwd):/mnt nvcr.io/nvidia/tensorflow:21.08-tf1-py3

Install TF2ONNX (tried both the master branch and the latest release 1.9.2)

pip install git+https://github.com/onnx/tensorflow-onnx

Run the following script to generate ONNX file for Conv2DTranspose.

import onnx
import tf2onnx
import tensorflow as tf

inputs = tf.keras.Input(shape=(64, 256, 256))
outputs = tf.keras.layers.Conv2DTranspose(
    filters=128,
    kernel_size=(3, 3),
    strides=(1, 1),
    use_bias=True,
    data_format="channels_first",
    name="conv2d_transpose",
)(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

onnx_model, _ = tf2onnx.convert.from_keras(
    model=model,
    opset=13,
    large_model=False,
)

onnx.save_model(onnx_model, "conv2d_transpose.onnx")

Screenshots

image

Additional context

N/A

@TomWildenhain-Microsoft
Copy link
Contributor

@leimao This should be a relatively easy fix to make. I'm not working on tf2onnx currently and it might be a while until someone else gets to look at this, but if you submit a PR within the next 2 weeks I can help review it.

@TomWildenhain-Microsoft TomWildenhain-Microsoft added the contribution welcome Community contribution is welcomed label Oct 18, 2021
@leimao
Copy link
Contributor Author

leimao commented Oct 18, 2021

@leimao This should be a relatively easy fix to make. I'm not working on tf2onnx currently and it might be a while until someone else gets to look at this, but if you submit a PR within the next 2 weeks I can help review it.

Thank you @TomWildenhain-Microsoft. I can create a MR if it is a simple fix. However, in the worst scenario, if it turns out that the fix is too complicated and time-consuming, do I expect Microsoft will have the bandwidth to fix it two weeks later?

@leimao
Copy link
Contributor Author

leimao commented Oct 19, 2021

@TomWildenhain-Microsoft It seems that it is not quite possible to fuse the bias term in the rewriter, because Conv2DBackpropInput does not accept a bias term as input.

convtranspose

Since the rewriter was not useful, the ONNX ConvTranspose operator for specific opset was generated with this method. That's why the optimization failed.

I think unless we have some optimization after the ONNX ConvTranspose operator was generated, it is impossible to fuse the bias term into the ONNX ConvTranspose operator.

@leimao
Copy link
Contributor Author

leimao commented Oct 19, 2021

Hi @TomWildenhain-Microsoft, do you have any suggestions?

@TomWildenhain-Microsoft
Copy link
Contributor

Rewriters are permitted to output both ONNX ops and TF ops. The current implementation uses TF ops, and you are correct that the TF spec prohibits adding a bias as a 3rd input. However, from my reading of that rewriter, it adds the bias as a 3rd input anyway. This is actually permissible since shape inference has already occurred at this point and we can "extend" the tf spec as a sort of hack, so long as the Conv2DBackpropInput handler properly converts the 3 input version.

Given that the behavior you are seeing is that the BaisAdd is still included (not deleted as the rewriter appears to do), I'm inclined to think the rewriter is not matching the pattern. Try setting allow_reorder to true on the BiasAdd pattern. Maybe it is swapping the input order.

@leimao
Copy link
Contributor Author

leimao commented Oct 19, 2021

Rewriters are permitted to output both ONNX ops and TF ops. The current implementation uses TF ops, and you are correct that the TF spec prohibits adding a bias as a 3rd input. However, from my reading of that rewriter, it adds the bias as a 3rd input anyway. This is actually permissible since shape inference has already occurred at this point and we can "extend" the tf spec as a sort of hack, so long as the Conv2DBackpropInput handler properly converts the 3 input version.

Given that the behavior you are seeing is that the BaisAdd is still included (not deleted as the rewriter appears to do), I'm inclined to think the rewriter is not matching the pattern. Try setting allow_reorder to true on the BiasAdd pattern. Maybe it is swapping the input order.

@TomWildenhain-Microsoft Thank you very much. The rewriter did not match the pattern, simply because the existing rewriter matches Conv/Conv2DBackpropInput that has two inputs, whereas in my case it the Conv2DBackpropInput has three inputs. Let me do some experiment further to see if there can be a simple fix.

In the worst scenario, if there is no simple TF op rewriter fix, do you suggest create a ONNX op rewriter whose pattern was ONNX ops?

@hwangdeyu
Copy link
Contributor

Hi @leimao
I tried to use your script to repro the issue, and I got a graph which is complex than you screenshots shown. Could you provide a python scirpt can repro your issue easily?

@leimao
Copy link
Contributor Author

leimao commented Nov 17, 2021

Thanks @hwangdeyu Please do use the Docker container I mentioned above and TF2ONNX 1.9.2 to reproduce, otherwise it is not quite possible to reproduce any bug, not only for TF2ONNX, but also for other software.

@hwangdeyu
Copy link
Contributor

This script result of this issue is changed cause #1741 added the optimize() logic for tf1 keras conversion.
Please consider creating a new issue if there are leftover issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contribution welcome Community contribution is welcomed keras Issues related to Keras
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants