From 29376622a670ca8ca26e2ce22fc9a09b6c80713e Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Sat, 21 Dec 2019 18:22:47 +0000 Subject: [PATCH] fix comments --- python/tvm/relay/frontend/tensorflow.py | 5 +- .../frontend/tensorflow/test_forward.py | 27 ++++-- tests/python/relay/test_op_level2.py | 42 +--------- topi/python/topi/testing/__init__.py | 1 + topi/python/topi/testing/pool3d_python.py | 82 +++++++++++++++++++ topi/tests/python/test_topi_pooling.py | 71 +++++----------- 6 files changed, 127 insertions(+), 101 deletions(-) create mode 100644 topi/python/topi/testing/pool3d_python.py diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 82e1a4f8a525a..8526f540515de 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -140,10 +140,10 @@ def _impl(inputs, attr, params): 'is not valid.' raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) if attr['data_format'] == "NDHWC": - tmp_shape = attr['_input_shapes'][inputs[0]] - input_shape = [tmp_shape[ii] for ii in (0, 4, 1, 2, 3)] + input_shape = [attr['_input_shapes'][inputs[0]][i] for i in (0, 4, 1, 2, 3)] inputs[0] = _op.transpose(inputs[0], axes=(0, 4, 1, 2, 3)) attr['data_format'] = "NCDHW" + attr['_input_shapes'][inputs[0]] = input_shape flip_layout = True attr['padding'] = attr['padding'].decode("utf-8") @@ -174,7 +174,6 @@ def _impl(inputs, attr, params): if name == "avg_pool": attr['count_include_pad'] = False attr['ceil_mode'] = False - attr['data_format'] = 'NCDHW' out = AttrCvt( op_name=name, transforms={ diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index bfcd6bda8d4da..7163eead8435e 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -242,17 +242,13 @@ def _test_pooling(input_shape, **kwargs): input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] kwargs['data_format'] = 'NCHW' _test_pooling_iteration(input_shape, **kwargs) - elif len(input_shape) == 5: - input_shape = [input_shape[ii] for ii in (0, 4, 1, 2, 3)] - kwargs['data_format'] = 'NCDHW' - _test_pooling_iteration(input_shape, **kwargs) def test_forward_pooling(): """ Pooling """ # TensorFlow only supports NDHWC for max_pool3d on CPU - for pool_type in ['MAX', 'AVG']: - + for pool_type in ['AVG', 'MAX']: + # NDHWC is the default layout for max_pool3d and avg_pool3d in TensorFlow _test_pooling(input_shape=[1, 3, 32, 32, 32], window_shape=[2, 2, 2], padding='VALID', @@ -274,6 +270,25 @@ def test_forward_pooling(): dilation_rate=[1, 1, 1], strides=[2, 2, 2]) + # test cases for max_pool3d & avg_pool3d with layout NCDHW + # TensorFlow pool3d doesn't support NCDHW on cpu + if is_gpu_available(): + _test_pooling(input_shape=[1, 3, 32, 32, 32], + window_shape=[1, 1, 1], + padding='SAME', + pooling_type=pool_type, + dilation_rate=[1, 1, 1], + strides=[1, 1, 1], + data_format='NCDHW') + + _test_pooling(input_shape=[1, 3, 32, 32, 32], + window_shape=[2, 2, 2], + padding='VALID', + pooling_type=pool_type, + dilation_rate=[1, 1, 1], + strides=[2, 2, 2], + data_format='NCDHW') + _test_pooling(input_shape=[2, 9, 10, 2], window_shape=[1, 1], padding='SAME', diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 25b073cb69c63..722c31f0537cb 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -534,44 +534,6 @@ def test_pool2d(): def test_pool3d(): - def _test_pool3d_baseline(np_data, in_shape, kernal, strides, padding, - out_shape, pool_type, count_include_pad=True): - # default layout is "NCDHW" - dtype = "float32" - n, ic, id, ih, iw = in_shape - kd, kw, kh = kernal - sd, sw, sh = strides - pf, pt, pl, pk, pb, pr = padding - - pad_np = np.zeros(shape=(n, ic, id + pf + pk, ih + pt + pb, iw + pl + pr)).astype(dtype) - no_zero = (range(n), range(ic), (range(pf, id + pf)), (range(pt, ih + pt)), (range(pl, iw + pl))) - pad_np[np.ix_(*no_zero)] = np_data - ret_np = np.zeros(shape=out_shape).astype(dtype) - - if pool_type == 'avg': - for k in range(out_shape[2]): - for i in range(out_shape[3]): - for j in range(out_shape[4]): - if count_include_pad: - ret_np[:, :, k, i, j] = \ - np.mean(pad_np[:, :, k * sd:k * sd + kd, - i * sh:i * sh + kh, j * sw:j * sw + kw], axis=(2, 3, 4)) - else: - pad_count = np.sum(pad_np[:, :, k * sd:k * sd + kd, - i * sh:i * sh + kh, j * sw:j * sw + kw] > 0, axis=(2, 3, 4)) - ret_np[:, :, k, i, j] = np.sum(pad_np[:, :, k * sd:k * sd + kd, - i * sh:i * sh + kh, j * sw:j * sw + kw], - axis=(2, 3, 4)) / np.maximum(pad_count, 1) - elif pool_type == 'max': - for k in range(out_shape[2]): - for i in range(out_shape[3]): - for j in range(out_shape[4]): - ret_np[:, :, k, i, j] = np.max( - pad_np[:, :, k * sd:k * sd + kd, - i * sh:i * sh + kh, j * sw:j * sw + kw], axis=(2, 3, 4)) - ret_np = np.maximum(ret_np, 0.0) - return ret_np - def _test_pool3d(opfunc): n, c, d, h, w = tvm.var("n"), 10, 5, 224, 224 x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32")) @@ -587,8 +549,8 @@ def _test_pool3d(opfunc): y = opfunc(x, pool_size=(2, 2, 2), strides=(2, 2, 2), padding=(0, 0, 0, 0, 0, 0)) func = relay.Function([x], y) data = np.random.uniform(size=dshape).astype(dtype) - ref_res = _test_pool3d_baseline(data, (1, 3, 32, 32, 32), (2, 2, 2), (2, 2, 2), - (0, 0, 0, 0, 0, 0), (1, 3, 16, 16, 16), pool_type, False) + ref_res = topi.testing.pool3d_ncdhw_python(data, (2, 2, 2), (2, 2, 2), + (0, 0, 0, 0, 0, 0), (1, 3, 16, 16, 16), pool_type, False) for target, ctx in ctx_list(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(data) diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 43e9f19d880b1..f9acb73bf0cc7 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -43,5 +43,6 @@ from .batch_matmul import batch_matmul from .slice_axis_python import slice_axis_python from .sequence_mask_python import sequence_mask +from .pool3d_python import pool3d_ncdhw_python from .pool_grad_python import pool_grad_nchw from .one_hot import one_hot diff --git a/topi/python/topi/testing/pool3d_python.py b/topi/python/topi/testing/pool3d_python.py new file mode 100644 index 0000000000000..451a19d7d3ce0 --- /dev/null +++ b/topi/python/topi/testing/pool3d_python.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, unused-variable +"""max_pool3d and avg_pool3d in python""" +import math +import numpy as np + +def pool3d_ncdhw_python(np_data, kernel, + strides, padding, + out_shape, pool_type, + count_include_pad=True, + ceil_mode=False, dtype="float32"): + """baseline for max_pool3d and avg_pool3d, default layout is "NCDHW""" + in_n, in_c, in_d, in_h, in_w = in_shape = np_data.shape + k_d, k_h, k_w = kernel + s_d, s_h, s_w = strides + pf, pt, pl, pk, pb, pr = padding + + if ceil_mode: + assert out_shape[2] == int(math.ceil(float(in_shape[2] - k_d + pf + pk) / s_d) + 1) + assert out_shape[3] == int(math.ceil(float(in_shape[3] - k_h + pt + pb) / s_h) + 1) + assert out_shape[4] == int(math.ceil(float(in_shape[4] - k_w + pl + pr) / s_w) + 1) + else: + assert out_shape[2] == int(math.floor(float(in_shape[2] - k_d + pf + pk) / s_d) + 1) + assert out_shape[3] == int(math.floor(float(in_shape[3] - k_h + pt + pb) / s_h) + 1) + assert out_shape[4] == int(math.floor(float(in_shape[4] - k_w + pl + pr) / s_w) + 1) + + pad_np = np.zeros(shape=(in_n, in_c, + in_d + pf + pk, + in_h + pt + pb, + in_w + pl + pr)).astype(dtype) + no_zero = (range(in_n), + range(in_c), + (range(pf, in_d + pf)), + (range(pt, in_h + pt)), + (range(pl, in_w + pl))) + pad_np[np.ix_(*no_zero)] = np_data + ret_np = np.zeros(shape=out_shape).astype(dtype) + + if pool_type == 'avg': + for k in range(out_shape[2]): + for i in range(out_shape[3]): + for j in range(out_shape[4]): + if count_include_pad: + ret_np[:, :, k, i, j] = \ + np.mean(pad_np[:, :, k * s_d: k * s_d + k_d, + i * s_h: i * s_h + k_h, + j * s_w: j * s_w + k_w], axis=(2, 3, 4)) + else: + pad_count = np.sum(pad_np[:, :, + k * s_d: k * s_d + k_d, + i * s_h: i * s_h + k_h, + j * s_w: j * s_w + k_w] > 0, axis=(2, 3, 4)) + ret_np[:, :, k, i, j] = np.sum(pad_np[:, :, + k * s_d: k * s_d + k_d, + i * s_h: i * s_h + k_h, + j * s_w: j * s_w + k_w], + axis=(2, 3, 4)) / np.maximum(pad_count, 1) + elif pool_type == 'max': + for k in range(out_shape[2]): + for i in range(out_shape[3]): + for j in range(out_shape[4]): + ret_np[:, :, k, i, j] = np.max( + pad_np[:, :, k * s_d: k * s_d + k_d, + i * s_h: i * s_h + k_h, + j * s_w: j * s_w + k_w], axis=(2, 3, 4)) + ret_np = np.maximum(ret_np, 0.0) + return ret_np diff --git a/topi/tests/python/test_topi_pooling.py b/topi/tests/python/test_topi_pooling.py index 8a32c18d700cf..8f649de9a24d4 100644 --- a/topi/tests/python/test_topi_pooling.py +++ b/topi/tests/python/test_topi_pooling.py @@ -15,13 +15,12 @@ # specific language governing permissions and limitations # under the License. """Test code for pooling""" +import math import numpy as np import tvm import topi import topi.testing -import math from topi.util import get_const_tuple - from common import get_all_backend def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True): @@ -264,57 +263,25 @@ def test_adaptive_pool(): verify_adaptive_pool((1, 14, 56, 78), (34, 13), "max") verify_adaptive_pool((1, 5, 46, 97), (4, 96), "avg") -def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True): - iz = iw = ih - kz = kw = kh - sz = sw = sh - pf, pt, pl, pk, pb, pr = padding - layout = "NCDHW" - A = tvm.placeholder((n, ic, iz, ih, iw), name='A') - B = topi.nn.pool3d(A, kernel=[kz, kh, kw], stride=[sz, sh, sw], padding=padding, +def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type, + ceil_mode, count_include_pad=True, layout='NCDHW'): + id = iw = ih + kd = kw = kh + sd = sw = sh + input_shape = (n, ic, id, ih, iw) + kernel = [kd, kh, kw] + stride = [sd, sh, sw] + A = tvm.placeholder(input_shape, name='A') + B = topi.nn.pool3d(A, kernel=kernel, stride=stride, padding=padding, pool_type=pool_type, ceil_mode=ceil_mode, - layout="NCDHW", count_include_pad=count_include_pad) + layout=layout, count_include_pad=count_include_pad) B = topi.nn.relu(B) dtype = A.dtype + output_shape = [int(i) for i in B.shape] - bshape = get_const_tuple(B.shape) - ashape = get_const_tuple(A.shape) - if ceil_mode: - assert bshape[2] == int(math.ceil(float(ashape[2] - kz + pf + pk) / sz) + 1) - assert bshape[3] == int(math.ceil(float(ashape[3] - kh + pt + pb) / sh) + 1) - assert bshape[4] == int(math.ceil(float(ashape[4] - kw + pl + pr) / sw) + 1) - else: - assert bshape[2] == int(math.floor(float(ashape[2] - kz + pf + pk) / sz) + 1) - assert bshape[3] == int(math.floor(float(ashape[3] - kh + pt + pb) / sh) + 1) - assert bshape[4] == int(math.floor(float(ashape[4] - kw + pl + pr) / sw) + 1) - - a_np = np.random.uniform(low=0.001, size=(n, ic, iz, ih, iw)).astype(dtype) - pad_np = np.zeros(shape=(n, ic, iz+pf+pk, ih+pt+pb, iw+pl+pr)).astype(dtype) - no_zero = (range(n), range(ic), (range(pf, iz+pf)), (range(pt, ih+pt)), (range(pl, iw+pl))) - pad_np[np.ix_(*no_zero)] = a_np - _, oc, oz, oh, ow = get_const_tuple(B.shape) - b_np = np.zeros(shape=(n, oc, oz, oh, ow)).astype(dtype) - - if pool_type == 'avg': - for k in range(oz): - for i in range(oh): - for j in range(ow): - if count_include_pad: - b_np[:,:,k,i,j] = np.mean( \ - pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3,4)) - else: - pad_count = np.sum( \ - pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2,3,4)) - b_np[:,:,k,i,j] = np.sum(pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw], \ - axis=(2,3, 4)) / np.maximum(pad_count, 1) - - elif pool_type =='max': - for k in range(oz): - for i in range(oh): - for j in range(ow): - b_np[:,:,k,i,j] = np.max( \ - pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3,4)) - b_np = np.maximum(b_np, 0.0) + input_np = np.random.uniform(low=0.001, size=input_shape).astype(dtype) + ref_np = topi.testing.pool3d_ncdhw_python(input_np, kernel, stride, padding, + output_shape, pool_type, count_include_pad, ceil_mode) def check_device(device): ctx = tvm.context(device, 0) @@ -325,11 +292,11 @@ def check_device(device): with tvm.target.create(device): s = topi.generic.schedule_pool(B, layout) - a = tvm.nd.array(a_np, ctx) + a = tvm.nd.array(input_np, ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) f = tvm.build(s, [A, B], device) f(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + tvm.testing.assert_allclose(b.asnumpy(), ref_np, rtol=1e-5) for device in get_all_backend(): check_device(device) @@ -353,7 +320,7 @@ def test_pool3d(): if __name__ == "__main__": test_pool() + test_pool3d() test_pool_grad() test_global_pool() test_adaptive_pool() - test_pool3d()