-
Notifications
You must be signed in to change notification settings - Fork 431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixed group attribute in convolution op #2090
Conversation
…ape dimensions so it works for different dimensionalities. Signed-off-by: Javier Dehesa <javidcf@gmail.com>
The code LGTM, thanks for your contributions. Could you please add a test (https://github.com/onnx/tensorflow-onnx/blob/main/tests/test_backend.py) to cover this? |
I wrote this test case for this: @check_opset_min_version(11, "Pad")
def test_conv_unknown_kernel_channels(self):
x_shape = [2, 10, 3]
x_val = make_xval(x_shape)
kernel_shape = [4, 3, 5]
kernel_val = make_xval(kernel_shape)
pad_val = np.array([[0, 0], [0, 0], [0, 0]], np.int64)
def func(x, kernel, pad):
# Make kernel dimensions unknown
kernel = tf.pad(kernel, pad)
conv = tf.nn.conv1d(x, kernel, stride=[1], padding='VALID')
return tf.identity(conv, name='output')
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: kernel_val, _INPUT2: pad_val}) Unfortunately, in the process of doing so I found another bug (already present before the fix) with convolution kernels with unknown shape. My fix solves the issue of the wrong value in the op
The error can also be reproduced before my fix by simply adding the following line before the call to x = tf.pad(x, pad) The problem is I think this is actually an onnxruntime issue. This line from onnxruntime seems to suggest the code is not properly checking that the dimensions are known (not -1), or maybe Either way, I'm not sure if a different test should be added which somehow just checks the attribute value with a EDIT: Another possibility is that the |
Skip adding the optional attribute to Conv nodes if any shape value is negative. Signed-off-by: Javier Dehesa <javidcf@gmail.com>
Signed-off-by: Javier Dehesa <javidcf@gmail.com>
Fixed convolution kernel dimension checks Signed-off-by: Javier Dehesa <javidcf@gmail.com>
Apologies, I had deleted this branch by mistake. |
Could you please resolve the conflict in test_backend.py file? |
Signed-off-by: Javier Dehesa <javidcf@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contributions!
Also minor change reading shape dimensions so it works for different dimensionalities.
Fixes #2084