diff --git a/tests/test_backend.py b/tests/test_backend.py index 059214242..0d22c4734 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -6138,6 +6138,20 @@ def func(x): x_val = make_xval([2, 3]) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @check_opset_min_version(11, "Pad") + def test_conv_unknown_kernel_channels(self): + x_shape = [2, 10, 3] + x_val = make_xval(x_shape) + kernel_shape = [4, 3, 5] + kernel_val = make_xval(kernel_shape) + pad_val = np.array([[0, 0], [0, 0], [0, 0]], np.int64) + def func(x, kernel, pad): + # Make kernel dimensions unknown + kernel = tf.pad(kernel, pad) + conv = tf.nn.conv1d(x, kernel, stride=[1], padding='VALID') + return tf.identity(conv, name='output') + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: kernel_val, _INPUT2: pad_val}) + @check_tf_min_version("2.3.0") @check_opset_min_version(16, "ScatterND") @skip_tfjs("not supported in tfjs") diff --git a/tf2onnx/onnx_opset/nn.py b/tf2onnx/onnx_opset/nn.py index 1774adc93..e91cd3b13 100644 --- a/tf2onnx/onnx_opset/nn.py +++ b/tf2onnx/onnx_opset/nn.py @@ -311,8 +311,9 @@ def conv_kernel_shape(ctx, node, input_idx, spatial=2): # Get spatial part. kernel_shape = kernel_shape[:spatial] - # Set new value and return it. - node.set_attr("kernel_shape", kernel_shape) + # Set attribute value only if all dimensions are known. + if all(d > 0 for d in kernel_shape): + node.set_attr("kernel_shape", kernel_shape) return kernel_shape @@ -379,11 +380,13 @@ def any_version(cls, opset, ctx, node, **kwargs): data_format = str(node.attr["data_format"].s, encoding="utf8") shape_dim = -1 if data_format == "NHWC": - shape_dim = ctx.get_shape(node.input[0])[3] + shape_dim = ctx.get_shape(node.input[0])[-1] elif data_format == "NCHW": shape_dim = ctx.get_shape(node.input[0])[1] if shape_dim != -1: - groups = int(shape_dim / ctx.get_shape(node.input[1])[2]) + filter_in_channels = ctx.get_shape(node.input[1])[-2] + if filter_in_channels != -1: + groups = shape_dim // filter_in_channels node.set_attr("group", groups) @@ -649,7 +652,8 @@ def version_1(cls, ctx, node, **kwargs): raise ValueError("input channel must be positive") k_output_channels = k_input_channels * k_channel_multiplier - node.set_attr("kernel_shape", [k_h, k_w]) + if k_h > 0 and k_w > 0: + node.set_attr("kernel_shape", [k_h, k_w]) strides = conv_dims_attr(node, "strides") dilations = conv_dims_attr(node, "dilations") node.set_attr("group", k_input_channels)