Skip to content

Commit

Permalink
[Conv2DTransposed] Fix wrong shape check and add new TOPI module to s…
Browse files Browse the repository at this point in the history
…upport groups (apache#9465)

* f wrong type check in conv2d_transpose

* add test case for conv2d transpose

* add groups support for conv2d_transpose

* add  naive implementation and schedule for conv2d with groups

* enable tests for cpu and arm_cpu, raise error for cuda platform

* revert the cuda and generic strategy

* revert back the x86 strategy

* revert back the arm_cpu strategy

* revert back the arm_cpu strategy

* revert back the arm_cpu strategy

* fix EOF of x86

* clang lint updated c++ code

* update topi implementation

* Revert test

* Revert test

* add generic/x86/arm specialization for conv2d_transpose with groups > 1

* remove commentted codes

* fix lint

* fix lint

* fix c++ lint

* fix lint

* fix python lint

* remove comments and reformat

* lint file

* lint code

* fix lint

* update logging information in convolution.h

Co-authored-by: Alicja Kwasniewska <alicja.kwasniewska@sima.ai>
  • Loading branch information
2 people authored and mehrdadh committed Dec 1, 2021
1 parent 075a93c commit 2c2b91f
Show file tree
Hide file tree
Showing 8 changed files with 385 additions and 24 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
groups = attrs.groups
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
assert groups == 1, "only support groups == 1 when targetting cuda/gpu"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.cuda.conv2d_transpose_nchw),
Expand Down
26 changes: 18 additions & 8 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def deformable_conv2d_strategy(attrs, inputs, out_type, target):


# conv2d_transpose
def wrap_compute_conv2d_transpose(topi_compute):
def wrap_compute_conv2d_transpose(topi_compute, has_groups=False):
"""wrap conv2d_transpose topi compute"""

def compute_conv2d_transpose(attrs, inputs, out_dtype):
Expand All @@ -456,7 +456,11 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype):
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
output_padding = get_const_tuple(attrs.output_padding)
out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
# out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
args = [inputs[0], inputs[1], strides, padding, out_dtype, output_padding]
if has_groups:
args.append(attrs.groups)
out = topi_compute(*args)
return [out]

return compute_conv2d_transpose
Expand All @@ -471,13 +475,19 @@ def conv2d_transpose_strategy(attrs, inputs, out_type, target):
groups = attrs.groups
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
wrap_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.generic",
)
if groups == 1:
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
wrap_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.generic",
)
else: # group_transpose_conv2d
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.group_conv2d_transpose_nchw, has_groups=True),
wrap_topi_schedule(topi.generic.schedule_group_conv2d_transpose_nchw),
name="group_conv2d_transpose_nchw.generic",
)
return strategy


Expand Down
18 changes: 12 additions & 6 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,13 +281,19 @@ def conv2d_transpose_strategy_cpu(attrs, inputs, out_type, target):
groups = attrs.groups
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.x86.conv2d_transpose_nchw),
wrap_topi_schedule(topi.x86.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.x86",
)
if groups == 1:
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.x86.conv2d_transpose_nchw),
wrap_topi_schedule(topi.x86.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.x86",
)
else:
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.group_conv2d_transpose_nchw, has_groups=True),
wrap_topi_schedule(topi.generic.schedule_group_conv2d_transpose_nchw),
name="group_conv2d_transpose_nchw.x86",
)
return strategy


Expand Down
17 changes: 17 additions & 0 deletions python/tvm/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,23 @@ def schedule_group_conv2d_nchw(outs):
return _default_schedule(outs, False)


def schedule_group_conv2d_transpose_nchw(outs):
"""Schedule for group_conv2d_transpose_nchw
Parameters
----------
outs: Array of Tensor
The computation graph description of group_conv2d_nhwc
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_group_conv2d_nhwc(outs):
"""Schedule for group_conv2d_nhwc
Expand Down
129 changes: 129 additions & 0 deletions python/tvm/topi/nn/conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
# pylint: disable=invalid-name, unused-variable, unused-argument
"""Transposed 2D convolution operators (sometimes called Deconvolution)."""
import collections

import tvm
from tvm import relay, te

Expand All @@ -25,6 +27,22 @@
from .utils import get_pad_tuple


def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
assert len(x) == n, f"Input can only have {n} elements, but got {len(x)} instead: {x}."
return x
return tuple(repeat(x, n))

return parse


_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
_quadruple = _ntuple(4)


