Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix group transpose conv2d #9443

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def deformable_conv2d_strategy(attrs, inputs, out_type, target):


# conv2d_transpose
def wrap_compute_conv2d_transpose(topi_compute):
def wrap_compute_conv2d_transpose(topi_compute, has_groups=False):
"""wrap conv2d_transpose topi compute"""

def compute_conv2d_transpose(attrs, inputs, out_dtype):
Expand All @@ -456,7 +456,10 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype):
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
output_padding = get_const_tuple(attrs.output_padding)
out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
args = [inputs[0], inputs[1], strides, padding, out_dtype, output_padding]
if has_groups:
args.append(attrs.groups)
out = topi_compute(*args)
return [out]

return compute_conv2d_transpose
Expand All @@ -471,13 +474,19 @@ def conv2d_transpose_strategy(attrs, inputs, out_type, target):
groups = attrs.groups
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
wrap_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.generic",
)
if groups == 1:
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
wrap_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.generic",
)
else: # group_transpose_conv2d
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.group_conv2d_transpose_nchw, has_groups=True),
wrap_topi_schedule(topi.generic.schedule_group_conv2d_transpose_nchw),
name="group_conv2d_transpose_nchw.generic",
)
return strategy


Expand Down
18 changes: 12 additions & 6 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,13 +281,19 @@ def conv2d_transpose_strategy_cpu(attrs, inputs, out_type, target):
groups = attrs.groups
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.x86.conv2d_transpose_nchw),
wrap_topi_schedule(topi.x86.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.x86",
)
if groups == 1:
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.x86.conv2d_transpose_nchw),
wrap_topi_schedule(topi.x86.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.x86",
)
else:
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.group_conv2d_transpose_nchw, has_groups=True),
wrap_topi_schedule(topi.generic.schedule_group_conv2d_transpose_nchw),
name="group_conv2d_transpose_nchw.x86",
)
return strategy


Expand Down
17 changes: 17 additions & 0 deletions python/tvm/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,23 @@ def schedule_conv2d_transpose_nchw(outs):
return _default_schedule(outs, False)


def schedule_group_conv2d_transpose_nchw(outs):
"""Schedule for group_conv2d_transpose_nchw

Parameters
----------
outs: Array of Tensor
The computation graph description of group_conv2d_transpose_nchw
in the format of an array of tensors.

Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_conv1d_transpose_ncw(outs):
"""Schedule for conv1d_transpose_ncw

Expand Down
88 changes: 87 additions & 1 deletion python/tvm/topi/nn/conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .dilate import dilate
from .pad import pad
from .utils import get_pad_tuple
from ..utils import simplify
from ..utils import get_const_tuple, simplify


def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype, output_padding):
Expand Down Expand Up @@ -173,3 +173,89 @@ def conv2d_transpose_legalize(attrs, inputs, types):
return out

return None


