diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index fc77869a6261..54163ac3ac61 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -71,7 +71,9 @@ This level enables typical convnet models. tvm.relay.nn.conv2d_transpose tvm.relay.nn.dense tvm.relay.nn.max_pool2d + tvm.relay.nn.max_pool3d tvm.relay.nn.avg_pool2d + tvm.relay.nn.avg_pool3d tvm.relay.nn.global_max_pool2d tvm.relay.nn.global_avg_pool2d tvm.relay.nn.upsampling @@ -246,7 +248,9 @@ Level 2 Definitions .. autofunction:: tvm.relay.nn.conv2d_transpose .. autofunction:: tvm.relay.nn.dense .. autofunction:: tvm.relay.nn.max_pool2d +.. autofunction:: tvm.relay.nn.max_pool3d .. autofunction:: tvm.relay.nn.avg_pool2d +.. autofunction:: tvm.relay.nn.avg_pool3d .. autofunction:: tvm.relay.nn.global_max_pool2d .. autofunction:: tvm.relay.nn.global_avg_pool2d .. autofunction:: tvm.relay.nn.upsampling diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 6ec581bdcf7b..45822c56ede2 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -135,8 +135,10 @@ def __call__(self, args, attrs, type_args): "nn.dense": op.nn.dense, "nn.bias_add": op.nn.bias_add, "nn.max_pool2d": op.nn.max_pool2d, + "nn.max_pool3d": op.nn.max_pool3d, "nn.global_max_pool2d": op.nn.global_max_pool2d, "nn.avg_pool2d": op.nn.avg_pool2d, + "nn.avg_pool3d": op.nn.avg_pool3d, "nn.global_avg_pool2d": op.nn.global_avg_pool2d, "nn.softmax": op.nn.softmax, "reshape": op.reshape, diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 460a14699a77..8526f540515d 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -122,6 +122,70 @@ def _impl(inputs, attr, params): return get_relay_op(name)(*inputs) return _impl +def _pool3d(name): + def _impl(inputs, attr, params): + attr['data_format'] = attr['data_format'].decode("utf-8") + flip_layout = False + + input_shape = attr['_input_shapes'][inputs[0]] + + if attr['data_format'] == 'NDHWC': + attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2], attr['ksize'][3]) + attr['strides'] = (attr['strides'][1], attr['strides'][2], attr['strides'][3]) + elif attr['data_format'] == 'NCDHW': + attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3], attr['ksize'][4]) + attr['strides'] = (attr['strides'][2], attr['strides'][3], attr['strides'][4]) + else: + msg = 'Value {} of attribute "data_format" of operator Pooling ' \ + 'is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) + if attr['data_format'] == "NDHWC": + input_shape = [attr['_input_shapes'][inputs[0]][i] for i in (0, 4, 1, 2, 3)] + inputs[0] = _op.transpose(inputs[0], axes=(0, 4, 1, 2, 3)) + attr['data_format'] = "NCDHW" + attr['_input_shapes'][inputs[0]] = input_shape + flip_layout = True + + attr['padding'] = attr['padding'].decode("utf-8") + + if attr['padding'] == 'VALID': + attr['padding'] = [0, 0, 0, 0, 0, 0] + elif attr['padding'] == 'SAME': + stride_d, stride_h, stride_w = attr['strides'] + kernel_d, kernel_h, kernel_w = attr['kernel_shape'] + if attr['data_format'] == 'NDHWC': + in_d = input_shape[1] + in_h = input_shape[2] + in_w = input_shape[3] + else: + in_d = input_shape[2] + in_h = input_shape[3] + in_w = input_shape[4] + pad_d = _get_pad_pair(in_d, kernel_d, stride_d) + pad_v = _get_pad_pair(in_h, kernel_h, stride_h) + pad_h = _get_pad_pair(in_w, kernel_w, stride_w) + + attr['padding'] = [pad_d[0], pad_v[0], pad_h[0], pad_d[1], pad_v[1], pad_h[1]] + else: + msg = 'Value {} in attribute "padding" of operator Pooling is ' \ + 'not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr['padding'])) + + if name == "avg_pool": + attr['count_include_pad'] = False + attr['ceil_mode'] = False + out = AttrCvt( + op_name=name, + transforms={ + 'kernel_shape': 'pool_size', + 'data_format': 'layout'}, + ignores=['ksize'])(inputs, attr) + if flip_layout: + out = _op.transpose(out, axes=(0, 2, 3, 4, 1)) + return out + + return _impl + def _pooling(name): def _impl(inputs, attr, params): @@ -1409,6 +1473,7 @@ def _impl(inputs, attr, params): 'ArgMin' : _argx(_op.argmin, 'argmin'), 'Assert' : _assert(), 'AvgPool' : _pooling('avg_pool'), + 'AvgPool3D' : _pool3d('avg_pool3d'), 'BatchMatMul' : _batch_matmul(), 'BatchMatMulV2' : _batch_matmul(), 'BatchNormWithGlobalNormalization' : _batch_norm(), @@ -1460,6 +1525,7 @@ def _impl(inputs, attr, params): 'MatMul' : _matmul(), 'Max' : _reduce('max'), 'MaxPool' : _pooling('max_pool'), + 'MaxPool3D' : _pool3d('max_pool3d'), 'Maximum' : _elemwise('maximum'), 'Mean' : _mean(), 'Min' : _reduce('min'), diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index e1372ac76480..73bb2e29f365 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -396,6 +396,18 @@ def schedule_max_pool2d(attrs, outs, target): reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) +# max_pool3d +@reg.register_schedule("nn.max_pool3d") +def schedule_max_pool3d(attrs, outs, target): + """Schedule definition of max_pool3d""" + layout = attrs.layout + with target: + return topi.generic.schedule_pool(outs, layout) + + +reg.register_pattern("nn.max_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE) + + # avg_pool2d @reg.register_schedule("nn.avg_pool2d") def schedule_avg_pool2d(attrs, outs, target): @@ -404,10 +416,21 @@ def schedule_avg_pool2d(attrs, outs, target): with target: return topi.generic.schedule_pool(outs, layout) - reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) +# avg_pool3d +@reg.register_schedule("nn.avg_pool3d") +def schedule_avg_pool3d(attrs, outs, target): + """Schedule definition of avg_pool3d""" + layout = attrs.layout + with target: + return topi.generic.schedule_pool(outs, layout) + + +reg.register_pattern("nn.avg_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE) + + # max_pool2d_grad @reg.register_schedule("nn.max_pool2d_grad") def schedule_max_pool2d_grad(attrs, outs, target): diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index fda5027ee49e..326d72fabc38 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -425,6 +425,51 @@ def max_pool2d(data, return _make.max_pool2d(data, pool_size, strides, padding, layout, ceil_mode) +def max_pool3d(data, + pool_size=(1, 1, 1), + strides=(1, 1, 1), + padding=(0, 0, 0), + layout="NCDHW", + ceil_mode=False): + r"""3D maximum pooling operator. + + This operator takes data as input and does 3D max value calculation + with in pool_size sized window by striding defined by stride. + + + In the default case, where the data_layout is `NCDHW` + a data Tensor with shape `(batch_size, channels, depth, height, width)`, + to produce an output Tensor. + + The ceil_mode is used to take ceil or floor while computing out shape. + count_include_pad indicates including or excluding padded input values in computation. + This operator accepts data layout specification. + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + strides : tuple of int, optional + The strides of pooling. + + padding : tuple of int, optional + The padding for pooling. + + layout : str, optional + Layout of the input. + + ceil_mode : bool, optional + To enable or disable ceil while pooling. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.max_pool3d(data, pool_size, strides, padding, + layout, ceil_mode) + def avg_pool2d(data, pool_size=(1, 1), strides=(1, 1), @@ -482,6 +527,55 @@ def avg_pool2d(data, return _make.avg_pool2d(data, pool_size, strides, padding, layout, ceil_mode, count_include_pad) +def avg_pool3d(data, + pool_size=(1, 1, 1), + strides=(1, 1, 1), + padding=(0, 0, 0), + layout="NCDHW", + ceil_mode=False, + count_include_pad=False): + r"""3D average pooling operator. + + This operator takes data as input and does 3D average value calculation + with in pool_size sized window by striding defined by stride + + + In the default case, where the data_layout is `NCDHW` + a data Tensor with shape `(batch_size, channels, depthm height, width)`, + to produce an output Tensor. + + The ceil_mode is used to take ceil or floor while computing out shape. + count_include_pad indicates including or excluding padded input values in computation. + This operator accepts data layout specification. + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + strides : tuple of int, optional + The strides of pooling. + + padding : tuple of int, optional + The padding for pooling. + + layout : str, optional + Layout of the input. + + ceil_mode : bool, optional + To enable or disable ceil while pooling. + + count_include_pad : bool, optional + To include padding to compute the average. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.avg_pool3d(data, pool_size, strides, padding, + layout, ceil_mode, count_include_pad) + def max_pool2d_grad(out_grad, data, pool_size=(1, 1), diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 35b2c053f8cf..e9ddca9cde9c 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -271,6 +271,16 @@ class AvgPool2DAttrs(Attrs): """Attributes used in avg_pool2d operators""" +@register_relay_attr_node +class MaxPool3DAttrs(Attrs): + """Attributes used in max_pool3d operators""" + + +@register_relay_attr_node +class AvgPool3DAttrs(Attrs): + """Attributes used in avg_pool3d operators""" + + @register_relay_attr_node class BitPackAttrs(Attrs): """Attributes used in bitpack operator""" diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 82de233f7b7e..7163eead8435 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -237,16 +237,58 @@ def _test_pooling_iteration(input_shape, **kwargs): def _test_pooling(input_shape, **kwargs): _test_pooling_iteration(input_shape, **kwargs) - if is_gpu_available() and (len(input_shape) == 4): - input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] - kwargs['data_format'] = 'NCHW' - _test_pooling_iteration(input_shape, **kwargs) + if is_gpu_available(): + if len(input_shape) == 4: + input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] + kwargs['data_format'] = 'NCHW' + _test_pooling_iteration(input_shape, **kwargs) def test_forward_pooling(): """ Pooling """ - + # TensorFlow only supports NDHWC for max_pool3d on CPU for pool_type in ['AVG', 'MAX']: + # NDHWC is the default layout for max_pool3d and avg_pool3d in TensorFlow + _test_pooling(input_shape=[1, 3, 32, 32, 32], + window_shape=[2, 2, 2], + padding='VALID', + pooling_type=pool_type, + dilation_rate=[1, 1, 1], + strides=[2, 2, 2]) + + _test_pooling(input_shape=[1, 3, 32, 32, 32], + window_shape=[1, 1, 1], + padding='SAME', + pooling_type=pool_type, + dilation_rate=[1, 1, 1], + strides=[1, 1, 1]) + + _test_pooling(input_shape=[1, 3, 32, 32, 32], + window_shape=[2, 2, 2], + padding='SAME', + pooling_type=pool_type, + dilation_rate=[1, 1, 1], + strides=[2, 2, 2]) + + # test cases for max_pool3d & avg_pool3d with layout NCDHW + # TensorFlow pool3d doesn't support NCDHW on cpu + if is_gpu_available(): + _test_pooling(input_shape=[1, 3, 32, 32, 32], + window_shape=[1, 1, 1], + padding='SAME', + pooling_type=pool_type, + dilation_rate=[1, 1, 1], + strides=[1, 1, 1], + data_format='NCDHW') + + _test_pooling(input_shape=[1, 3, 32, 32, 32], + window_shape=[2, 2, 2], + padding='VALID', + pooling_type=pool_type, + dilation_rate=[1, 1, 1], + strides=[2, 2, 2], + data_format='NCDHW') + _test_pooling(input_shape=[2, 9, 10, 2], window_shape=[1, 1], padding='SAME', @@ -2855,7 +2897,6 @@ def test_forward_add_n(): test_forward_sin() test_forward_negative() test_forward_divide() - test_forward_floordiv() test_forward_abs() test_forward_softplus() test_forward_sqrt() @@ -2916,5 +2957,3 @@ def test_forward_add_n(): test_forward_where() test_forward_matmul() test_forward_batch_matmul() - - # TODO missing tests: rank diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 9257ef22c42c..722c31f0537c 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -471,7 +471,7 @@ def _test_pool2d(opfunc, reffunc): y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) func = relay.Function([x], y) data = np.random.uniform(size=dshape).astype(dtype) - ref_res = reffunc(data.reshape(1,3,14,2,14,2), axis=(3,5)) + ref_res = reffunc(data.reshape(1, 3, 14, 2, 14, 2), axis=(3, 5)) for target, ctx in ctx_list(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data) @@ -532,6 +532,34 @@ def test_pool2d(): _test_global_pool2d(relay.nn.global_avg_pool2d, np.mean) +def test_pool3d(): + + def _test_pool3d(opfunc): + n, c, d, h, w = tvm.var("n"), 10, 5, 224, 224 + x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32")) + y = opfunc(x, pool_size=(1, 1, 1)) + assert "pool_size=" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, 10, 5, 224, 224), "float32") + # test execution + dtype = "float32" + dshape = (1, 3, 32, 32, 32) + x = relay.var("x", shape=dshape) + pool_type = 'max' if 'max' in str(opfunc) else 'avg' + y = opfunc(x, pool_size=(2, 2, 2), strides=(2, 2, 2), padding=(0, 0, 0, 0, 0, 0)) + func = relay.Function([x], y) + data = np.random.uniform(size=dshape).astype(dtype) + ref_res = topi.testing.pool3d_ncdhw_python(data, (2, 2, 2), (2, 2, 2), + (0, 0, 0, 0, 0, 0), (1, 3, 16, 16, 16), pool_type, False) + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + _test_pool3d(relay.nn.max_pool3d) + _test_pool3d(relay.nn.avg_pool3d) + + def test_avg_pool2d_no_count_pad(): kh, kw = (4, 4) sh, sw = (2, 2) @@ -900,6 +928,7 @@ def test_bitpack_infer_type(): if __name__ == "__main__": test_pool2d() + test_pool3d() test_avg_pool2d_no_count_pad() test_lrn() test_l2_normalize() diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 43e9f19d880b..f9acb73bf0cc 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -43,5 +43,6 @@ from .batch_matmul import batch_matmul from .slice_axis_python import slice_axis_python from .sequence_mask_python import sequence_mask +from .pool3d_python import pool3d_ncdhw_python from .pool_grad_python import pool_grad_nchw from .one_hot import one_hot diff --git a/topi/python/topi/testing/pool3d_python.py b/topi/python/topi/testing/pool3d_python.py new file mode 100644 index 000000000000..32513163d068 --- /dev/null +++ b/topi/python/topi/testing/pool3d_python.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, unused-variable +"""max_pool3d and avg_pool3d in python""" +import math +import numpy as np + +def pool3d_ncdhw_python(np_data, kernel, + strides, padding, + out_shape, pool_type, + count_include_pad=True, + ceil_mode=False, dtype="float32"): + """baseline for max_pool3d and avg_pool3d, default layout is "NCDHW""" + in_n, in_c, in_d, in_h, in_w = in_shape = np_data.shape + k_d, k_h, k_w = kernel + s_d, s_h, s_w = strides + pf, pt, pl, pk, pb, pr = padding + + if ceil_mode: + assert out_shape[2] == int(math.ceil(float(in_shape[2] - k_d + pf + pk) / s_d) + 1) + assert out_shape[3] == int(math.ceil(float(in_shape[3] - k_h + pt + pb) / s_h) + 1) + assert out_shape[4] == int(math.ceil(float(in_shape[4] - k_w + pl + pr) / s_w) + 1) + else: + assert out_shape[2] == int(math.floor(float(in_shape[2] - k_d + pf + pk) / s_d) + 1) + assert out_shape[3] == int(math.floor(float(in_shape[3] - k_h + pt + pb) / s_h) + 1) + assert out_shape[4] == int(math.floor(float(in_shape[4] - k_w + pl + pr) / s_w) + 1) + + pad_np = np.zeros(shape=(in_n, in_c, + in_d + pf + pk, + in_h + pt + pb, + in_w + pl + pr)).astype(dtype) + no_zero = (range(in_n), + range(in_c), + (range(pf, in_d + pf)), + (range(pt, in_h + pt)), + (range(pl, in_w + pl))) + pad_np[np.ix_(*no_zero)] = np_data + ret_np = np.zeros(shape=out_shape).astype(dtype) + + if pool_type == 'avg': + for k in range(out_shape[2]): + for i in range(out_shape[3]): + for j in range(out_shape[4]): + if count_include_pad: + ret_np[:, :, k, i, j] = \ + np.mean(pad_np[:, :, k * s_d: k * s_d + k_d, + i * s_h: i * s_h + k_h, + j * s_w: j * s_w + k_w], axis=(2, 3, 4)) + else: + pad_count = np.sum(pad_np[:, :, + k * s_d: k * s_d + k_d, + i * s_h: i * s_h + k_h, + j * s_w: j * s_w + k_w] > 0, axis=(2, 3, 4)) + ret_np[:, :, k, i, j] = np.sum(pad_np[:, :, + k * s_d: k * s_d + k_d, + i * s_h: i * s_h + k_h, + j * s_w: j * s_w + k_w], + axis=(2, 3, 4)) / np.maximum(pad_count, 1) + elif pool_type == 'max': + for k in range(out_shape[2]): + for i in range(out_shape[3]): + for j in range(out_shape[4]): + ret_np[:, :, k, i, j] = np.max( + pad_np[:, :, k * s_d: k * s_d + k_d, + i * s_h: i * s_h + k_h, + j * s_w: j * s_w + k_w], axis=(2, 3, 4)) + else: + raise ValueError("pool type {} is not supported".format(pool_type)) + + ret_np = np.maximum(ret_np, 0.0) + return ret_np diff --git a/topi/tests/python/test_topi_pooling.py b/topi/tests/python/test_topi_pooling.py index 8a32c18d700c..8f649de9a24d 100644 --- a/topi/tests/python/test_topi_pooling.py +++ b/topi/tests/python/test_topi_pooling.py @@ -15,13 +15,12 @@ # specific language governing permissions and limitations # under the License. """Test code for pooling""" +import math import numpy as np import tvm import topi import topi.testing -import math from topi.util import get_const_tuple - from common import get_all_backend def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True): @@ -264,57 +263,25 @@ def test_adaptive_pool(): verify_adaptive_pool((1, 14, 56, 78), (34, 13), "max") verify_adaptive_pool((1, 5, 46, 97), (4, 96), "avg") -def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True): - iz = iw = ih - kz = kw = kh - sz = sw = sh - pf, pt, pl, pk, pb, pr = padding - layout = "NCDHW" - A = tvm.placeholder((n, ic, iz, ih, iw), name='A') - B = topi.nn.pool3d(A, kernel=[kz, kh, kw], stride=[sz, sh, sw], padding=padding, +def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type, + ceil_mode, count_include_pad=True, layout='NCDHW'): + id = iw = ih + kd = kw = kh + sd = sw = sh + input_shape = (n, ic, id, ih, iw) + kernel = [kd, kh, kw] + stride = [sd, sh, sw] + A = tvm.placeholder(input_shape, name='A') + B = topi.nn.pool3d(A, kernel=kernel, stride=stride, padding=padding, pool_type=pool_type, ceil_mode=ceil_mode, - layout="NCDHW", count_include_pad=count_include_pad) + layout=layout, count_include_pad=count_include_pad) B = topi.nn.relu(B) dtype = A.dtype + output_shape = [int(i) for i in B.shape] - bshape = get_const_tuple(B.shape) - ashape = get_const_tuple(A.shape) - if ceil_mode: - assert bshape[2] == int(math.ceil(float(ashape[2] - kz + pf + pk) / sz) + 1) - assert bshape[3] == int(math.ceil(float(ashape[3] - kh + pt + pb) / sh) + 1) - assert bshape[4] == int(math.ceil(float(ashape[4] - kw + pl + pr) / sw) + 1) - else: - assert bshape[2] == int(math.floor(float(ashape[2] - kz + pf + pk) / sz) + 1) - assert bshape[3] == int(math.floor(float(ashape[3] - kh + pt + pb) / sh) + 1) - assert bshape[4] == int(math.floor(float(ashape[4] - kw + pl + pr) / sw) + 1) - - a_np = np.random.uniform(low=0.001, size=(n, ic, iz, ih, iw)).astype(dtype) - pad_np = np.zeros(shape=(n, ic, iz+pf+pk, ih+pt+pb, iw+pl+pr)).astype(dtype) - no_zero = (range(n), range(ic), (range(pf, iz+pf)), (range(pt, ih+pt)), (range(pl, iw+pl))) - pad_np[np.ix_(*no_zero)] = a_np - _, oc, oz, oh, ow = get_const_tuple(B.shape) - b_np = np.zeros(shape=(n, oc, oz, oh, ow)).astype(dtype) - - if pool_type == 'avg': - for k in range(oz): - for i in range(oh): - for j in range(ow): - if count_include_pad: - b_np[:,:,k,i,j] = np.mean( \ - pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3,4)) - else: - pad_count = np.sum( \ - pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2,3,4)) - b_np[:,:,k,i,j] = np.sum(pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw], \ - axis=(2,3, 4)) / np.maximum(pad_count, 1) - - elif pool_type =='max': - for k in range(oz): - for i in range(oh): - for j in range(ow): - b_np[:,:,k,i,j] = np.max( \ - pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3,4)) - b_np = np.maximum(b_np, 0.0) + input_np = np.random.uniform(low=0.001, size=input_shape).astype(dtype) + ref_np = topi.testing.pool3d_ncdhw_python(input_np, kernel, stride, padding, + output_shape, pool_type, count_include_pad, ceil_mode) def check_device(device): ctx = tvm.context(device, 0) @@ -325,11 +292,11 @@ def check_device(device): with tvm.target.create(device): s = topi.generic.schedule_pool(B, layout) - a = tvm.nd.array(a_np, ctx) + a = tvm.nd.array(input_np, ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) f = tvm.build(s, [A, B], device) f(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + tvm.testing.assert_allclose(b.asnumpy(), ref_np, rtol=1e-5) for device in get_all_backend(): check_device(device) @@ -353,7 +320,7 @@ def test_pool3d(): if __name__ == "__main__": test_pool() + test_pool3d() test_pool_grad() test_global_pool() test_adaptive_pool() - test_pool3d()