Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Dec 22, 2019
1 parent c939107 commit 2937662
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 101 deletions.
5 changes: 2 additions & 3 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ def _impl(inputs, attr, params):
'is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
if attr['data_format'] == "NDHWC":
tmp_shape = attr['_input_shapes'][inputs[0]]
input_shape = [tmp_shape[ii] for ii in (0, 4, 1, 2, 3)]
input_shape = [attr['_input_shapes'][inputs[0]][i] for i in (0, 4, 1, 2, 3)]
inputs[0] = _op.transpose(inputs[0], axes=(0, 4, 1, 2, 3))
attr['data_format'] = "NCDHW"
attr['_input_shapes'][inputs[0]] = input_shape
flip_layout = True

attr['padding'] = attr['padding'].decode("utf-8")
Expand Down Expand Up @@ -174,7 +174,6 @@ def _impl(inputs, attr, params):
if name == "avg_pool":
attr['count_include_pad'] = False
attr['ceil_mode'] = False
attr['data_format'] = 'NCDHW'
out = AttrCvt(
op_name=name,
transforms={
Expand Down
27 changes: 21 additions & 6 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,17 +242,13 @@ def _test_pooling(input_shape, **kwargs):
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
kwargs['data_format'] = 'NCHW'
_test_pooling_iteration(input_shape, **kwargs)
elif len(input_shape) == 5:
input_shape = [input_shape[ii] for ii in (0, 4, 1, 2, 3)]
kwargs['data_format'] = 'NCDHW'
_test_pooling_iteration(input_shape, **kwargs)


def test_forward_pooling():
""" Pooling """
# TensorFlow only supports NDHWC for max_pool3d on CPU
for pool_type in ['MAX', 'AVG']:

for pool_type in ['AVG', 'MAX']:
# NDHWC is the default layout for max_pool3d and avg_pool3d in TensorFlow
_test_pooling(input_shape=[1, 3, 32, 32, 32],
window_shape=[2, 2, 2],
padding='VALID',
Expand All @@ -274,6 +270,25 @@ def test_forward_pooling():
dilation_rate=[1, 1, 1],
strides=[2, 2, 2])

# test cases for max_pool3d & avg_pool3d with layout NCDHW
# TensorFlow pool3d doesn't support NCDHW on cpu
if is_gpu_available():
_test_pooling(input_shape=[1, 3, 32, 32, 32],
window_shape=[1, 1, 1],
padding='SAME',
pooling_type=pool_type,
dilation_rate=[1, 1, 1],
strides=[1, 1, 1],
data_format='NCDHW')

_test_pooling(input_shape=[1, 3, 32, 32, 32],
window_shape=[2, 2, 2],
padding='VALID',
pooling_type=pool_type,
dilation_rate=[1, 1, 1],
strides=[2, 2, 2],
data_format='NCDHW')

_test_pooling(input_shape=[2, 9, 10, 2],
window_shape=[1, 1],
padding='SAME',
Expand Down
42 changes: 2 additions & 40 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,44 +534,6 @@ def test_pool2d():

def test_pool3d():

def _test_pool3d_baseline(np_data, in_shape, kernal, strides, padding,
out_shape, pool_type, count_include_pad=True):
# default layout is "NCDHW"
dtype = "float32"
n, ic, id, ih, iw = in_shape
kd, kw, kh = kernal
sd, sw, sh = strides
pf, pt, pl, pk, pb, pr = padding

pad_np = np.zeros(shape=(n, ic, id + pf + pk, ih + pt + pb, iw + pl + pr)).astype(dtype)
no_zero = (range(n), range(ic), (range(pf, id + pf)), (range(pt, ih + pt)), (range(pl, iw + pl)))
pad_np[np.ix_(*no_zero)] = np_data
ret_np = np.zeros(shape=out_shape).astype(dtype)

if pool_type == 'avg':
for k in range(out_shape[2]):
for i in range(out_shape[3]):
for j in range(out_shape[4]):
if count_include_pad:
ret_np[:, :, k, i, j] = \
np.mean(pad_np[:, :, k * sd:k * sd + kd,
i * sh:i * sh + kh, j * sw:j * sw + kw], axis=(2, 3, 4))
else:
pad_count = np.sum(pad_np[:, :, k * sd:k * sd + kd,
i * sh:i * sh + kh, j * sw:j * sw + kw] > 0, axis=(2, 3, 4))
ret_np[:, :, k, i, j] = np.sum(pad_np[:, :, k * sd:k * sd + kd,
i * sh:i * sh + kh, j * sw:j * sw + kw],
axis=(2, 3, 4)) / np.maximum(pad_count, 1)
elif pool_type == 'max':
for k in range(out_shape[2]):
for i in range(out_shape[3]):
for j in range(out_shape[4]):
ret_np[:, :, k, i, j] = np.max(
pad_np[:, :, k * sd:k * sd + kd,
i * sh:i * sh + kh, j * sw:j * sw + kw], axis=(2, 3, 4))
ret_np = np.maximum(ret_np, 0.0)
return ret_np

def _test_pool3d(opfunc):
n, c, d, h, w = tvm.var("n"), 10, 5, 224, 224
x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32"))
Expand All @@ -587,8 +549,8 @@ def _test_pool3d(opfunc):
y = opfunc(x, pool_size=(2, 2, 2), strides=(2, 2, 2), padding=(0, 0, 0, 0, 0, 0))
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = _test_pool3d_baseline(data, (1, 3, 32, 32, 32), (2, 2, 2), (2, 2, 2),
(0, 0, 0, 0, 0, 0), (1, 3, 16, 16, 16), pool_type, False)
ref_res = topi.testing.pool3d_ncdhw_python(data, (2, 2, 2), (2, 2, 2),
(0, 0, 0, 0, 0, 0), (1, 3, 16, 16, 16), pool_type, False)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data)
Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,6 @@
from .batch_matmul import batch_matmul
from .slice_axis_python import slice_axis_python
from .sequence_mask_python import sequence_mask
from .pool3d_python import pool3d_ncdhw_python
from .pool_grad_python import pool_grad_nchw
from .one_hot import one_hot
82 changes: 82 additions & 0 deletions topi/python/topi/testing/pool3d_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument, unused-variable
"""max_pool3d and avg_pool3d in python"""
import math
import numpy as np

def pool3d_ncdhw_python(np_data, kernel,
strides, padding,
out_shape, pool_type,
count_include_pad=True,
ceil_mode=False, dtype="float32"):
"""baseline for max_pool3d and avg_pool3d, default layout is "NCDHW"""
in_n, in_c, in_d, in_h, in_w = in_shape = np_data.shape
k_d, k_h, k_w = kernel
s_d, s_h, s_w = strides
pf, pt, pl, pk, pb, pr = padding

if ceil_mode:
assert out_shape[2] == int(math.ceil(float(in_shape[2] - k_d + pf + pk) / s_d) + 1)
assert out_shape[3] == int(math.ceil(float(in_shape[3] - k_h + pt + pb) / s_h) + 1)
assert out_shape[4] == int(math.ceil(float(in_shape[4] - k_w + pl + pr) / s_w) + 1)
else:
assert out_shape[2] == int(math.floor(float(in_shape[2] - k_d + pf + pk) / s_d) + 1)
assert out_shape[3] == int(math.floor(float(in_shape[3] - k_h + pt + pb) / s_h) + 1)
assert out_shape[4] == int(math.floor(float(in_shape[4] - k_w + pl + pr) / s_w) + 1)

pad_np = np.zeros(shape=(in_n, in_c,
in_d + pf + pk,
in_h + pt + pb,
in_w + pl + pr)).astype(dtype)
no_zero = (range(in_n),
range(in_c),
(range(pf, in_d + pf)),
(range(pt, in_h + pt)),
(range(pl, in_w + pl)))
pad_np[np.ix_(*no_zero)] = np_data
ret_np = np.zeros(shape=out_shape).astype(dtype)

if pool_type == 'avg':
for k in range(out_shape[2]):
for i in range(out_shape[3]):
for j in range(out_shape[4]):
if count_include_pad:
ret_np[:, :, k, i, j] = \
np.mean(pad_np[:, :, k * s_d: k * s_d + k_d,
i * s_h: i * s_h + k_h,
j * s_w: j * s_w + k_w], axis=(2, 3, 4))
else:
pad_count = np.sum(pad_np[:, :,
k * s_d: k * s_d + k_d,
i * s_h: i * s_h + k_h,
j * s_w: j * s_w + k_w] > 0, axis=(2, 3, 4))
ret_np[:, :, k, i, j] = np.sum(pad_np[:, :,
k * s_d: k * s_d + k_d,
i * s_h: i * s_h + k_h,
j * s_w: j * s_w + k_w],
axis=(2, 3, 4)) / np.maximum(pad_count, 1)
elif pool_type == 'max':
for k in range(out_shape[2]):
for i in range(out_shape[3]):
for j in range(out_shape[4]):
ret_np[:, :, k, i, j] = np.max(
pad_np[:, :, k * s_d: k * s_d + k_d,
i * s_h: i * s_h + k_h,
j * s_w: j * s_w + k_w], axis=(2, 3, 4))
ret_np = np.maximum(ret_np, 0.0)
return ret_np
71 changes: 19 additions & 52 deletions topi/tests/python/test_topi_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
# specific language governing permissions and limitations
# under the License.
"""Test code for pooling"""
import math
import numpy as np
import tvm
import topi
import topi.testing
import math
from topi.util import get_const_tuple

from common import get_all_backend

def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True):
Expand Down Expand Up @@ -264,57 +263,25 @@ def test_adaptive_pool():
verify_adaptive_pool((1, 14, 56, 78), (34, 13), "max")
verify_adaptive_pool((1, 5, 46, 97), (4, 96), "avg")

def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True):
iz = iw = ih
kz = kw = kh
sz = sw = sh
pf, pt, pl, pk, pb, pr = padding
layout = "NCDHW"
A = tvm.placeholder((n, ic, iz, ih, iw), name='A')
B = topi.nn.pool3d(A, kernel=[kz, kh, kw], stride=[sz, sh, sw], padding=padding,
def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type,
ceil_mode, count_include_pad=True, layout='NCDHW'):
id = iw = ih
kd = kw = kh
sd = sw = sh
input_shape = (n, ic, id, ih, iw)
kernel = [kd, kh, kw]
stride = [sd, sh, sw]
A = tvm.placeholder(input_shape, name='A')
B = topi.nn.pool3d(A, kernel=kernel, stride=stride, padding=padding,
pool_type=pool_type, ceil_mode=ceil_mode,
layout="NCDHW", count_include_pad=count_include_pad)
layout=layout, count_include_pad=count_include_pad)
B = topi.nn.relu(B)
dtype = A.dtype
output_shape = [int(i) for i in B.shape]

