From 1cb3485188f18c5f866e2ca460623f2eb01e1279 Mon Sep 17 00:00:00 2001 From: Alex Gladkov Date: Tue, 21 Jan 2020 21:48:13 +0000 Subject: [PATCH] Improve CUDA conv2d_transpose_nchw - combine pad and dilate; - fix for the issue https://discuss.tvm.ai/t/compile-error-for-cuda-target/4164 - fix for the issue https://github.com/apache/incubator-tvm/pull/4472 --- .../python/topi/cuda/conv2d_transpose_nchw.py | 136 +++++++----------- .../python/test_topi_conv2d_transpose_nchw.py | 29 ++-- 2 files changed, 74 insertions(+), 91 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_transpose_nchw.py b/topi/python/topi/cuda/conv2d_transpose_nchw.py index 274dfb03e794..26bc26169674 100644 --- a/topi/python/topi/cuda/conv2d_transpose_nchw.py +++ b/topi/python/topi/cuda/conv2d_transpose_nchw.py @@ -21,11 +21,11 @@ from tvm import autotvm from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity from .. import nn, generic -from ..util import equal_const_int, get_const_tuple, traverse_inline +from ..util import get_const_tuple, traverse_inline @autotvm.task.register_topi_compute(nn.conv2d_transpose_nchw, ['cuda', 'gpu'], "direct") -def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype): +def conv2d_transpose_nchw_cuda(cfg, data, kernel, stride, padding, out_dtype): """Transposed 2D convolution nchw forward operator. Parameters @@ -48,67 +48,58 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype): Output : tvm.Tensor 4-D with shape [batch, out_channel, out_height, out_width] """ - batch, in_c, in_h, in_w = get_const_tuple(Input.shape) - _, out_c, filter_h, filter_w = get_const_tuple(Filter.shape) - stride_h, stride_w = strides - - # attach stride info to config, this is used in schedule space definition - cfg.stride = strides - - # padding stage - fpad_top, fpad_left, fpad_bottom, fpad_right = nn.get_pad_tuple(padding, (filter_h, filter_w)) - bpad_top = filter_h - 1 - fpad_top - bpad_bottom = filter_h - 1 - fpad_bottom - bpad_left = filter_w - 1 - fpad_left - bpad_right = filter_w - 1 - fpad_right - - # padding stage - FirstPad = nn.pad(Input, - [0, 0, (bpad_top + stride_h - 1) // stride_h, - (bpad_left + stride_w - 1) // stride_w], - [0, 0, (bpad_bottom + stride_h - 1) // stride_h, - (bpad_right + stride_w - 1) // stride_w], name='FirstPad') - - idxdiv = tvm.indexdiv - idxmod = tvm.indexmod - # remove extra padding introduced by dilatation - border_h = idxmod(stride_h - idxmod(bpad_top, stride_h), stride_h) - border_w = idxmod(stride_w - idxmod(bpad_left, stride_w), stride_w) - - # dilation stage - data = FirstPad - strides = [1, 1, stride_h, stride_w] - n = len(data.shape) - - def _dilate(*indices): - not_zero = [] - index_tuple = [] - for i in range(n): - if not equal_const_int(strides[i], 1): - index_tuple.append(idxdiv(indices[i], strides[i])) - not_zero.append(idxmod(indices[i], strides[i]).equal(0)) - else: - index_tuple.append(indices[i]) - if not_zero: - not_zero = tvm.all(*not_zero) - return tvm.if_then_else(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype)) - return data(*index_tuple) - - # convolution stage - out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h - out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w - dc = tvm.reduce_axis((0, in_c), name='dc') - dh = tvm.reduce_axis((0, filter_h), name='dh') - dw = tvm.reduce_axis((0, filter_w), name='dw') - - Output = tvm.compute( - (batch, out_c, out_h, out_w), + batch, inp_channels, inp_height, inp_width = get_const_tuple(data.shape) + _, out_channels, kernel_height, kernel_width = get_const_tuple(kernel.shape) + stride_height, stride_width = stride + cfg.stride = stride + pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple( + padding, (kernel_height, kernel_width)) + + out_width = (inp_width - 1) * stride_width + \ + kernel_width - pad_left - pad_right + pad_left = kernel_width - 1 - pad_left + pad_right = kernel_width - 1 - pad_right + dilated_width = stride_width * (inp_width - 1) + 1 + + out_height = (inp_height - 1) * stride_height + \ + kernel_height - pad_top - pad_bottom + pad_top = kernel_height - 1 - pad_top + pad_bottom = kernel_height - 1 - pad_bottom + dilated_height = stride_height * (inp_height - 1) + 1 + + # compute pad + data = tvm.compute( + (batch, inp_channels, + pad_top + dilated_height + pad_bottom, + pad_left + dilated_width + pad_right), + lambda n, c, y, x: tvm.if_then_else( + tvm.all(x >= pad_left, + x < pad_left + dilated_width, + tvm.indexmod(x - pad_left, stride_width).equal(0), + y >= pad_top, + y < pad_top + dilated_height, + tvm.indexmod(y - pad_top, stride_height).equal(0)), + data[n, c, + tvm.indexdiv(y - pad_top, stride_height), + tvm.indexdiv(x - pad_left, stride_width)], + tvm.const(0., "float32")), + name='data_pad') + + # compute transposed conv + dc = tvm.reduce_axis((0, inp_channels), name='dc') + dh = tvm.reduce_axis((0, kernel_height), name='dh') + dw = tvm.reduce_axis((0, kernel_width), name='dw') + data_out = tvm.compute( + (batch, out_channels, out_height, out_width), lambda b, c, h, w: tvm.sum( - _dilate(b, dc, h + dh + border_h, w + dw + border_w).astype(out_dtype) * - Filter[dc, c, filter_h - 1 - dh, filter_w - 1 - dw].astype(out_dtype), + data[b, dc, h + dh, w + dw].astype(out_dtype) * + kernel[dc, + c, + kernel_height - 1 - dh, + kernel_width - 1 - dw].astype(out_dtype), axis=[dc, dh, dw]), tag="conv2d_transpose_nchw") - return Output + return data_out @autotvm.task.register_topi_schedule(generic.schedule_conv2d_transpose_nchw, ['cuda', 'gpu'], 'direct') @@ -140,7 +131,8 @@ def _fallback_schedule(N, F, Y, X): else: cfg["tile_n"] = SplitEntity([1, 1, 1, 1]) # split F (output channel dimension) - cfg["tile_f"] = SplitEntity([-1, 1, 64, 1]) + if F > 1: + cfg["tile_f"] = SplitEntity([-1, 1, 64, 1]) # split Y (height dimension) y_split_factor = 1 for candidate in range(5, 17): @@ -185,26 +177,8 @@ def _callback(op): cfg.define_knob("unroll_explicit", [0, 1]) if cfg.is_fallback: - ko = int(kernel.shape[1]) - kh = int(kernel.shape[2]) - kw = int(kernel.shape[3]) - stride_h, stride_w = cfg.stride - # Workaround to make CUDA compilation work. Issue #4470 - # TODO make _fallback_schedule work for all kernel/strides combinations - # after issue #4470 is resolved - do_fallback = True - if ko == 1: - do_fallback = False - elif (kh, kw) == (1, 1): - do_fallback = True - elif (stride_h, stride_w) == (2, 2): - do_fallback = False - elif (kh, kw) == (stride_h, stride_w): - do_fallback = False - - if do_fallback: - N, F, Y, X = get_const_tuple(conv.shape) - _fallback_schedule(N, F, Y, X) + N, F, Y, X = get_const_tuple(conv.shape) + _fallback_schedule(N, F, Y, X) ##### space definition end ##### diff --git a/topi/tests/python/test_topi_conv2d_transpose_nchw.py b/topi/tests/python/test_topi_conv2d_transpose_nchw.py index 0960760a89de..fb836d43ccce 100644 --- a/topi/tests/python/test_topi_conv2d_transpose_nchw.py +++ b/topi/tests/python/test_topi_conv2d_transpose_nchw.py @@ -25,10 +25,13 @@ from common import get_all_backend def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding): - in_height = in_width = in_size + in_height, in_width = in_size + kernel_height, kernel_width = kernel + stride_height, stride_width = stride + pad_top, pad_left, pad_bottom, pad_right = padding A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') - W = tvm.placeholder((in_channel, num_filter, kernel, kernel), name='W') + W = tvm.placeholder((in_channel, num_filter, kernel_height, kernel_width), name='W') a_shape = get_const_tuple(A.shape) w_shape = get_const_tuple(W.shape) @@ -51,7 +54,10 @@ def check_device(device): return print("Running on target: %s" % device) with tvm.target.create(device): - B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], [padding, padding], A.dtype) + B = topi.nn.conv2d_transpose_nchw(A, W, + [stride_height, stride_width], + [pad_top, pad_left, pad_bottom, pad_right], + A.dtype) C = topi.nn.relu(B) s1 = topi.generic.schedule_conv2d_transpose_nchw([B]) s2 = topi.generic.schedule_conv2d_transpose_nchw([C]) @@ -66,18 +72,21 @@ def check_device(device): func2(a, w, c) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) - for device in get_all_backend(): check_device(device) def test_conv2d_transpose_nchw(): - verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 1, 0) - verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 2, 1) - verify_conv2d_transpose_nchw(1, 3, 224, 32, 2, 2, 0) - verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 1, 0) - verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 2, 1) - + verify_conv2d_transpose_nchw(1, 3, (224, 224), 1, (1, 1), (1, 1), (0, 0, 0, 0)) + verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0)) + verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (3, 3), (0, 0, 0, 0)) + verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0)) + verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (2, 2), (1, 1, 1, 1)) + verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (2, 2), (2, 2), (0, 0, 0, 0)) + verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (1, 1), (0, 0, 0, 0)) + verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (2, 2), (1, 1, 1, 1)) + verify_conv2d_transpose_nchw(16, 32, (8192, 1), 8, (31, 1), (2, 1), (14, 0, 15, 0)) + verify_conv2d_transpose_nchw(16, 512, (8, 1), 128, (31, 1), (2, 1), (14, 0, 15, 0)) if __name__ == "__main__": test_conv2d_transpose_nchw()