def group_conv2d_transpose_nchw(Input, Filter, stride, padding, out_dtype, output_padding, groups):
"""Group convolution operator in NCHW layout.

Parameters
----------
Input : tvm.te.Tensor
4-D with shape [batch, in_channel, in_height, in_width]

Filter : tvm.te.Tensor
4-D with shape [num_filter, in_channel // groups, filter_height, filter_width]

stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]

padding : int or a list/tuple of 2 or 4 ints
padding size, or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 4 ints

out_dtype : str
The output data type. This is used for mixed precision.

output_padding : tuple of ints
Used to get the right output shape for gradients

groups : int
number of groups

out_dtype : str
The output type. This is used for mixed precision.

Returns
-------
Output : tvm.te.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""

if groups == 1:
return conv2d_transpose_nchw(Input, Filter, stride, padding, out_dtype, output_padding)

if out_dtype is None:
out_dtype = Input.dtype
assert isinstance(stride, int) or len(stride) == 2
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride

batch, in_channel, _, _ = get_const_tuple(Input.shape)
in_channel_w, _, _, _ = get_const_tuple(Filter.shape)

assert in_channel % groups == 0, "input channels must divide group size"
assert in_channel_w % groups == 0, "weight channels must divide group size"

data_pad, kernel_transform = conv2d_transpose_nchw_preprocess(
Input, Filter, stride, padding, out_dtype, output_padding
)
batch, in_c, in_h, in_w = data_pad.shape
out_c, _, filter_h, filter_w = kernel_transform.shape

out_c = simplify(out_c)
out_height = simplify(in_h - filter_h + 1)
out_width = simplify(in_w - filter_w + 1)

# compute graph
rc = te.reduce_axis((0, in_c // groups), name="rc")
ry = te.reduce_axis((0, filter_h), name="ry")
rx = te.reduce_axis((0, filter_w), name="rx")
return te.compute(
(batch, out_c * groups, out_height, out_width),
lambda nn, ff, yy, xx: te.sum(
data_pad[
nn,
ff // ((out_c * groups) // groups) * (in_c // groups) + rc,
yy + ry,
xx + rx,
].astype(out_dtype)
* kernel_transform[
ff % out_c, ff // ((out_c * groups) // groups) * (in_c // groups) + rc, ry, rx
].astype(out_dtype),
axis=[rc, ry, rx],
),
tag="group_conv2d_transpose_nchw",
)
40 changes: 39 additions & 1 deletion python/tvm/topi/testing/conv2d_transpose_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm.topi.nn.utils import get_pad_tuple


def conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding):
def _conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding):
"""Transposed convolution operator in NCHW layout.

Parameters
Expand Down Expand Up @@ -141,3 +141,41 @@ def conv2d_transpose_nhwc_python(
)
res_nhwc = np.transpose(res_nchw, (0, 2, 3, 1))
return res_nhwc


def conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding, groups=1):
"""Convolution operator in NCHW layout.

Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]

w_np : numpy.ndarray
4-D with shape [num_filter, in_channel // groups, filter_height, filter_width]

stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]

padding : int or str
Padding size, or ['VALID', 'SAME']

output_padding : int or a list/tuple of two ints
Use to disambiguate the output shape.

groups : int
Number of groups

Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
a_slices = np.array_split(a_np, groups, axis=1)
w_slices = np.array_split(w_np, groups, axis=0)
b_slices = [
_conv2d_transpose_nchw_python(a_slice, w_slice, stride, padding, output_padding)
for a_slice, w_slice in zip(a_slices, w_slices)
]
b_np = np.concatenate(b_slices, axis=1)
return b_np
21 changes: 16 additions & 5 deletions src/relay/op/nn/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -1070,19 +1070,29 @@ bool Conv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;

auto dshape_nchw = trans_in_layout.ForwardShape(data->shape);
if (param->groups > 1) {
ICHECK(weight->shape.defined())
<< "Weight shape must be specified when groups is greater than 1.";
}

// infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) {
ICHECK_EQ(param->kernel_size.size(), 2);
ICHECK_EQ(param->dilation.size(), 2);

Array<IndexExpr> wshape({dshape_nchw[1], indexdiv(param->channels, param->groups),
param->kernel_size[0], param->kernel_size[1]});

tvm::tir::ExprDeepEqual expr_equal;
Array<IndexExpr> wshape;
if (expr_equal(param->channels, 1)) {
wshape = {{dshape_nchw[1], param->channels, param->kernel_size[0], param->kernel_size[1]}};
channels = param->groups;
} else {
wshape = {{dshape_nchw[1], indexdiv(param->channels, param->groups), param->kernel_size[0],
param->kernel_size[1]}};
channels = param->channels;
}
wshape = trans_kernel_layout.BackwardShape(wshape);
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
channels = param->channels;

DataType weight_dtype = data->dtype;
if (weight != nullptr) {
Expand All @@ -1108,7 +1118,8 @@ bool Conv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
<< " channels=" << param->channels << " wshape=" << Array<IndexExpr>(wshape);
}
if (!dshape_nchw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) {
ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0]));
ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups),
indexdiv(wshape[0], param->groups)));
}
channels = wshape[1];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
Expand Down
Loading