Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented MaxPool3D and AvgPool3D #1020

Merged
merged 1 commit into from
Jul 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,77 @@ def test_conv2d_6(self):
kernel_val = np.arange(1, 1 + np.prod(kernel_shape)).astype("float32").reshape(kernel_shape)
self._conv_test(x_val, kernel_val, strides=strides, padding="VALID", rtol=1e-05)

def test_conv3d_1(self):
strides = [1, 1, 1, 1, 1]
dilations = [1, 1, 1, 1, 1]
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
w = np.random.random_sample([2, 3, 4, 5, 6]).astype(np.float32)
padding = "VALID"
def func(x):
kernel = tf.constant(w, dtype=tf.float32, name='k')
conv = tf.nn.conv3d(x, kernel, strides=strides, padding=padding, data_format="NDHWC", dilations=dilations)
return tf.identity(conv, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05)

def test_conv3d_2(self):
strides = [1, 2, 3, 1, 1]
dilations = [1, 1, 1, 1, 1]
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
w = np.random.random_sample([2, 3, 4, 5, 6]).astype(np.float32)
padding = "VALID"
def func(x):
kernel = tf.constant(w, dtype=tf.float32, name='k')
conv = tf.nn.conv3d(x, kernel, strides=strides, padding=padding, data_format="NDHWC", dilations=dilations)
return tf.identity(conv, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05)

def test_conv3d_3(self):
strides = [1, 2, 3, 1, 1]
dilations = [1, 1, 1, 1, 1]
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
w = np.random.random_sample([2, 3, 4, 5, 6]).astype(np.float32)
padding = "SAME"
def func(x):
kernel = tf.constant(w, dtype=tf.float32, name='k')
conv = tf.nn.conv3d(x, kernel, strides=strides, padding=padding, data_format="NDHWC", dilations=dilations)
return tf.identity(conv, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05)

def test_avgpool3d(self):
strides = [1, 1, 1, 1, 1]
ksize = [1, 2, 2, 3, 1]
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
padding = "VALID"

def func(x):
mp = tf.nn.avg_pool3d(x, ksize, strides, padding=padding, data_format="NDHWC")
return tf.identity(mp, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

def test_maxpool3d(self):
strides = [1, 1, 1, 1, 1]
ksize = [1, 2, 2, 3, 1]
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
padding = "VALID"

def func(x):
mp = tf.nn.max_pool3d(x, ksize, strides, padding=padding, data_format="NDHWC")
return tf.identity(mp, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@check_tf_min_version("1.14", "tf.nn.avg_pool2d doesn't exist before tf 1.14")
def test_avgpool2d(self):
strides = [1, 1, 1, 1]
ksize = [1, 2, 3, 1]
x_val = make_xval([2, 10, 12, 3])
padding = "VALID"

def func(x):
mp = tf.nn.avg_pool2d(x, ksize, strides, padding=padding, data_format="NHWC")
return tf.identity(mp, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})


@check_tf_min_version("1.7", "tf only support dilation is 1 for now")
def test_conv2d_7(self):
x_shape = [1, 35, 35, 288] # out: [1, 17, 17, 384]
Expand Down
39 changes: 22 additions & 17 deletions tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,15 @@ def add_padding(ctx, node, kernel_shape, strides, dilations=None, spatial=2):
else:
raise ValueError("invalid padding value: {}".format(padding))

def parse_dims_attr(node, dims, spatial):
if is_channels_last(node):
# We have (N, ..., C) or (...).
if len(dims) != spatial:
dims = dims[1:-1]
else:
# We have (N, C, ...).
dims = dims[2:]
return dims

def conv_dims_attr(node, name, new_name=None, spatial=2):
# Fetch attribute.
Expand All @@ -266,13 +275,7 @@ def conv_dims_attr(node, name, new_name=None, spatial=2):

# Get spatial part.
dims = dims.ints
if is_channels_last(node):
# We have (N, ..., C) or (...).
if len(dims) != spatial:
dims = dims[1:-1]
else:
# We have (N, C, ...).
dims = dims[2:]
dims = parse_dims_attr(node, dims, spatial)

# Set new value and return it.
node.set_attr(new_name, dims)
Expand Down Expand Up @@ -475,7 +478,7 @@ def version_1(cls, ctx, node, **kwargs):


@tf_op(["AvgPool", "AvgPool3D"], onnx_op="AveragePool")
@tf_op(["MaxPool", "MaxPoolV2"], onnx_op="MaxPool")
@tf_op(["MaxPool", "MaxPoolV2", "MaxPool3D"], onnx_op="MaxPool")
class PoolOp:
@classmethod
def version_1(cls, ctx, node, **kwargs):
Expand All @@ -497,6 +500,11 @@ def _convert(cls, ctx, node, **kwargs):
# @AttrType.INTS strides)
# above seems wrong - input[1] is ksize, input[2] is strides
# stride and ksize in tf is not always NHWC, so watch out when converting into onnx's NCHW
if kwargs["tf_op"] in ["AvgPool3D", "MaxPool3D"]:
spatial = 3
else:
spatial = 2

if len(node.input) < 3:
kernel_shape_tf = node.get_attr("ksize").ints
strides_tf = node.get_attr("strides").ints
Expand All @@ -506,17 +514,14 @@ def _convert(cls, ctx, node, **kwargs):
ctx.remove_input(node, node.input[2])
ctx.remove_input(node, node.input[1])

if node.is_nhwc():
kernel_shape_hw = kernel_shape_tf[1:3]
strides_hw = strides_tf[1:3]
else:
kernel_shape_hw = kernel_shape_tf[2:4]
strides_hw = strides_tf[2:4]
kernel_shape_hw = parse_dims_attr(node, kernel_shape_tf, spatial)
strides_hw = parse_dims_attr(node, strides_tf, spatial)

node.set_attr("kernel_shape", kernel_shape_hw)
node.set_attr("strides", strides_hw)
conv_dims_attr(node, "dilations")
add_padding(ctx, node, kernel_shape_hw, strides_hw)
conv_convert_inputs(ctx, node, with_kernel=False)
dilations = conv_dims_attr(node, "dilations", spatial=spatial)
add_padding(ctx, node, kernel_shape_hw, strides_hw, dilations=dilations, spatial=spatial)
conv_convert_inputs(ctx, node, with_kernel=False, spatial=spatial)


@tf_op(["MaxPoolWithArgmax"], onnx_op="MaxPool")
Expand Down
1 change: 1 addition & 0 deletions tf2onnx/tfonnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def tensorflow_onnx_mapping(g, ops_mapping):
# if there is a onnx_op key we'll map the old type to a new type
onnx_op = kwargs.get("onnx_op")
if onnx_op:
kwargs["tf_op"] = op
node.type = onnx_op
body_graphs = node.get_body_graphs()
if body_graphs:
Expand Down