bshape = get_const_tuple(B.shape)
ashape = get_const_tuple(A.shape)
if ceil_mode:
assert bshape[2] == int(math.ceil(float(ashape[2] - kz + pf + pk) / sz) + 1)
assert bshape[3] == int(math.ceil(float(ashape[3] - kh + pt + pb) / sh) + 1)
assert bshape[4] == int(math.ceil(float(ashape[4] - kw + pl + pr) / sw) + 1)
else:
assert bshape[2] == int(math.floor(float(ashape[2] - kz + pf + pk) / sz) + 1)
assert bshape[3] == int(math.floor(float(ashape[3] - kh + pt + pb) / sh) + 1)
assert bshape[4] == int(math.floor(float(ashape[4] - kw + pl + pr) / sw) + 1)

a_np = np.random.uniform(low=0.001, size=(n, ic, iz, ih, iw)).astype(dtype)
pad_np = np.zeros(shape=(n, ic, iz+pf+pk, ih+pt+pb, iw+pl+pr)).astype(dtype)
no_zero = (range(n), range(ic), (range(pf, iz+pf)), (range(pt, ih+pt)), (range(pl, iw+pl)))
pad_np[np.ix_(*no_zero)] = a_np
_, oc, oz, oh, ow = get_const_tuple(B.shape)
b_np = np.zeros(shape=(n, oc, oz, oh, ow)).astype(dtype)

