Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] add max_pool3d in relay and TF converter #4551

Merged
merged 2 commits into from
Dec 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ This level enables typical convnet models.
tvm.relay.nn.conv2d_transpose
tvm.relay.nn.dense
tvm.relay.nn.max_pool2d
tvm.relay.nn.max_pool3d
tvm.relay.nn.avg_pool2d
tvm.relay.nn.avg_pool3d
tvm.relay.nn.global_max_pool2d
tvm.relay.nn.global_avg_pool2d
tvm.relay.nn.upsampling
Expand Down Expand Up @@ -246,7 +248,9 @@ Level 2 Definitions
.. autofunction:: tvm.relay.nn.conv2d_transpose
.. autofunction:: tvm.relay.nn.dense
.. autofunction:: tvm.relay.nn.max_pool2d
.. autofunction:: tvm.relay.nn.max_pool3d
.. autofunction:: tvm.relay.nn.avg_pool2d
.. autofunction:: tvm.relay.nn.avg_pool3d
.. autofunction:: tvm.relay.nn.global_max_pool2d
.. autofunction:: tvm.relay.nn.global_avg_pool2d
.. autofunction:: tvm.relay.nn.upsampling
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,10 @@ def __call__(self, args, attrs, type_args):
"nn.dense": op.nn.dense,
"nn.bias_add": op.nn.bias_add,
"nn.max_pool2d": op.nn.max_pool2d,
"nn.max_pool3d": op.nn.max_pool3d,
"nn.global_max_pool2d": op.nn.global_max_pool2d,
"nn.avg_pool2d": op.nn.avg_pool2d,
"nn.avg_pool3d": op.nn.avg_pool3d,
"nn.global_avg_pool2d": op.nn.global_avg_pool2d,
"nn.softmax": op.nn.softmax,
"reshape": op.reshape,
Expand Down
66 changes: 66 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,70 @@ def _impl(inputs, attr, params):
return get_relay_op(name)(*inputs)
return _impl

def _pool3d(name):
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False

input_shape = attr['_input_shapes'][inputs[0]]

if attr['data_format'] == 'NDHWC':
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2], attr['ksize'][3])
attr['strides'] = (attr['strides'][1], attr['strides'][2], attr['strides'][3])
elif attr['data_format'] == 'NCDHW':
attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3], attr['ksize'][4])
attr['strides'] = (attr['strides'][2], attr['strides'][3], attr['strides'][4])
else:
msg = 'Value {} of attribute "data_format" of operator Pooling ' \
'is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
if attr['data_format'] == "NDHWC":
input_shape = [attr['_input_shapes'][inputs[0]][i] for i in (0, 4, 1, 2, 3)]
inputs[0] = _op.transpose(inputs[0], axes=(0, 4, 1, 2, 3))
attr['data_format'] = "NCDHW"
yongwww marked this conversation as resolved.
Show resolved Hide resolved
attr['_input_shapes'][inputs[0]] = input_shape
flip_layout = True

attr['padding'] = attr['padding'].decode("utf-8")

if attr['padding'] == 'VALID':
attr['padding'] = [0, 0, 0, 0, 0, 0]
elif attr['padding'] == 'SAME':
stride_d, stride_h, stride_w = attr['strides']
kernel_d, kernel_h, kernel_w = attr['kernel_shape']
if attr['data_format'] == 'NDHWC':
in_d = input_shape[1]
in_h = input_shape[2]
in_w = input_shape[3]
else:
in_d = input_shape[2]
in_h = input_shape[3]
in_w = input_shape[4]
pad_d = _get_pad_pair(in_d, kernel_d, stride_d)
pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)

attr['padding'] = [pad_d[0], pad_v[0], pad_h[0], pad_d[1], pad_v[1], pad_h[1]]
else:
msg = 'Value {} in attribute "padding" of operator Pooling is ' \
'not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))

if name == "avg_pool":
attr['count_include_pad'] = False
attr['ceil_mode'] = False
out = AttrCvt(
op_name=name,
transforms={
'kernel_shape': 'pool_size',
'data_format': 'layout'},
ignores=['ksize'])(inputs, attr)
if flip_layout:
out = _op.transpose(out, axes=(0, 2, 3, 4, 1))
return out

