Skip to content

Commit

Permalink
Bugfix apache#1692. Constant folding and result comparision allowance.
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 committed Sep 12, 2018
1 parent 0565fcc commit f09a287
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
5 changes: 3 additions & 2 deletions topi/python/topi/nn/upsampling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""TVM operator upsampling compute."""
from __future__ import absolute_import
import topi
from ..util import simplify


def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'):
Expand Down Expand Up @@ -31,9 +32,9 @@ def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'):
"""

if layout == "NCHW":
out_shape = (data.shape[2] * scale, data.shape[3] * scale)
out_shape = (simplify(data.shape[2] * scale), simplify(data.shape[3] * scale))
elif layout == "NHWC":
out_shape = (data.shape[1] * scale, data.shape[2] * scale)
out_shape = (simplify(data.shape[1] * scale), simplify(data.shape[2] * scale))
else:
raise ValueError("not support this layout {} yet".format(layout))

Expand Down
29 changes: 21 additions & 8 deletions topi/tests/python/test_topi_upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import topi.testing
import math

def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW'):
def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW', method="NEAREST_NEIGHBOR"):


if layout == 'NCHW':
Expand All @@ -22,9 +22,13 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCH
raise NotImplementedError(
'Layout not supported {} '.format(layout))

B = topi.nn.upsampling(A, scale, layout=layout)
B = topi.nn.upsampling(A, scale, layout=layout, method=method)

b_np = topi.testing.upsampling_python(a_np, scale, layout)
if method == "BILINEAR":
out_size = (in_height*scale, in_width*scale)
b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout)
else:
b_np = topi.testing.upsampling_python(a_np, scale, layout)

def check_device(device):
ctx = tvm.context(device, 0)
Expand All @@ -39,18 +43,27 @@ def check_device(device):
f = tvm.build(s, [A, B], device)
f(a, b)

np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)

for device in ['llvm', 'cuda', 'vulkan', 'nvptx']:
check_device(device)

def test_upsampling():
# NCHW
# NEAREST_NEIGHBOR - NCHW
verify_upsampling(8, 16, 32, 32, 2)
verify_upsampling(12, 32, 64, 64, 3)
# NHWC
verify_upsampling(8, 16, 32, 32, 2, "NHWC")
verify_upsampling(12, 32, 64, 64, 3, "NHWC")

# NEAREST_NEIGHBOR - NHWC
verify_upsampling(8, 16, 32, 32, 2, layout="NHWC")
verify_upsampling(12, 32, 64, 64, 3, layout="NHWC")

# BILINEAR - NCHW
verify_upsampling(2, 2, 32, 32, 2, method="BILINEAR")
verify_upsampling(2, 2, 32, 32, 3, method="BILINEAR")

# BILINEAR - NHWC
verify_upsampling(2, 2, 32, 32, 2, layout="NHWC", method="BILINEAR")
verify_upsampling(2, 2, 32, 32, 3, layout="NHWC", method="BILINEAR")

if __name__ == "__main__":
test_upsampling()

0 comments on commit f09a287

Please sign in to comment.