if pool_type == 'avg':
for k in range(oz):
for i in range(oh):
for j in range(ow):
if count_include_pad:
b_np[:,:,k,i,j] = np.mean( \
pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3,4))
else:
pad_count = np.sum( \
pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2,3,4))
b_np[:,:,k,i,j] = np.sum(pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw], \
axis=(2,3, 4)) / np.maximum(pad_count, 1)

elif pool_type =='max':
for k in range(oz):
for i in range(oh):
for j in range(ow):
b_np[:,:,k,i,j] = np.max( \
pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3,4))
b_np = np.maximum(b_np, 0.0)
input_np = np.random.uniform(low=0.001, size=input_shape).astype(dtype)
ref_np = topi.testing.pool3d_ncdhw_python(input_np, kernel, stride, padding,
output_shape, pool_type, count_include_pad, ceil_mode)

def check_device(device):
ctx = tvm.context(device, 0)
Expand All @@ -325,11 +292,11 @@ def check_device(device):
with tvm.target.create(device):
s = topi.generic.schedule_pool(B, layout)

a = tvm.nd.array(a_np, ctx)
a = tvm.nd.array(input_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
f(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
tvm.testing.assert_allclose(b.asnumpy(), ref_np, rtol=1e-5)

for device in get_all_backend():
check_device(device)
Expand All @@ -353,7 +320,7 @@ def test_pool3d():

if __name__ == "__main__":
test_pool()
test_pool3d()
test_pool_grad()
test_global_pool()
test_adaptive_pool()
test_pool3d()

0 comments on commit 2937662

Please sign in to comment.