return _impl

def _pooling(name):
def _impl(inputs, attr, params):

Expand Down Expand Up @@ -1409,6 +1473,7 @@ def _impl(inputs, attr, params):
'ArgMin' : _argx(_op.argmin, 'argmin'),
'Assert' : _assert(),
'AvgPool' : _pooling('avg_pool'),
'AvgPool3D' : _pool3d('avg_pool3d'),
'BatchMatMul' : _batch_matmul(),
'BatchMatMulV2' : _batch_matmul(),
'BatchNormWithGlobalNormalization' : _batch_norm(),
Expand Down Expand Up @@ -1460,6 +1525,7 @@ def _impl(inputs, attr, params):
'MatMul' : _matmul(),
'Max' : _reduce('max'),
'MaxPool' : _pooling('max_pool'),
'MaxPool3D' : _pool3d('max_pool3d'),
'Maximum' : _elemwise('maximum'),
'Mean' : _mean(),
'Min' : _reduce('min'),
Expand Down
25 changes: 24 additions & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,18 @@ def schedule_max_pool2d(attrs, outs, target):
reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)


# max_pool3d
@reg.register_schedule("nn.max_pool3d")
def schedule_max_pool3d(attrs, outs, target):
"""Schedule definition of max_pool3d"""
layout = attrs.layout
with target:
return topi.generic.schedule_pool(outs, layout)


reg.register_pattern("nn.max_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE)


# avg_pool2d
@reg.register_schedule("nn.avg_pool2d")
def schedule_avg_pool2d(attrs, outs, target):
Expand All @@ -404,10 +416,21 @@ def schedule_avg_pool2d(attrs, outs, target):
with target:
return topi.generic.schedule_pool(outs, layout)


reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)


# avg_pool3d
@reg.register_schedule("nn.avg_pool3d")
def schedule_avg_pool3d(attrs, outs, target):
"""Schedule definition of avg_pool3d"""
layout = attrs.layout
with target:
return topi.generic.schedule_pool(outs, layout)


