From 8eb15b7ba3303b764ba4b8fdec2e72cb7ea97b08 Mon Sep 17 00:00:00 2001 From: Alicja Kwasniewska Date: Wed, 20 Oct 2021 21:19:45 -0700 Subject: [PATCH] Added group transposed convolution Change includes topi implementation, tests, generic and x86 strategy for group transposed convolution. Signed-off-by: Alicja Kwasniewska --- python/tvm/relay/op/strategy/generic.py | 25 ++- python/tvm/relay/op/strategy/x86.py | 18 +- python/tvm/topi/generic/nn.py | 17 ++ python/tvm/topi/nn/conv2d_transpose.py | 88 +++++++++- .../topi/testing/conv2d_transpose_python.py | 40 ++++- src/relay/op/nn/convolution.h | 21 ++- .../test_topi_group_conv2d_transpose.py | 156 ++++++++++++++++++ 7 files changed, 344 insertions(+), 21 deletions(-) create mode 100644 tests/python/topi/python/test_topi_group_conv2d_transpose.py diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 777f17ba6084..71b9a078b30e 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -446,7 +446,7 @@ def deformable_conv2d_strategy(attrs, inputs, out_type, target): # conv2d_transpose -def wrap_compute_conv2d_transpose(topi_compute): +def wrap_compute_conv2d_transpose(topi_compute, has_groups=False): """wrap conv2d_transpose topi compute""" def compute_conv2d_transpose(attrs, inputs, out_dtype): @@ -456,7 +456,10 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype): out_dtype = attrs.out_dtype out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype output_padding = get_const_tuple(attrs.output_padding) - out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding) + args = [inputs[0], inputs[1], strides, padding, out_dtype, output_padding] + if has_groups: + args.append(attrs.groups) + out = topi_compute(*args) return [out] return compute_conv2d_transpose @@ -471,13 +474,19 @@ def conv2d_transpose_strategy(attrs, inputs, out_type, target): groups = attrs.groups assert layout == "NCHW", "only support nchw for now" assert dilation == (1, 1), "not support dilate now" - assert groups == 1, "only support groups == 1 for now" strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw), - wrap_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw), - name="conv2d_transpose_nchw.generic", - ) + if groups == 1: + strategy.add_implementation( + wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw), + wrap_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw), + name="conv2d_transpose_nchw.generic", + ) + else: # group_transpose_conv2d + strategy.add_implementation( + wrap_compute_conv2d_transpose(topi.nn.group_conv2d_transpose_nchw, has_groups=True), + wrap_topi_schedule(topi.generic.schedule_group_conv2d_transpose_nchw), + name="group_conv2d_transpose_nchw.generic", + ) return strategy diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 1c8d1b478cb1..a421b120fab4 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -281,13 +281,19 @@ def conv2d_transpose_strategy_cpu(attrs, inputs, out_type, target): groups = attrs.groups assert layout == "NCHW", "only support nchw for now" assert dilation == (1, 1), "not support dilate now" - assert groups == 1, "only support groups == 1 for now" strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_conv2d_transpose(topi.x86.conv2d_transpose_nchw), - wrap_topi_schedule(topi.x86.schedule_conv2d_transpose_nchw), - name="conv2d_transpose_nchw.x86", - ) + if groups == 1: + strategy.add_implementation( + wrap_compute_conv2d_transpose(topi.x86.conv2d_transpose_nchw), + wrap_topi_schedule(topi.x86.schedule_conv2d_transpose_nchw), + name="conv2d_transpose_nchw.x86", + ) + else: + strategy.add_implementation( + wrap_compute_conv2d_transpose(topi.nn.group_conv2d_transpose_nchw, has_groups=True), + wrap_topi_schedule(topi.generic.schedule_group_conv2d_transpose_nchw), + name="group_conv2d_transpose_nchw.x86", + ) return strategy diff --git a/python/tvm/topi/generic/nn.py b/python/tvm/topi/generic/nn.py index 1b3214154687..fceb6adea63c 100644 --- a/python/tvm/topi/generic/nn.py +++ b/python/tvm/topi/generic/nn.py @@ -345,6 +345,23 @@ def schedule_conv2d_transpose_nchw(outs): return _default_schedule(outs, False) +def schedule_group_conv2d_transpose_nchw(outs): + """Schedule for group_conv2d_transpose_nchw + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of group_conv2d_transpose_nchw + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + def schedule_conv1d_transpose_ncw(outs): """Schedule for conv1d_transpose_ncw diff --git a/python/tvm/topi/nn/conv2d_transpose.py b/python/tvm/topi/nn/conv2d_transpose.py index 22188bcd45a4..47a9a6d8e894 100644 --- a/python/tvm/topi/nn/conv2d_transpose.py +++ b/python/tvm/topi/nn/conv2d_transpose.py @@ -22,7 +22,7 @@ from .dilate import dilate from .pad import pad from .utils import get_pad_tuple -from ..utils import simplify +from ..utils import get_const_tuple, simplify def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype, output_padding): @@ -173,3 +173,89 @@ def conv2d_transpose_legalize(attrs, inputs, types): return out return None + + +def group_conv2d_transpose_nchw(Input, Filter, stride, padding, out_dtype, output_padding, groups): + """Group convolution operator in NCHW layout. + + Parameters + ---------- + Input : tvm.te.Tensor + 4-D with shape [batch, in_channel, in_height, in_width] + + Filter : tvm.te.Tensor + 4-D with shape [num_filter, in_channel // groups, filter_height, filter_width] + + stride : int or a list/tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + out_dtype : str + The output data type. This is used for mixed precision. + + output_padding : tuple of ints + Used to get the right output shape for gradients + + groups : int + number of groups + + out_dtype : str + The output type. This is used for mixed precision. + + Returns + ------- + Output : tvm.te.Tensor + 4-D with shape [batch, out_channel, out_height, out_width] + """ + + if groups == 1: + return conv2d_transpose_nchw(Input, Filter, stride, padding, out_dtype, output_padding) + + if out_dtype is None: + out_dtype = Input.dtype + assert isinstance(stride, int) or len(stride) == 2 + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + batch, in_channel, _, _ = get_const_tuple(Input.shape) + in_channel_w, _, _, _ = get_const_tuple(Filter.shape) + + assert in_channel % groups == 0, "input channels must divide group size" + assert in_channel_w % groups == 0, "weight channels must divide group size" + + data_pad, kernel_transform = conv2d_transpose_nchw_preprocess( + Input, Filter, stride, padding, out_dtype, output_padding + ) + batch, in_c, in_h, in_w = data_pad.shape + out_c, _, filter_h, filter_w = kernel_transform.shape + + out_c = simplify(out_c) + out_height = simplify(in_h - filter_h + 1) + out_width = simplify(in_w - filter_w + 1) + + # compute graph + rc = te.reduce_axis((0, in_c // groups), name="rc") + ry = te.reduce_axis((0, filter_h), name="ry") + rx = te.reduce_axis((0, filter_w), name="rx") + return te.compute( + (batch, out_c * groups, out_height, out_width), + lambda nn, ff, yy, xx: te.sum( + data_pad[ + nn, + ff // ((out_c * groups) // groups) * (in_c // groups) + rc, + yy + ry, + xx + rx, + ].astype(out_dtype) + * kernel_transform[ + ff % out_c, ff // ((out_c * groups) // groups) * (in_c // groups) + rc, ry, rx + ].astype(out_dtype), + axis=[rc, ry, rx], + ), + tag="group_conv2d_transpose_nchw", + ) diff --git a/python/tvm/topi/testing/conv2d_transpose_python.py b/python/tvm/topi/testing/conv2d_transpose_python.py index c7c0d9f2529a..c948a0715fd3 100644 --- a/python/tvm/topi/testing/conv2d_transpose_python.py +++ b/python/tvm/topi/testing/conv2d_transpose_python.py @@ -22,7 +22,7 @@ from tvm.topi.nn.utils import get_pad_tuple -def conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding): +def _conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding): """Transposed convolution operator in NCHW layout. Parameters @@ -141,3 +141,41 @@ def conv2d_transpose_nhwc_python( ) res_nhwc = np.transpose(res_nchw, (0, 2, 3, 1)) return res_nhwc + + +def conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding, groups=1): + """Convolution operator in NCHW layout. + + Parameters + ---------- + a_np : numpy.ndarray + 4-D with shape [batch, in_channel, in_height, in_width] + + w_np : numpy.ndarray + 4-D with shape [num_filter, in_channel // groups, filter_height, filter_width] + + stride : int or a list/tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + output_padding : int or a list/tuple of two ints + Use to disambiguate the output shape. + + groups : int + Number of groups + + Returns + ------- + b_np : np.ndarray + 4-D with shape [batch, out_channel, out_height, out_width] + """ + a_slices = np.array_split(a_np, groups, axis=1) + w_slices = np.array_split(w_np, groups, axis=0) + b_slices = [ + _conv2d_transpose_nchw_python(a_slice, w_slice, stride, padding, output_padding) + for a_slice, w_slice in zip(a_slices, w_slices) + ] + b_np = np.concatenate(b_slices, axis=1) + return b_np diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index c27227b2eb73..6e4a034b0364 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -1070,19 +1070,29 @@ bool Conv2DTransposeRel(const Array& types, int num_inputs, const Attrs& a IndexExpr channels, dilated_ksize_y, dilated_ksize_x; auto dshape_nchw = trans_in_layout.ForwardShape(data->shape); + if (param->groups > 1) { + ICHECK(weight->shape.defined()) + << "Weight shape must be specified when groups is greater than 1."; + } // infer weight if the kernel_size and channels are defined if (param->kernel_size.defined() && param->channels.defined()) { ICHECK_EQ(param->kernel_size.size(), 2); ICHECK_EQ(param->dilation.size(), 2); - Array wshape({dshape_nchw[1], indexdiv(param->channels, param->groups), - param->kernel_size[0], param->kernel_size[1]}); - + tvm::tir::ExprDeepEqual expr_equal; + Array wshape; + if (expr_equal(param->channels, 1)) { + wshape = {{dshape_nchw[1], param->channels, param->kernel_size[0], param->kernel_size[1]}}; + channels = param->groups; + } else { + wshape = {{dshape_nchw[1], indexdiv(param->channels, param->groups), param->kernel_size[0], + param->kernel_size[1]}}; + channels = param->channels; + } wshape = trans_kernel_layout.BackwardShape(wshape); dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; - channels = param->channels; DataType weight_dtype = data->dtype; if (weight != nullptr) { @@ -1108,7 +1118,8 @@ bool Conv2DTransposeRel(const Array& types, int num_inputs, const Attrs& a << " channels=" << param->channels << " wshape=" << Array(wshape); } if (!dshape_nchw[1].as() && !wshape[0].as()) { - ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0])); + ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), + indexdiv(wshape[0], param->groups))); } channels = wshape[1]; dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; diff --git a/tests/python/topi/python/test_topi_group_conv2d_transpose.py b/tests/python/topi/python/test_topi_group_conv2d_transpose.py new file mode 100644 index 000000000000..90b7500c6cd4 --- /dev/null +++ b/tests/python/topi/python/test_topi_group_conv2d_transpose.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example code to do group transpose convolution.""" + +import numpy as np +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import te, topi +from tvm.contrib.pickle_memoize import memoize +from tvm.topi.utils import get_const_tuple + +_group_conv2d_nchw_implement = { + "generic": ( + topi.nn.group_conv2d_transpose_nchw, + topi.generic.schedule_group_conv2d_transpose_nchw, + ), +} + + +def verify_group_conv2d_transpose_nchw( + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + output_padding, + groups, +): + print( + "Workload: (%d, %d, %s, %d, %s, %s, %s, %s, %d)" + % (batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding, groups) + ) + + in_height, in_width = in_size + kernel_height, kernel_width = kernel + + A = te.placeholder((batch, in_channel, in_height, in_width), name="A") + W = te.placeholder((in_channel, num_filter // groups, kernel_height, kernel_width), name="W") + bias = te.placeholder((num_filter, 1, 1), name="bias") + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + bias_shape = get_const_tuple(bias.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_group_conv2d_transpose.verify_group_conv2d_transpose_nchw") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + c_np = tvm.topi.testing.conv2d_transpose_nchw_python( + a_np, w_np, stride, padding, output_padding, groups + ).astype(dtype) + + return a_np, w_np, b_np, c_np + + a_np, w_np, b_np, c_np = get_ref_data() + + def check_target(target): + dev = tvm.device(target, 0) + if not tvm.testing.device_enabled(target): + print("Skip because %s is not enabled" % target) + return + + print("Running on target: %s" % target) + with tvm.target.Target(target): + fcompute, fschedule = tvm.topi.testing.dispatch(target, _group_conv2d_nchw_implement) + C = fcompute(A, W, stride, padding, dtype, output_padding, groups) + s = fschedule([C]) + + a = tvm.nd.array(a_np, dev) + w = tvm.nd.array(w_np, dev) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) + func = tvm.build( + s, + [A, W, C], + target, + name="group_conv2d_transpose_%d_%d_%s_%d_%s_%s_%s_%s_%d" + % ( + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + output_padding, + groups, + ), + ) + func(a, w, c) + c = c.numpy() + for measurement, reference in zip(c, c_np): + tvm.testing.assert_allclose(measurement, reference, rtol=1e-5) + + for target in ["llvm"]: + check_target(target) + + +@tvm.testing.uses_gpu +def test_group_conv2d_transpose_nchw(): + verify_group_conv2d_transpose_nchw(1, 1, (224, 224), 1, (1, 1), (1, 1), (0, 0, 0, 0), (0, 0), 1) + verify_group_conv2d_transpose_nchw( + 1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0), (0, 0), 1 + ) + verify_group_conv2d_transpose_nchw( + 1, 3, (224, 224), 32, (3, 3), (3, 3), (0, 0, 0, 0), (0, 0), 1 + ) + verify_group_conv2d_transpose_nchw( + 1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0), (0, 0), 1 + ) + verify_group_conv2d_transpose_nchw( + 1, 3, (224, 224), 32, (3, 3), (2, 2), (1, 1, 1, 1), (0, 0), 1 + ) + verify_group_conv2d_transpose_nchw(1, 4, (32, 32), 4, (5, 5), (1, 1), (0, 0, 0, 0), (0, 0), 2) + verify_group_conv2d_transpose_nchw(1, 9, (32, 32), 9, (5, 5), (1, 1), (0, 0, 0, 0), (0, 0), 3) + verify_group_conv2d_transpose_nchw(1, 4, (32, 32), 16, (5, 5), (2, 2), (1, 1, 1, 1), (0, 0), 4) + verify_group_conv2d_transpose_nchw( + 1, 32, (8192, 1), 8, (31, 1), (2, 1), (14, 0, 15, 0), (0, 0), 2 + ) + verify_group_conv2d_transpose_nchw( + 1, 512, (8, 1), 256, (31, 1), (2, 1), (14, 0, 15, 0), (0, 0), 16 + ) + verify_group_conv2d_transpose_nchw( + 1, 512, (8, 1), 256, (31, 1), (2, 1), (14, 0, 15, 0), (1, 0), 16 + ) + verify_group_conv2d_transpose_nchw( + 1, 64, (64, 64), 64, (4, 4), (1, 1), (0, 0, 0, 0), (0, 0), 64 + ) + verify_group_conv2d_transpose_nchw( + 1, 128, (32, 32), 128, (4, 4), (1, 1), (0, 0, 0, 0), (0, 0), 128 + ) + verify_group_conv2d_transpose_nchw( + 1, 256, (16, 16), 256, (4, 4), (1, 1), (0, 0, 0, 0), (0, 0), 256 + ) + + +if __name__ == "__main__": + test_group_conv2d_transpose_nchw()