def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype, output_padding):
"""Transposed 2D convolution nchw forward operator.
Expand Down Expand Up @@ -116,6 +134,117 @@ def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype,
return Output


def group_conv2d_transpose_nchw(data, kernel, stride, padding, out_dtype, output_padding, groups):
"""Group convolution operator in NCHW layout.
Parameters
----------
data : tvm.te.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
kernel : tvm.te.Tensor
4-D with shape [in_channel, out_channel // groups, filter_height, filter_width]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or a list/tuple of 2 or 4 ints
padding size, or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 4 ints
out_dtype : str
The output data type. This is used for mixed precision.
output_padding : tuple of ints
Used to get the right output shape for gradients
groups : int
number of groups
out_dtype : str
The output type. This is used for mixed precision.
Returns
-------
Output : tvm.te.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
if groups == 1:
return conv2d_transpose_nchw(data, kernel, stride, padding, out_dtype, output_padding)

# some pre-processing and prelimnary checks
if out_dtype is None:
out_dtype = data.dtype

batch, in_channels, in_h, in_w = data.shape
_, out_c, filter_h, filter_w = kernel.shape
assert (
in_channels % groups == 0
), f"input channels {in_channels} must divide group size {groups}"
# assert out_c % groups == 0, f"output channels {in_c} must divide group size {groups}"

strides = _pair(stride)
# padding = _pair(padding)
# output_padding = _pair(output_padding)
# dilation = _pair(dilation)

stride_h, stride_w = strides
opad_h, opad_w = output_padding
assert (
opad_h < stride_h and opad_w < stride_w
), f"[{output_padding}] opad_h:{opad_h} < stride_h:{stride_h} \
and opad_w:{opad_w} < stride_w:{stride_w} does not satisfy."
# dilate data
data_dilate = dilate(data, [1, 1, stride_h, stride_w], name="data_dilate")
# pad data
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w))
bpad_top = filter_h - 1 - fpad_top
bpad_bottom = filter_h - 1 - fpad_bottom + opad_h
bpad_left = filter_w - 1 - fpad_left
bpad_right = filter_w - 1 - fpad_right + opad_w
data_pad = pad(
data_dilate, [0, 0, bpad_top, bpad_left], [0, 0, bpad_bottom, bpad_right], name="data_pad"
)
# transform kernel layout from IOHW to OIHW, and rotate kernel by 180 degrees
kernel_transform = te.compute(
(out_c, in_channels, filter_h, filter_w),
lambda i, o, h, w: kernel[o][i][filter_h - 1 - h][filter_w - 1 - w],
name="kernel_transform",
)

batch, in_channels, in_h, in_w = data_pad.shape
out_c, _, filter_h, filter_w = kernel_transform.shape

# convolution stage
out_channels = simplify(out_c * groups)

out_h = simplify(in_h - filter_h + 1)
out_w = simplify(in_w - filter_w + 1)
dc = te.reduce_axis((0, in_channels // groups), name="dc")
dh = te.reduce_axis((0, filter_h), name="dh")
dw = te.reduce_axis((0, filter_w), name="dw")

# data: batch, in_channels, out_h, out_w
# weight: out_channels // G, in_channels, out_h, out_w
return te.compute(
(batch, out_channels, out_h, out_w),
lambda b, c, h, w: te.sum(
data_pad[
b, c // (out_channels // groups) * (in_channels // groups) + dc, h + dh, w + dw
].astype(out_dtype)
* kernel_transform[
c % (out_channels // groups),
c // (out_channels // groups) * (in_channels // groups) + dc,
dh,
dw,
].astype(out_dtype),
axis=[dc, dh, dw],
),
tag="group_conv2d_transpose_nchw",
)


def layout_transform(tensor: "relay.Expr", current_layout: str, desired_layout: str):
"""Transform a tensor with the current layout to the desired layout.
Expand Down
40 changes: 39 additions & 1 deletion python/tvm/topi/testing/conv2d_transpose_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm.topi.nn.utils import get_pad_tuple


def conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding):
def _conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding):
"""Transposed convolution operator in NCHW layout.
Parameters
Expand Down Expand Up @@ -141,3 +141,41 @@ def conv2d_transpose_nhwc_python(
)
res_nhwc = np.transpose(res_nchw, (0, 2, 3, 1))
return res_nhwc


def conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding, groups=1):
"""Convolution operator in NCHW layout.
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
w_np : numpy.ndarray
4-D with shape [in_channel, num_filter // groups, filter_height, filter_width]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str
Padding size, or ['VALID', 'SAME']
output_padding : int or a list/tuple of two ints
Use to disambiguate the output shape.
groups : int
Number of groups
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
a_slices = np.array_split(a_np, groups, axis=1)
w_slices = np.array_split(w_np, groups, axis=0)
b_slices = [
_conv2d_transpose_nchw_python(a_slice, w_slice, stride, padding, output_padding)
for a_slice, w_slice in zip(a_slices, w_slices)
]
b_np = np.concatenate(b_slices, axis=1)
return b_np
21 changes: 13 additions & 8 deletions src/relay/op/nn/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -1053,18 +1053,18 @@ bool Conv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a

const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
ICHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCHW."
<< "Conv2DTransposed only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;

const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kIOHW);
ICHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from IOHW."
<< "Conv2DTransposed only support kernel layouts that are convertible from IOHW."
<< " But got " << kernel_layout;

Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
ICHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCHW."
<< "Conv2DTransposed only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;

IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
Expand Down Expand Up @@ -1099,16 +1099,21 @@ bool Conv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
// check the size
ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
reporter->AssertEQ(param->kernel_size[1], wshape[3]))
<< "Conv2D: shape of weight is inconsistent with kernel_size, "
<< "Conv2DTransposed: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size << " wshape=" << Array<IndexExpr>(wshape);
}
if (param->channels.defined()) {
ICHECK(reporter->AssertEQ(param->channels, wshape[1]))
<< "Conv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << Array<IndexExpr>(wshape);
ICHECK(reporter->AssertEQ(indexdiv(param->channels, param->groups), wshape[1]))
<< "Conv2DTransposed: shape of weight is inconsistent with out_channels, "
<< " out_channels // groups != weight.shape[1] "
<< " out_channels=" << param->channels << " groups=" << param->groups
<< " weight.shape=" << Array<IndexExpr>(wshape);
}
if (!dshape_nchw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) {
ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0]));
ICHECK(reporter->AssertEQ(dshape_nchw[1], wshape[0]))
<< "Conv2DTransposed: shape of weight is inconsistent with in_channels."
<< " data.shape= " << Array<IndexExpr>(dshape_nchw) << " groups= " << param->groups
<< " weight.shape= " << Array<IndexExpr>(wshape);
}
channels = wshape[1];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
Expand Down
Loading

0 comments on commit 2c2b91f

Please sign in to comment.