diff --git a/tests/test_backend.py b/tests/test_backend.py index a3cccf9dc..80829690b 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -392,6 +392,77 @@ def test_conv2d_6(self): kernel_val = np.arange(1, 1 + np.prod(kernel_shape)).astype("float32").reshape(kernel_shape) self._conv_test(x_val, kernel_val, strides=strides, padding="VALID", rtol=1e-05) + def test_conv3d_1(self): + strides = [1, 1, 1, 1, 1] + dilations = [1, 1, 1, 1, 1] + x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32) + w = np.random.random_sample([2, 3, 4, 5, 6]).astype(np.float32) + padding = "VALID" + def func(x): + kernel = tf.constant(w, dtype=tf.float32, name='k') + conv = tf.nn.conv3d(x, kernel, strides=strides, padding=padding, data_format="NDHWC", dilations=dilations) + return tf.identity(conv, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05) + + def test_conv3d_2(self): + strides = [1, 2, 3, 1, 1] + dilations = [1, 1, 1, 1, 1] + x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32) + w = np.random.random_sample([2, 3, 4, 5, 6]).astype(np.float32) + padding = "VALID" + def func(x): + kernel = tf.constant(w, dtype=tf.float32, name='k') + conv = tf.nn.conv3d(x, kernel, strides=strides, padding=padding, data_format="NDHWC", dilations=dilations) + return tf.identity(conv, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05) + + def test_conv3d_3(self): + strides = [1, 2, 3, 1, 1] + dilations = [1, 1, 1, 1, 1] + x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32) + w = np.random.random_sample([2, 3, 4, 5, 6]).astype(np.float32) + padding = "SAME" + def func(x): + kernel = tf.constant(w, dtype=tf.float32, name='k') + conv = tf.nn.conv3d(x, kernel, strides=strides, padding=padding, data_format="NDHWC", dilations=dilations) + return tf.identity(conv, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05) + + def test_avgpool3d(self): + strides = [1, 1, 1, 1, 1] + ksize = [1, 2, 2, 3, 1] + x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32) + padding = "VALID" + + def func(x): + mp = tf.nn.avg_pool3d(x, ksize, strides, padding=padding, data_format="NDHWC") + return tf.identity(mp, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + + def test_maxpool3d(self): + strides = [1, 1, 1, 1, 1] + ksize = [1, 2, 2, 3, 1] + x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32) + padding = "VALID" + + def func(x): + mp = tf.nn.max_pool3d(x, ksize, strides, padding=padding, data_format="NDHWC") + return tf.identity(mp, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + + @check_tf_min_version("1.14", "tf.nn.avg_pool2d doesn't exist before tf 1.14") + def test_avgpool2d(self): + strides = [1, 1, 1, 1] + ksize = [1, 2, 3, 1] + x_val = make_xval([2, 10, 12, 3]) + padding = "VALID" + + def func(x): + mp = tf.nn.avg_pool2d(x, ksize, strides, padding=padding, data_format="NHWC") + return tf.identity(mp, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + + @check_tf_min_version("1.7", "tf only support dilation is 1 for now") def test_conv2d_7(self): x_shape = [1, 35, 35, 288] # out: [1, 17, 17, 384] diff --git a/tf2onnx/onnx_opset/nn.py b/tf2onnx/onnx_opset/nn.py index f385e7a34..d3dec94d9 100644 --- a/tf2onnx/onnx_opset/nn.py +++ b/tf2onnx/onnx_opset/nn.py @@ -254,6 +254,15 @@ def add_padding(ctx, node, kernel_shape, strides, dilations=None, spatial=2): else: raise ValueError("invalid padding value: {}".format(padding)) +def parse_dims_attr(node, dims, spatial): + if is_channels_last(node): + # We have (N, ..., C) or (...). + if len(dims) != spatial: + dims = dims[1:-1] + else: + # We have (N, C, ...). + dims = dims[2:] + return dims def conv_dims_attr(node, name, new_name=None, spatial=2): # Fetch attribute. @@ -266,13 +275,7 @@ def conv_dims_attr(node, name, new_name=None, spatial=2): # Get spatial part. dims = dims.ints - if is_channels_last(node): - # We have (N, ..., C) or (...). - if len(dims) != spatial: - dims = dims[1:-1] - else: - # We have (N, C, ...). - dims = dims[2:] + dims = parse_dims_attr(node, dims, spatial) # Set new value and return it. node.set_attr(new_name, dims) @@ -475,7 +478,7 @@ def version_1(cls, ctx, node, **kwargs): @tf_op(["AvgPool", "AvgPool3D"], onnx_op="AveragePool") -@tf_op(["MaxPool", "MaxPoolV2"], onnx_op="MaxPool") +@tf_op(["MaxPool", "MaxPoolV2", "MaxPool3D"], onnx_op="MaxPool") class PoolOp: @classmethod def version_1(cls, ctx, node, **kwargs): @@ -497,6 +500,11 @@ def _convert(cls, ctx, node, **kwargs): # @AttrType.INTS strides) # above seems wrong - input[1] is ksize, input[2] is strides # stride and ksize in tf is not always NHWC, so watch out when converting into onnx's NCHW + if kwargs["tf_op"] in ["AvgPool3D", "MaxPool3D"]: + spatial = 3 + else: + spatial = 2 + if len(node.input) < 3: kernel_shape_tf = node.get_attr("ksize").ints strides_tf = node.get_attr("strides").ints @@ -506,17 +514,14 @@ def _convert(cls, ctx, node, **kwargs): ctx.remove_input(node, node.input[2]) ctx.remove_input(node, node.input[1]) - if node.is_nhwc(): - kernel_shape_hw = kernel_shape_tf[1:3] - strides_hw = strides_tf[1:3] - else: - kernel_shape_hw = kernel_shape_tf[2:4] - strides_hw = strides_tf[2:4] + kernel_shape_hw = parse_dims_attr(node, kernel_shape_tf, spatial) + strides_hw = parse_dims_attr(node, strides_tf, spatial) + node.set_attr("kernel_shape", kernel_shape_hw) node.set_attr("strides", strides_hw) - conv_dims_attr(node, "dilations") - add_padding(ctx, node, kernel_shape_hw, strides_hw) - conv_convert_inputs(ctx, node, with_kernel=False) + dilations = conv_dims_attr(node, "dilations", spatial=spatial) + add_padding(ctx, node, kernel_shape_hw, strides_hw, dilations=dilations, spatial=spatial) + conv_convert_inputs(ctx, node, with_kernel=False, spatial=spatial) @tf_op(["MaxPoolWithArgmax"], onnx_op="MaxPool") diff --git a/tf2onnx/tfonnx.py b/tf2onnx/tfonnx.py index fb5879744..a2ed3e6b8 100644 --- a/tf2onnx/tfonnx.py +++ b/tf2onnx/tfonnx.py @@ -235,6 +235,7 @@ def tensorflow_onnx_mapping(g, ops_mapping): # if there is a onnx_op key we'll map the old type to a new type onnx_op = kwargs.get("onnx_op") if onnx_op: + kwargs["tf_op"] = op node.type = onnx_op body_graphs = node.get_body_graphs() if body_graphs: