From 8490f0104aace1ebedb956e107c90511f1328ec8 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Tue, 14 Jan 2020 20:03:14 -0800 Subject: [PATCH] Revert "[Relay][TOPI]Fix meaning of conv2d_transpose output_padding parameter (#4318)" (#4708) This reverts commit dcf7fbf1f962569e78c624755b2d612fffa81ada. --- python/tvm/autotvm/tophub.py | 6 ++--- python/tvm/relay/op/nn/_nn.py | 12 ++++++--- tests/python/relay/test_op_level2.py | 25 ++++++++++--------- topi/python/topi/arm_cpu/conv2d_transpose.py | 21 ++++++---------- topi/python/topi/cuda/conv1d_transpose_ncw.py | 7 +++--- .../python/topi/cuda/conv2d_transpose_nchw.py | 12 ++++----- topi/python/topi/nn/conv1d_transpose.py | 6 ++--- topi/python/topi/nn/conv2d_transpose.py | 24 ++++++++---------- .../testing/conv1d_transpose_ncw_python.py | 7 +++--- .../topi/testing/conv2d_transpose_python.py | 17 ++++++------- topi/python/topi/x86/conv2d_transpose.py | 4 +-- .../python/test_topi_conv1d_transpose_ncw.py | 2 +- .../python/test_topi_conv2d_transpose_nchw.py | 18 ++++++------- vta/python/vta/top/vta_conv2d_transpose.py | 16 +++++------- vta/scripts/tune_conv2d_transpose.py | 19 ++++++-------- .../test_benchmark_topi_conv2d_transpose.py | 16 ++++++------ 16 files changed, 92 insertions(+), 120 deletions(-) diff --git a/python/tvm/autotvm/tophub.py b/python/tvm/autotvm/tophub.py index 2f07f72c7cd8a..d953eaaeea9df 100644 --- a/python/tvm/autotvm/tophub.py +++ b/python/tvm/autotvm/tophub.py @@ -46,16 +46,16 @@ # the version of each package PACKAGE_VERSION = { - 'arm_cpu': "v0.05", + 'arm_cpu': "v0.04", 'llvm': "v0.03", - 'cuda': "v0.07", + 'cuda': "v0.06", 'rocm': "v0.03", 'opencl': "v0.03", 'mali': "v0.05", 'intel_graphics': "v0.01", - 'vta': "v0.07", + 'vta': "v0.06", } logger = logging.getLogger('autotvm') diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index e405fee916dc1..275e09c84535c 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -339,7 +339,6 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype, target): padding = get_const_tuple(attrs.padding) strides = get_const_tuple(attrs.strides) dilation = get_const_tuple(attrs.dilation) - output_padding = get_const_tuple(attrs.output_padding) groups = attrs.groups layout = attrs.data_layout out_dtype = attrs.out_dtype @@ -349,7 +348,10 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype, target): assert dilation == (1, 1), "not support dilate now" assert groups == 1, "only support groups == 1 for now" out = topi.nn.conv2d_transpose_nchw( - inputs[0], inputs[1], strides, padding, out_dtype, output_padding) + inputs[0], inputs[1], strides, padding, out_dtype) + output_padding = get_const_tuple(attrs.output_padding) + out = topi.nn.pad(out, + [0, 0, 0, 0], [0, 0, output_padding[0], output_padding[1]]) return [out] @@ -442,8 +444,10 @@ def compute_conv1d_transpose(attrs, inputs, out_dtype, target): assert dilation == (1,), "conv1d_transpose dilation is not supported" assert groups == 1, "conv1d_transpose groups == 1 only supported" out = topi.nn.conv1d_transpose_ncw( - inputs[0], inputs[1], strides, padding, out_dtype, - get_const_tuple(attrs.output_padding)) + inputs[0], inputs[1], strides, padding, out_dtype) + output_padding = get_const_tuple(attrs.output_padding) + out = topi.nn.pad(out, + [0, 0, 0], [0, 0, output_padding[0]]) return [out] diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 68f398396c051..4b914ee11fee8 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -570,8 +570,11 @@ def test_conv2d_transpose_nchw_run(): dtype = "float32" data = np.random.uniform(size=dshape).astype(dtype) kernel = np.random.uniform(size=kshape).astype(dtype) - ref_res = topi.testing.conv2d_transpose_nchw_python( - data, kernel, 2, 1, (2, 2)) + c_np = topi.testing.conv2d_transpose_nchw_python( + data, kernel, 2, 1) + d_np = np.zeros(shape=oshape) + d_np[:,:,0:c_np.shape[2],0:c_np.shape[3]] = c_np + ref_res = d_np for target, ctx in ctx_list(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) @@ -596,14 +599,9 @@ def test_conv2d_transpose_nhwc_run(): data = np.random.uniform(size=dshape_nhwc).astype(dtype) kernel = np.random.uniform(size=kshape_hwoi).astype(dtype) # use true kshape layout here - HWOI - - ref_res = topi.testing.conv2d_transpose_nhwc_python(data, kernel, 'HWOI', - 2, 1, output_padding=(2, 2)) - - for target, ctx in ctx_list(): - intrp1 = relay.create_executor("graph", ctx=ctx, target=target) - op_res1 = intrp1.evaluate(func)(data, kernel) - tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + c_np = topi.testing.conv2d_transpose_nhwc_python(data, kernel, 'HWOI', 2, 1) + d_np = np.zeros(shape=oshape_nhwc) + d_np[:,0:c_np.shape[1],0:c_np.shape[2],:] = c_np def test_conv1d_transpose_ncw_run(): @@ -619,8 +617,11 @@ def test_conv1d_transpose_ncw_run(): dtype = "float32" data = np.random.uniform(size=dshape).astype(dtype) kernel = np.random.uniform(size=kshape).astype(dtype) - ref_res = topi.testing.conv1d_transpose_ncw_python( - data, kernel, 2, 1, output_padding=(2,)) + c_np = topi.testing.conv1d_transpose_ncw_python( + data, kernel, 2, 1) + d_np = np.zeros(shape=oshape) + d_np[:,:,0:c_np.shape[2]] = c_np + ref_res = d_np for target, ctx in ctx_list(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) diff --git a/topi/python/topi/arm_cpu/conv2d_transpose.py b/topi/python/topi/arm_cpu/conv2d_transpose.py index 483c7983a3ea6..65f1024c88a30 100644 --- a/topi/python/topi/arm_cpu/conv2d_transpose.py +++ b/topi/python/topi/arm_cpu/conv2d_transpose.py @@ -27,8 +27,7 @@ from .conv2d_spatial_pack import schedule_conv2d_spatial_pack_nchw @autotvm.task.register_topi_compute(conv2d_transpose_nchw, "arm_cpu", "direct") -def conv2d_transpose_nchw_arm(cfg, Input, Filter, strides, padding, out_dtype, - output_padding=(0, 0)): +def conv2d_transpose_nchw_arm(cfg, Input, Filter, strides, padding, out_dtype): """Transposed 2D convolution nchw forward operator. Parameters @@ -48,33 +47,27 @@ def conv2d_transpose_nchw_arm(cfg, Input, Filter, strides, padding, out_dtype, out_dtype: str The output data type. This is used for mixed precision. - output_padding : tuple of int - Used to get the right output shape in gradients - Returns ------- Output : tvm.Tensor 4-D with shape [batch, out_channel, out_height, out_width] """ - return _decl_spatial_pack(cfg, Input, Filter, strides, padding, "NCHW", out_dtype, 2, - output_padding) + return _decl_spatial_pack(cfg, Input, Filter, strides, padding, "NCHW", out_dtype, 2) -def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile, - output_padding): +def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile): assert layout == "NCHW", "Only support NCHW" out_dtype = out_dtype or data.dtype N, CI, IH, IW = get_const_tuple(data.shape) _, CO, KH, KW = get_const_tuple(kernel.shape) - opad_h, opad_w = output_padding pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (KH, KW)) - bpad_top, bpad_bottom = KH - 1 - pad_top, KH - 1 - pad_bottom + opad_h - bpad_left, bpad_right = KW - 1 - pad_left, KW - 1 - pad_right + opad_w + bpad_top, bpad_bottom = KH - 1 - pad_top, KH - 1 - pad_bottom + bpad_left, bpad_right = KW - 1 - pad_left, KW - 1 - pad_right HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) - OH = (IH - 1) * HSTR - pad_top - pad_bottom + KH + opad_h - OW = (IW - 1) * WSTR - pad_left - pad_right + KW + opad_w + OH = (IH - 1) * HSTR - pad_top - pad_bottom + KH + OW = (IW - 1) * WSTR - pad_left - pad_right + KW dilated_input = dilate(data, [1, 1, HSTR, WSTR]) data_pad = pad(dilated_input, [0, 0, bpad_top, bpad_left], [0, 0, bpad_bottom, bpad_right]) diff --git a/topi/python/topi/cuda/conv1d_transpose_ncw.py b/topi/python/topi/cuda/conv1d_transpose_ncw.py index 9f15264f1e7e0..be7824e71e814 100644 --- a/topi/python/topi/cuda/conv1d_transpose_ncw.py +++ b/topi/python/topi/cuda/conv1d_transpose_ncw.py @@ -23,7 +23,7 @@ from ..util import get_const_tuple, traverse_inline @autotvm.task.register_topi_compute(nn.conv1d_transpose_ncw, ['cuda', 'gpu'], "direct") -def conv1d_transpose_ncw_cuda(cfg, data, kernel, stride, padding, out_dtype, output_padding=(0,)): +def conv1d_transpose_ncw_cuda(cfg, data, kernel, stride, padding, out_dtype): """Transposed 1D convolution ncw forward operator. Parameters @@ -53,11 +53,10 @@ def conv1d_transpose_ncw_cuda(cfg, data, kernel, stride, padding, out_dtype, out cfg.stride = stride batch, inp_channels, inp_width = get_const_tuple(data.shape) _, out_channels, kernel_size = get_const_tuple(kernel.shape) - opad = output_padding[0] pad_left, pad_right = nn.get_pad_tuple1d(padding, kernel_size) - out_width = (inp_width - 1) * stride + kernel_size - pad_left - pad_right + opad + out_width = (inp_width - 1) * stride + kernel_size - pad_left - pad_right pad_left = kernel_size - 1 - pad_left - pad_right = kernel_size - 1 - pad_right + opad + pad_right = kernel_size - 1 - pad_right dilated_width = stride * (inp_width - 1) + 1 data = tvm.compute( (batch, inp_channels, pad_left + dilated_width + pad_right), diff --git a/topi/python/topi/cuda/conv2d_transpose_nchw.py b/topi/python/topi/cuda/conv2d_transpose_nchw.py index a630e9071c0dd..274dfb03e7946 100644 --- a/topi/python/topi/cuda/conv2d_transpose_nchw.py +++ b/topi/python/topi/cuda/conv2d_transpose_nchw.py @@ -25,8 +25,7 @@ @autotvm.task.register_topi_compute(nn.conv2d_transpose_nchw, ['cuda', 'gpu'], "direct") -def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype, - output_padding=(0, 0)): +def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype): """Transposed 2D convolution nchw forward operator. Parameters @@ -52,7 +51,6 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype, batch, in_c, in_h, in_w = get_const_tuple(Input.shape) _, out_c, filter_h, filter_w = get_const_tuple(Filter.shape) stride_h, stride_w = strides - opad_h, opad_w = output_padding # attach stride info to config, this is used in schedule space definition cfg.stride = strides @@ -60,9 +58,9 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype, # padding stage fpad_top, fpad_left, fpad_bottom, fpad_right = nn.get_pad_tuple(padding, (filter_h, filter_w)) bpad_top = filter_h - 1 - fpad_top - bpad_bottom = filter_h - 1 - fpad_bottom + opad_h + bpad_bottom = filter_h - 1 - fpad_bottom bpad_left = filter_w - 1 - fpad_left - bpad_right = filter_w - 1 - fpad_right + opad_w + bpad_right = filter_w - 1 - fpad_right # padding stage FirstPad = nn.pad(Input, @@ -97,8 +95,8 @@ def _dilate(*indices): return data(*index_tuple) # convolution stage - out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + opad_h - out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + opad_w + out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w dc = tvm.reduce_axis((0, in_c), name='dc') dh = tvm.reduce_axis((0, filter_h), name='dh') dw = tvm.reduce_axis((0, filter_w), name='dw') diff --git a/topi/python/topi/nn/conv1d_transpose.py b/topi/python/topi/nn/conv1d_transpose.py index 23c0f577b5239..39918e90c3173 100644 --- a/topi/python/topi/nn/conv1d_transpose.py +++ b/topi/python/topi/nn/conv1d_transpose.py @@ -25,8 +25,7 @@ @tvm.target.generic_func -def conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype, - output_padding=(0,)): +def conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype): """Transposed 1D convolution ncw forward operator. Parameters @@ -57,12 +56,11 @@ def conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype, stride = stride[0] batch, channels_in, data_width = data.shape _, channels_out, kernel_width = kernel.shape - opad = output_padding[0] channels_out = simplify(channels_out) data = dilate(data, [1, 1, stride], name='data_dilate') pad_left, pad_right = get_pad_tuple1d(padding, (kernel_width,)) pad_left = kernel_width - 1 - pad_left - pad_right = kernel_width - 1 - pad_right + opad + pad_right = kernel_width - 1 - pad_right data = pad(data, [0, 0, pad_left], [0, 0, pad_right], name='data_pad') # transpose kernel, switch kernel layout to IOW diff --git a/topi/python/topi/nn/conv2d_transpose.py b/topi/python/topi/nn/conv2d_transpose.py index 7baf18ffa9439..a240b687c86d3 100644 --- a/topi/python/topi/nn/conv2d_transpose.py +++ b/topi/python/topi/nn/conv2d_transpose.py @@ -26,7 +26,7 @@ @tvm.target.generic_func -def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype, output_padding=(0, 0)): +def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype): """Transposed 2D convolution nchw forward operator. Parameters @@ -46,33 +46,28 @@ def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype, output_pad out_dtype : str The output data type. This is used for mixed precision. - output_padding : tuple of ints - Used to get the right output shape for gradients - Returns ------- Output : tvm.Tensor 4-D with shape [batch, out_channel, out_height, out_width] """ - return declaration_conv2d_transpose_impl(Input, Filter, strides, padding, out_dtype, - output_padding=output_padding) + return declaration_conv2d_transpose_impl(Input, Filter, strides, padding, out_dtype) -def conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype, output_padding): +def conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype): """Preprocess data and kernel to make the compute pattern of conv2d_transpose the same as conv2d""" batch, in_c, in_h, in_w = data.shape _, out_c, filter_h, filter_w = kernel.shape stride_h, stride_w = strides - opad_h, opad_w = output_padding # dilate data data_dilate = dilate(data, [1, 1, stride_h, stride_w], name='data_dilate') # pad data fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w)) bpad_top = filter_h - 1 - fpad_top - bpad_bottom = filter_h - 1 - fpad_bottom + opad_h + bpad_bottom = filter_h - 1 - fpad_bottom bpad_left = filter_w - 1 - fpad_left - bpad_right = filter_w - 1 - fpad_right + opad_w + bpad_right = filter_w - 1 - fpad_right data_pad = pad(data_dilate, \ [0, 0, bpad_top, bpad_left], \ [0, 0, bpad_bottom, bpad_right], \ @@ -84,17 +79,18 @@ def conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype, return data_pad, kernel_transform -def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype, output_padding): +def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype): """Implementation of conv2d transpose""" data_pad, kernel_transform = \ - conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype, output_padding) + conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype) batch, in_c, in_h, in_w = data_pad.shape out_c, _, filter_h, filter_w = kernel_transform.shape + stride_h, stride_w = strides # convolution stage out_c = simplify(out_c) - out_h = simplify(in_h - filter_h + 1 + output_padding[0]) - out_w = simplify(in_w - filter_w + 1 + output_padding[1]) + out_h = simplify(in_h - filter_h + 1) + out_w = simplify(in_w - filter_w + 1) dc = tvm.reduce_axis((0, in_c), name='dc') dh = tvm.reduce_axis((0, filter_h), name='dh') dw = tvm.reduce_axis((0, filter_w), name='dw') diff --git a/topi/python/topi/testing/conv1d_transpose_ncw_python.py b/topi/python/topi/testing/conv1d_transpose_ncw_python.py index 4461291e6f2a2..cb78bbf8cb3f9 100644 --- a/topi/python/topi/testing/conv1d_transpose_ncw_python.py +++ b/topi/python/topi/testing/conv1d_transpose_ncw_python.py @@ -21,7 +21,7 @@ import topi from topi.nn.util import get_pad_tuple1d -def conv1d_transpose_ncw_python(a_np, w_np, stride, padding, output_padding): +def conv1d_transpose_ncw_python(a_np, w_np, stride, padding): """Transposed 1D convolution operator in NCW layout. Parameters @@ -47,7 +47,6 @@ def conv1d_transpose_ncw_python(a_np, w_np, stride, padding, output_padding): """ batch, in_c, in_w = a_np.shape _, out_c, filter_w = w_np.shape - opad = output_padding[0] if isinstance(stride, int): stride_w = stride else: @@ -57,11 +56,11 @@ def conv1d_transpose_ncw_python(a_np, w_np, stride, padding, output_padding): dilated_a_np = topi.testing.dilate_python(a_np, [1, 1, stride_w]) # padding stage bpad_left = filter_w - 1 - fpad_left - bpad_right = filter_w - 1 - fpad_right + opad + bpad_right = filter_w - 1 - fpad_right padded_a_np = np.zeros((batch, in_c, dilated_a_np.shape[2]+bpad_left+bpad_right)) padded_a_np[:, :, bpad_left:dilated_a_np.shape[2]+bpad_left] = dilated_a_np # convolution stage - out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + opad + out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w b_np = np.zeros((batch, out_c, out_w)) for n in range(batch): for f in range(out_c): diff --git a/topi/python/topi/testing/conv2d_transpose_python.py b/topi/python/topi/testing/conv2d_transpose_python.py index 7c56cf86c6274..50c43eb70e3e0 100644 --- a/topi/python/topi/testing/conv2d_transpose_python.py +++ b/topi/python/topi/testing/conv2d_transpose_python.py @@ -22,7 +22,7 @@ from topi.nn.util import get_pad_tuple -def conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding=(0, 0)): +def conv2d_transpose_nchw_python(a_np, w_np, stride, padding): """Transposed convolution operator in NCHW layout. Parameters @@ -50,22 +50,21 @@ def conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding=(0, stride_h = stride_w = stride else: stride_h, stride_w = stride - opad_h, opad_w = output_padding # dilate stage dilated_a_np = topi.testing.dilate_python(a_np, [1, 1, stride_h, stride_w]) # padding stage fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w)) bpad_top = filter_h - 1 - fpad_top - bpad_bottom = filter_h - 1 - fpad_bottom + opad_h + bpad_bottom = filter_h - 1 - fpad_bottom bpad_left = filter_w - 1 - fpad_left - bpad_right = filter_w - 1 - fpad_right + opad_w + bpad_right = filter_w - 1 - fpad_right padded_a_np = np.zeros((batch, in_c, dilated_a_np.shape[2]+bpad_top+bpad_bottom, \ dilated_a_np.shape[3]+bpad_left+bpad_right)) padded_a_np[:, :, bpad_top:dilated_a_np.shape[2]+bpad_top, \ bpad_left:dilated_a_np.shape[3]+bpad_left] = dilated_a_np # convolution stage - out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + opad_h - out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + opad_w + out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w b_np = np.zeros((batch, out_c, out_h, out_w)) for n in range(batch): for f in range(out_c): @@ -76,8 +75,7 @@ def conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding=(0, return b_np -def conv2d_transpose_nhwc_python(a_nhwc, weight, weight_format, stride, padding, - output_padding=(0, 0)): +def conv2d_transpose_nhwc_python(a_nhwc, weight, weight_format, stride, padding): """Transposed convolution operator in NHWC layout. Parameters @@ -119,7 +117,6 @@ def conv2d_transpose_nhwc_python(a_nhwc, weight, weight_format, stride, padding, else: raise ValueError('Valid weight_formats are HWIO, HWOI, OIHW or IOHW') - res_nchw = conv2d_transpose_nchw_python(a_nchw, w_iohw, stride, padding, - output_padding=output_padding) + res_nchw = conv2d_transpose_nchw_python(a_nchw, w_iohw, stride, padding) res_nhwc = np.transpose(res_nchw, (0, 2, 3, 1)) return res_nhwc diff --git a/topi/python/topi/x86/conv2d_transpose.py b/topi/python/topi/x86/conv2d_transpose.py index 3cbc57be4c1fa..27fc0afce999f 100644 --- a/topi/python/topi/x86/conv2d_transpose.py +++ b/topi/python/topi/x86/conv2d_transpose.py @@ -28,9 +28,9 @@ @autotvm.register_topi_compute(conv2d_transpose_nchw, 'cpu', ['direct']) -def _conv2d_transpose_nchw(cfg, data, kernel, strides, padding, out_dtype, output_padding=(0, 0)): +def _conv2d_transpose_nchw(cfg, data, kernel, strides, padding, out_dtype): data_pad, kernel_transform = \ - conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype, output_padding) + conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype) # reuse conv2d implementation _create_tuning_space_conv2d(cfg, data_pad, kernel_transform, strides=(1, 1), \ padding=(0, 0), dilation=(1, 1), layout="NCHW") diff --git a/topi/tests/python/test_topi_conv1d_transpose_ncw.py b/topi/tests/python/test_topi_conv1d_transpose_ncw.py index a05acd33e4984..9d6e9db254b50 100644 --- a/topi/tests/python/test_topi_conv1d_transpose_ncw.py +++ b/topi/tests/python/test_topi_conv1d_transpose_ncw.py @@ -37,7 +37,7 @@ def verify_conv1d_transpose_ncw(batch, in_channel, in_size, num_filter, kernel, def get_ref_data(): a_np = np.random.uniform(size=a_shape).astype(dtype) w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = topi.testing.conv1d_transpose_ncw_python(a_np, w_np, stride, padding, (0,)) + b_np = topi.testing.conv1d_transpose_ncw_python(a_np, w_np, stride, padding) c_np = np.maximum(b_np, 0) return a_np, w_np, b_np, c_np diff --git a/topi/tests/python/test_topi_conv2d_transpose_nchw.py b/topi/tests/python/test_topi_conv2d_transpose_nchw.py index fe07df8a2a612..0960760a89de3 100644 --- a/topi/tests/python/test_topi_conv2d_transpose_nchw.py +++ b/topi/tests/python/test_topi_conv2d_transpose_nchw.py @@ -24,7 +24,7 @@ from common import get_all_backend -def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding): +def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding): in_height = in_width = in_size A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') @@ -38,7 +38,7 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, def get_ref_data(): a_np = np.random.uniform(size=a_shape).astype(dtype) w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = topi.testing.conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding) + b_np = topi.testing.conv2d_transpose_nchw_python(a_np, w_np, stride, padding) c_np = np.maximum(b_np, 0) return a_np, w_np, b_np, c_np @@ -51,7 +51,7 @@ def check_device(device): return print("Running on target: %s" % device) with tvm.target.create(device): - B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], [padding, padding], A.dtype, output_padding) + B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], [padding, padding], A.dtype) C = topi.nn.relu(B) s1 = topi.generic.schedule_conv2d_transpose_nchw([B]) s2 = topi.generic.schedule_conv2d_transpose_nchw([C]) @@ -72,13 +72,11 @@ def check_device(device): def test_conv2d_transpose_nchw(): - verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 1, 0, (0, 0)) - verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 2, 1, (0, 0)) - verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 2, 1, (1, 0)) - verify_conv2d_transpose_nchw(1, 3, 224, 32, 2, 2, 0, (0, 0)) - verify_conv2d_transpose_nchw(1, 3, 224, 32, 2, 2, 0, (1, 1)) - verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 1, 0, (0, 0)) - verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 2, 1, (0, 0)) + verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 1, 0) + verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 2, 1) + verify_conv2d_transpose_nchw(1, 3, 224, 32, 2, 2, 0) + verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 1, 0) + verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 2, 1) if __name__ == "__main__": diff --git a/vta/python/vta/top/vta_conv2d_transpose.py b/vta/python/vta/top/vta_conv2d_transpose.py index ff10ff0153481..a2750dc9081d3 100644 --- a/vta/python/vta/top/vta_conv2d_transpose.py +++ b/vta/python/vta/top/vta_conv2d_transpose.py @@ -27,28 +27,24 @@ from ..environment import get_env @autotvm.register_topi_compute(topi.nn.conv2d_transpose_nchw, 'vta', 'direct') -def _declaration_conv2d_transpose(cfg, +def _declatation_conv2d_transpose(cfg, data, kernel, strides, padding, - out_dtype, - output_padding=(0, 0)): + out_dtype): ishape = get_const_tuple(data.shape) kshape = get_const_tuple(kernel.shape) b, c_i, i_h, i_w, t_b, t_ci = ishape c_o, _, k_h, k_w, t_co, t_ci = kshape stride_h, stride_w = strides - opad_h, opad_w = output_padding - # FIXME(tmoreau89): currently IR pass breaks when output padding != (0,0) - assert opad_h == 0 and opad_w == 0, "VTA does not support output padding for now" # derive padding parameters fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (k_h, k_w)) bpad_top = k_h - 1 - fpad_top - bpad_bottom = k_h - 1 - fpad_bottom + opad_h + bpad_bottom = k_h - 1 - fpad_bottom bpad_left = k_w - 1 - fpad_left - bpad_right = k_w - 1 - fpad_right + opad_w + bpad_right = k_w - 1 - fpad_right # padding stage dilated_input = topi.nn.dilate(data, [1, 1, stride_h, stride_w, 1, 1]) @@ -57,8 +53,8 @@ def _declaration_conv2d_transpose(cfg, [0, 0, bpad_bottom, bpad_right, 0, 0]) # convolution transpose stage - out_h = (i_h - 1) * stride_h - fpad_top - fpad_bottom + k_h + opad_h - out_w = (i_w - 1) * stride_w - fpad_left - fpad_right + k_w + opad_w + out_h = (i_h - 1) * stride_h - fpad_top - fpad_bottom + k_h + out_w = (i_w - 1) * stride_w - fpad_left - fpad_right + k_w oshape = (b, c_o, out_h, out_w, t_b, t_co) d_c = tvm.reduce_axis((0, c_i), name='d_c') d_h = tvm.reduce_axis((0, k_h), name='d_h') diff --git a/vta/scripts/tune_conv2d_transpose.py b/vta/scripts/tune_conv2d_transpose.py index fa9900a121c4d..3e51d410638b2 100644 --- a/vta/scripts/tune_conv2d_transpose.py +++ b/vta/scripts/tune_conv2d_transpose.py @@ -33,15 +33,13 @@ Workload = namedtuple("Conv2DTransposeWorkload", ['batch', 'height', 'width', 'in_filter', 'out_filter', - 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride', - 'o_hpad', 'o_wpad']) + 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) -# DCGAN workloads dcgan_wkls = [ # dcgan - ('DCGAN.CT1', Workload(env.BATCH, 4, 4, 1024, 512, 4, 4, 1, 1, 2, 2, 0, 0)), - ('DCGAN.CT2', Workload(env.BATCH, 8, 8, 512, 256, 4, 4, 1, 1, 2, 2, 0, 0)), - ('DCGAN.CT3', Workload(env.BATCH, 16, 16, 256, 128, 4, 4, 1, 1, 2, 2, 0, 0)), + ('DCGAN.CT1', Workload(env.BATCH, 4, 4, 1024, 512, 4, 4, 1, 1, 2, 2)), + ('DCGAN.CT2', Workload(env.BATCH, 8, 8, 512, 256, 4, 4, 1, 1, 2, 2)), + ('DCGAN.CT3', Workload(env.BATCH, 16, 16, 256, 128, 4, 4, 1, 1, 2, 2)), ] @tvm.tag_scope(tag=topi.tag.ELEMWISE) @@ -53,7 +51,7 @@ def my_clip(x, a_min, a_max): x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") return x -def conv2d_transpose(N, CI, H, W, CO, KH, KW, strides, padding, opadding): +def conv2d_transpose(N, CI, H, W, CO, KH, KW, strides, padding): data_shape = (N//env.BATCH, CI//env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN) kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN) @@ -66,9 +64,7 @@ def conv2d_transpose(N, CI, H, W, CO, KH, KW, strides, padding, opadding): Filter=kernel, strides=strides, padding=padding, - out_dtype=env.acc_dtype, - output_padding=opadding - ) + out_dtype=env.acc_dtype) res = topi.right_shift(res, env.WGT_WIDTH) res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) res = topi.cast(res, env.out_dtype) @@ -113,12 +109,11 @@ def conv2d_transpose(N, CI, H, W, CO, KH, KW, strides, padding, opadding): KW = wl.wkernel strides = (wl.hstride, wl.wstride) padding = (wl.hpad, wl.wpad) - opadding = (wl.o_hpad, wl.o_wpad) # Create task task = autotvm.task.create( conv2d_transpose, - args=(N, CI, H, W, CO, KH, KW, strides, padding, opadding), + args=(N, CI, H, W, CO, KH, KW, strides, padding), target=tvm.target.vta(), target_host=env.target_host, template_key='direct') diff --git a/vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py b/vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py index 235076c795287..e2601d1a424ff 100644 --- a/vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py +++ b/vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py @@ -37,8 +37,7 @@ Workload = namedtuple("Conv2DTransposeWorkload", ['batch', 'height', 'width', 'in_filter', 'out_filter', - 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride', - 'o_hpad', 'o_wpad']) + 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) # Get batch info from env env = vta.get_env() @@ -46,9 +45,9 @@ # DCGAN workloads dcgan_wklds = [ # dcgan - ('DCGAN.CT1', Workload(env.BATCH, 4, 4, 1024, 512, 4, 4, 1, 1, 2, 2, 0, 0)), - ('DCGAN.CT2', Workload(env.BATCH, 8, 8, 512, 256, 4, 4, 1, 1, 2, 2, 0, 0)), - ('DCGAN.CT3', Workload(env.BATCH, 16, 16, 256, 128, 4, 4, 1, 1, 2, 2, 0, 0)), + ('DCGAN.CT1', Workload(env.BATCH, 4, 4, 1024, 512, 4, 4, 1, 1, 2, 2)), + ('DCGAN.CT2', Workload(env.BATCH, 8, 8, 512, 256, 4, 4, 1, 1, 2, 2)), + ('DCGAN.CT3', Workload(env.BATCH, 16, 16, 256, 128, 4, 4, 1, 1, 2, 2)), ] # FIXME: we need a custom clip operator to circumvent a pattern detection limitation @@ -103,8 +102,7 @@ def run_conv2d_transpose(env, remote, wl, target, # Define base computation schedule with target: res = topi.nn.conv2d_transpose_nchw( - data, kernel, (wl.hstride, wl.wstride), - (wl.hpad, wl.wpad), env.acc_dtype, (wl.o_hpad, wl.o_wpad)) + data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), env.acc_dtype) res = topi.right_shift(res, env.WGT_WIDTH) res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) res = topi.cast(res, env.out_dtype) @@ -114,8 +112,8 @@ def run_conv2d_transpose(env, remote, wl, target, print(vta.lower(s, [data, kernel, res], simple_mode=True)) # Derive number of ops - fout_height = (wl.height - 1) * wl.hstride - 2 * wl.hpad + wl.hkernel + wl.o_hpad - fout_width = (wl.width - 1) * wl.wstride - 2 * wl.wpad + wl.wkernel + wl.o_wpad + fout_height = (wl.height - 1) * wl.hstride - 2 * wl.hpad + wl.hkernel + fout_width = (wl.width - 1) * wl.wstride - 2 * wl.wpad + wl.wkernel num_ops = 2 * wl.batch * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter # @memoize("vta.tests.test_benchmark_topi.conv2d.verify_nhwc")