Skip to content

Commit

Permalink
Added group transposed convolution
Browse files Browse the repository at this point in the history
Change includes topi implementation, tests, generic and x86
strategy for group transposed convolution.

Signed-off-by: Alicja Kwasniewska <alicja.kwasniewska@sima.ai>
  • Loading branch information
alicja-SiMa-ai committed Nov 5, 2021
1 parent 19b23b9 commit 8eb15b7
Show file tree
Hide file tree
Showing 7 changed files with 344 additions and 21 deletions.
25 changes: 17 additions & 8 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def deformable_conv2d_strategy(attrs, inputs, out_type, target):


# conv2d_transpose
def wrap_compute_conv2d_transpose(topi_compute):
def wrap_compute_conv2d_transpose(topi_compute, has_groups=False):
"""wrap conv2d_transpose topi compute"""

def compute_conv2d_transpose(attrs, inputs, out_dtype):
Expand All @@ -456,7 +456,10 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype):
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
output_padding = get_const_tuple(attrs.output_padding)
out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
args = [inputs[0], inputs[1], strides, padding, out_dtype, output_padding]
if has_groups:
args.append(attrs.groups)
out = topi_compute(*args)
return [out]

return compute_conv2d_transpose
Expand All @@ -471,13 +474,19 @@ def conv2d_transpose_strategy(attrs, inputs, out_type, target):
groups = attrs.groups
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
wrap_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.generic",
)
if groups == 1:
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
wrap_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.generic",
)
else: # group_transpose_conv2d
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.group_conv2d_transpose_nchw, has_groups=True),
wrap_topi_schedule(topi.generic.schedule_group_conv2d_transpose_nchw),
name="group_conv2d_transpose_nchw.generic",
)
return strategy


Expand Down
18 changes: 12 additions & 6 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,13 +281,19 @@ def conv2d_transpose_strategy_cpu(attrs, inputs, out_type, target):
groups = attrs.groups
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.x86.conv2d_transpose_nchw),
wrap_topi_schedule(topi.x86.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.x86",
)
if groups == 1:
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.x86.conv2d_transpose_nchw),
wrap_topi_schedule(topi.x86.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.x86",
)
else:
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.group_conv2d_transpose_nchw, has_groups=True),
wrap_topi_schedule(topi.generic.schedule_group_conv2d_transpose_nchw),
name="group_conv2d_transpose_nchw.x86",
)
return strategy


Expand Down
17 changes: 17 additions & 0 deletions python/tvm/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,23 @@ def schedule_conv2d_transpose_nchw(outs):
return _default_schedule(outs, False)


def schedule_group_conv2d_transpose_nchw(outs):
"""Schedule for group_conv2d_transpose_nchw
Parameters
----------
outs: Array of Tensor
The computation graph description of group_conv2d_transpose_nchw
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_conv1d_transpose_ncw(outs):
"""Schedule for conv1d_transpose_ncw
Expand Down
88 changes: 87 additions & 1 deletion python/tvm/topi/nn/conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .dilate import dilate
from .pad import pad
from .utils import get_pad_tuple
from ..utils import simplify
from ..utils import get_const_tuple, simplify


def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype, output_padding):
Expand Down Expand Up @@ -173,3 +173,89 @@ def conv2d_transpose_legalize(attrs, inputs, types):
return out

return None


