Skip to content

Commit

Permalink
Implemented MaxPool3D and AvgPool3D
Browse files Browse the repository at this point in the history
  • Loading branch information
TomWildenhain-Microsoft committed Jul 24, 2020
1 parent 85bca92 commit f34aa95
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 17 deletions.
71 changes: 71 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,77 @@ def test_conv2d_6(self):
kernel_val = np.arange(1, 1 + np.prod(kernel_shape)).astype("float32").reshape(kernel_shape)
self._conv_test(x_val, kernel_val, strides=strides, padding="VALID", rtol=1e-05)

def test_conv3d_1(self):
strides = [1, 1, 1, 1, 1]
dilations = [1, 1, 1, 1, 1]
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
w = np.random.random_sample([2, 3, 4, 5, 6]).astype(np.float32)
padding = "VALID"
def func(x):
kernel = tf.constant(w, dtype=tf.float32, name='k')
conv = tf.nn.conv3d(x, kernel, strides=strides, padding=padding, data_format="NDHWC", dilations=dilations)
return tf.identity(conv, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05)

def test_conv3d_2(self):
strides = [1, 2, 3, 1, 1]
dilations = [1, 1, 1, 1, 1]
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
w = np.random.random_sample([2, 3, 4, 5, 6]).astype(np.float32)
padding = "VALID"
def func(x):
kernel = tf.constant(w, dtype=tf.float32, name='k')
conv = tf.nn.conv3d(x, kernel, strides=strides, padding=padding, data_format="NDHWC", dilations=dilations)
return tf.identity(conv, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05)

def test_conv3d_3(self):
strides = [1, 2, 3, 1, 1]
dilations = [1, 1, 1, 1, 1]
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
w = np.random.random_sample([2, 3, 4, 5, 6]).astype(np.float32)
padding = "SAME"
def func(x):
kernel = tf.constant(w, dtype=tf.float32, name='k')
conv = tf.nn.conv3d(x, kernel, strides=strides, padding=padding, data_format="NDHWC", dilations=dilations)
return tf.identity(conv, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05)

def test_avgpool3d(self):
strides = [1, 1, 1, 1, 1]
ksize = [1, 2, 2, 3, 1]
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
padding = "VALID"

def func(x):
mp = tf.nn.avg_pool3d(x, ksize, strides, padding=padding, data_format="NDHWC")
return tf.identity(mp, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

def test_maxpool3d(self):
strides = [1, 1, 1, 1, 1]
ksize = [1, 2, 2, 3, 1]
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
padding = "VALID"

def func(x):
mp = tf.nn.max_pool3d(x, ksize, strides, padding=padding, data_format="NDHWC")
return tf.identity(mp, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@check_tf_min_version("1.14", "tf.nn.avg_pool2d doesn't exist before tf 1.14")
def test_avgpool2d(self):
strides = [1, 1, 1, 1]
ksize = [1, 2, 3, 1]
x_val = make_xval([2, 10, 12, 3])
padding = "VALID"

def func(x):
mp = tf.nn.avg_pool2d(x, ksize, strides, padding=padding, data_format="NHWC")
return tf.identity(mp, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})


@check_tf_min_version("1.7", "tf only support dilation is 1 for now")
def test_conv2d_7(self):
x_shape = [1, 35, 35, 288] # out: [1, 17, 17, 384]
Expand Down
39 changes: 22 additions & 17 deletions tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,15 @@ def add_padding(ctx, node, kernel_shape, strides, dilations=None, spatial=2):
else:
raise ValueError("invalid padding value: {}".format(padding))

def parse_dims_attr(node, dims, spatial):
if is_channels_last(node):
# We have (N, ..., C) or (...).
if len(dims) != spatial:
dims = dims[1:-1]
else:
# We have (N, C, ...).
dims = dims[2:]
return dims

def conv_dims_attr(node, name, new_name=None, spatial=2):
# Fetch attribute.
Expand All @@ -266,13 +275,7 @@ def conv_dims_attr(node, name, new_name=None, spatial=2):

# Get spatial part.
dims = dims.ints
if is_channels_last(node):
# We have (N, ..., C) or (...).
if len(dims) != spatial:
dims = dims[1:-1]
else:
# We have (N, C, ...).
dims = dims[2:]
dims = parse_dims_attr(node, dims, spatial)

# Set new value and return it.
node.set_attr(new_name, dims)
Expand Down Expand Up @@ -475,7 +478,7 @@ def version_1(cls, ctx, node, **kwargs):


@tf_op(["AvgPool", "AvgPool3D"], onnx_op="AveragePool")
@tf_op(["MaxPool", "MaxPoolV2"], onnx_op="MaxPool")
@tf_op(["MaxPool", "MaxPoolV2", "MaxPool3D"], onnx_op="MaxPool")
class PoolOp:
@classmethod
def version_1(cls, ctx, node, **kwargs):
Expand All @@ -497,6 +500,11 @@ def _convert(cls, ctx, node, **kwargs):
# @AttrType.INTS strides)
# above seems wrong - input[1] is ksize, input[2] is strides
# stride and ksize in tf is not always NHWC, so watch out when converting into onnx's NCHW
if kwargs["tf_op"] in ["AvgPool3D", "MaxPool3D"]:
spatial = 3
else:
spatial = 2

if len(node.input) < 3:
kernel_shape_tf = node.get_attr("ksize").ints
strides_tf = node.get_attr("strides").ints
Expand All @@ -506,17 +514,14 @@ def _convert(cls, ctx, node, **kwargs):
ctx.remove_input(node, node.input[2])
ctx.remove_input(node, node.input[1])

if node.is_nhwc():
kernel_shape_hw = kernel_shape_tf[1:3]
strides_hw = strides_tf[1:3]
else:
kernel_shape_hw = kernel_shape_tf[2:4]
strides_hw = strides_tf[2:4]
kernel_shape_hw = parse_dims_attr(node, kernel_shape_tf, spatial)
strides_hw = parse_dims_attr(node, strides_tf, spatial)

node.set_attr("kernel_shape", kernel_shape_hw)
node.set_attr("strides", strides_hw)
conv_dims_attr(node, "dilations")
add_padding(ctx, node, kernel_shape_hw, strides_hw)
conv_convert_inputs(ctx, node, with_kernel=False)
dilations = conv_dims_attr(node, "dilations", spatial=spatial)
add_padding(ctx, node, kernel_shape_hw, strides_hw, dilations=dilations, spatial=spatial)
conv_convert_inputs(ctx, node, with_kernel=False, spatial=spatial)


@tf_op(["MaxPoolWithArgmax"], onnx_op="MaxPool")
Expand Down
1 change: 1 addition & 0 deletions tf2onnx/tfonnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def tensorflow_onnx_mapping(g, ops_mapping):
# if there is a onnx_op key we'll map the old type to a new type
onnx_op = kwargs.get("onnx_op")
if onnx_op:
kwargs["tf_op"] = op
node.type = onnx_op
body_graphs = node.get_body_graphs()
if body_graphs:
Expand Down

0 comments on commit f34aa95

Please sign in to comment.