reg.register_pattern("nn.avg_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE)


# max_pool2d_grad
@reg.register_schedule("nn.max_pool2d_grad")
def schedule_max_pool2d_grad(attrs, outs, target):
Expand Down
94 changes: 94 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,51 @@ def max_pool2d(data,
return _make.max_pool2d(data, pool_size, strides, padding,
layout, ceil_mode)

def max_pool3d(data,
pool_size=(1, 1, 1),
strides=(1, 1, 1),
padding=(0, 0, 0),
layout="NCDHW",
ceil_mode=False):
r"""3D maximum pooling operator.

This operator takes data as input and does 3D max value calculation
with in pool_size sized window by striding defined by stride.


In the default case, where the data_layout is `NCDHW`
a data Tensor with shape `(batch_size, channels, depth, height, width)`,
to produce an output Tensor.

The ceil_mode is used to take ceil or floor while computing out shape.
count_include_pad indicates including or excluding padded input values in computation.
This operator accepts data layout specification.

Parameters
----------
data : tvm.relay.Expr
The input data to the operator.

strides : tuple of int, optional
The strides of pooling.

padding : tuple of int, optional
The padding for pooling.

layout : str, optional
Layout of the input.

ceil_mode : bool, optional
To enable or disable ceil while pooling.

Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.max_pool3d(data, pool_size, strides, padding,
layout, ceil_mode)

def avg_pool2d(data,
pool_size=(1, 1),
strides=(1, 1),
Expand Down Expand Up @@ -482,6 +527,55 @@ def avg_pool2d(data,
return _make.avg_pool2d(data, pool_size, strides, padding,
layout, ceil_mode, count_include_pad)

def avg_pool3d(data,
pool_size=(1, 1, 1),
strides=(1, 1, 1),
padding=(0, 0, 0),
layout="NCDHW",
ceil_mode=False,
count_include_pad=False):
r"""3D average pooling operator.

This operator takes data as input and does 3D average value calculation
with in pool_size sized window by striding defined by stride


In the default case, where the data_layout is `NCDHW`
a data Tensor with shape `(batch_size, channels, depthm height, width)`,
to produce an output Tensor.

The ceil_mode is used to take ceil or floor while computing out shape.
count_include_pad indicates including or excluding padded input values in computation.
This operator accepts data layout specification.

Parameters
----------
data : tvm.relay.Expr
The input data to the operator.

strides : tuple of int, optional
The strides of pooling.

padding : tuple of int, optional
The padding for pooling.

layout : str, optional
Layout of the input.

ceil_mode : bool, optional
To enable or disable ceil while pooling.

count_include_pad : bool, optional
To include padding to compute the average.

Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.avg_pool3d(data, pool_size, strides, padding,
layout, ceil_mode, count_include_pad)

def max_pool2d_grad(out_grad,
data,
pool_size=(1, 1),
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,16 @@ class AvgPool2DAttrs(Attrs):
"""Attributes used in avg_pool2d operators"""


@register_relay_attr_node
class MaxPool3DAttrs(Attrs):
"""Attributes used in max_pool3d operators"""


@register_relay_attr_node
class AvgPool3DAttrs(Attrs):
"""Attributes used in avg_pool3d operators"""


@register_relay_attr_node
class BitPackAttrs(Attrs):
"""Attributes used in bitpack operator"""
Expand Down
55 changes: 47 additions & 8 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,16 +237,58 @@ def _test_pooling_iteration(input_shape, **kwargs):
def _test_pooling(input_shape, **kwargs):
_test_pooling_iteration(input_shape, **kwargs)

if is_gpu_available() and (len(input_shape) == 4):
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
kwargs['data_format'] = 'NCHW'
_test_pooling_iteration(input_shape, **kwargs)
if is_gpu_available():
if len(input_shape) == 4:
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
kwargs['data_format'] = 'NCHW'
_test_pooling_iteration(input_shape, **kwargs)


def test_forward_pooling():
""" Pooling """

# TensorFlow only supports NDHWC for max_pool3d on CPU
for pool_type in ['AVG', 'MAX']:
# NDHWC is the default layout for max_pool3d and avg_pool3d in TensorFlow
_test_pooling(input_shape=[1, 3, 32, 32, 32],
window_shape=[2, 2, 2],
padding='VALID',
pooling_type=pool_type,
dilation_rate=[1, 1, 1],
strides=[2, 2, 2])

_test_pooling(input_shape=[1, 3, 32, 32, 32],
window_shape=[1, 1, 1],
padding='SAME',
pooling_type=pool_type,
dilation_rate=[1, 1, 1],
strides=[1, 1, 1])

_test_pooling(input_shape=[1, 3, 32, 32, 32],
window_shape=[2, 2, 2],
padding='SAME',
pooling_type=pool_type,
dilation_rate=[1, 1, 1],
strides=[2, 2, 2])
yongwww marked this conversation as resolved.
Show resolved Hide resolved

# test cases for max_pool3d & avg_pool3d with layout NCDHW
# TensorFlow pool3d doesn't support NCDHW on cpu
if is_gpu_available():
_test_pooling(input_shape=[1, 3, 32, 32, 32],
window_shape=[1, 1, 1],
padding='SAME',
pooling_type=pool_type,
dilation_rate=[1, 1, 1],
strides=[1, 1, 1],
data_format='NCDHW')

_test_pooling(input_shape=[1, 3, 32, 32, 32],
window_shape=[2, 2, 2],
padding='VALID',
pooling_type=pool_type,
dilation_rate=[1, 1, 1],
strides=[2, 2, 2],
data_format='NCDHW')

_test_pooling(input_shape=[2, 9, 10, 2],
window_shape=[1, 1],
padding='SAME',
Expand Down Expand Up @@ -2855,7 +2897,6 @@ def test_forward_add_n():
test_forward_sin()
test_forward_negative()
test_forward_divide()
test_forward_floordiv()
test_forward_abs()
test_forward_softplus()
test_forward_sqrt()
Expand Down Expand Up @@ -2916,5 +2957,3 @@ def test_forward_add_n():
test_forward_where()
test_forward_matmul()
test_forward_batch_matmul()

# TODO missing tests: rank
Loading