def group_conv2d_transpose_nchw(Input, Filter, stride, padding, out_dtype, output_padding, groups):
"""Group convolution operator in NCHW layout.
Parameters
----------
Input : tvm.te.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
Filter : tvm.te.Tensor
4-D with shape [num_filter, in_channel // groups, filter_height, filter_width]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or a list/tuple of 2 or 4 ints
padding size, or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 4 ints
out_dtype : str
The output data type. This is used for mixed precision.
output_padding : tuple of ints
Used to get the right output shape for gradients
groups : int
number of groups
out_dtype : str
The output type. This is used for mixed precision.
Returns
-------
Output : tvm.te.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""

if groups == 1:
return conv2d_transpose_nchw(Input, Filter, stride, padding, out_dtype, output_padding)

if out_dtype is None:
out_dtype = Input.dtype
assert isinstance(stride, int) or len(stride) == 2
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride

batch, in_channel, _, _ = get_const_tuple(Input.shape)
in_channel_w, _, _, _ = get_const_tuple(Filter.shape)

assert in_channel % groups == 0, "input channels must divide group size"
assert in_channel_w % groups == 0, "weight channels must divide group size"

data_pad, kernel_transform = conv2d_transpose_nchw_preprocess(
Input, Filter, stride, padding, out_dtype, output_padding
)
batch, in_c, in_h, in_w = data_pad.shape
out_c, _, filter_h, filter_w = kernel_transform.shape

out_c = simplify(out_c)
out_height = simplify(in_h - filter_h + 1)
out_width = simplify(in_w - filter_w + 1)

# compute graph
rc = te.reduce_axis((0, in_c // groups), name="rc")
ry = te.reduce_axis((0, filter_h), name="ry")
rx = te.reduce_axis((0, filter_w), name="rx")
return te.compute(
(batch, out_c * groups, out_height, out_width),
lambda nn, ff, yy, xx: te.sum(
data_pad[
nn,
ff // ((out_c * groups) // groups) * (in_c // groups) + rc,
yy + ry,
xx + rx,
].astype(out_dtype)
* kernel_transform[
ff % out_c, ff // ((out_c * groups) // groups) * (in_c // groups) + rc, ry, rx
].astype(out_dtype),
axis=[rc, ry, rx],
),
tag="group_conv2d_transpose_nchw",
)
40 changes: 39 additions & 1 deletion python/tvm/topi/testing/conv2d_transpose_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm.topi.nn.utils import get_pad_tuple


def conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding):
def _conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding):
"""Transposed convolution operator in NCHW layout.
Parameters
Expand Down Expand Up @@ -141,3 +141,41 @@ def conv2d_transpose_nhwc_python(
)
res_nhwc = np.transpose(res_nchw, (0, 2, 3, 1))
return res_nhwc


def conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding, groups=1):
"""Convolution operator in NCHW layout.
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
w_np : numpy.ndarray
4-D with shape [num_filter, in_channel // groups, filter_height, filter_width]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str
Padding size, or ['VALID', 'SAME']
output_padding : int or a list/tuple of two ints
Use to disambiguate the output shape.
groups : int
Number of groups
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
a_slices = np.array_split(a_np, groups, axis=1)
w_slices = np.array_split(w_np, groups, axis=0)
b_slices = [
_conv2d_transpose_nchw_python(a_slice, w_slice, stride, padding, output_padding)
for a_slice, w_slice in zip(a_slices, w_slices)
]
b_np = np.concatenate(b_slices, axis=1)
return b_np
21 changes: 16 additions & 5 deletions src/relay/op/nn/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -1070,19 +1070,29 @@ bool Conv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;

auto dshape_nchw = trans_in_layout.ForwardShape(data->shape);
if (param->groups > 1) {
ICHECK(weight->shape.defined())
<< "Weight shape must be specified when groups is greater than 1.";
}

// infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) {
ICHECK_EQ(param->kernel_size.size(), 2);
ICHECK_EQ(param->dilation.size(), 2);

Array<IndexExpr> wshape({dshape_nchw[1], indexdiv(param->channels, param->groups),
param->kernel_size[0], param->kernel_size[1]});

tvm::tir::ExprDeepEqual expr_equal;
Array<IndexExpr> wshape;
if (expr_equal(param->channels, 1)) {
wshape = {{dshape_nchw[1], param->channels, param->kernel_size[0], param->kernel_size[1]}};
channels = param->groups;
} else {
wshape = {{dshape_nchw[1], indexdiv(param->channels, param->groups), param->kernel_size[0],
param->kernel_size[1]}};
channels = param->channels;
}
wshape = trans_kernel_layout.BackwardShape(wshape);
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
channels = param->channels;

DataType weight_dtype = data->dtype;
if (weight != nullptr) {
Expand All @@ -1108,7 +1118,8 @@ bool Conv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
<< " channels=" << param->channels << " wshape=" << Array<IndexExpr>(wshape);
}
if (!dshape_nchw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) {
ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0]));
ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups),
indexdiv(wshape[0], param->groups)));
}
channels = wshape[1];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
Expand Down
Loading

0 comments on commit 8eb15b7

Please sign in to comment.