Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Conv1Dtranspose default kernel layout should be IOW #14482

Merged
merged 3 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -671,10 +671,10 @@ struct Conv1DTransposeAttrs : public tvm::AttrsNode<Conv1DTransposeAttrs> {
"dimensions respectively. Convolution is applied on the"
"'W' dimension.");
TVM_ATTR_FIELD(kernel_layout)
.set_default("OIW")
.set_default("IOW")
.describe(
"Dimension ordering of data and weight. Can be 'OIW', 'OIW16o16i', etc."
"'O', 'I', 'W' stands for num_filter, input_channel, and width"
"Dimension ordering of data and weight. Can be 'IOW', 'IOW16o16i', etc."
"'I', 'O', 'W' stands for input_channel, num_filter and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout)
.set_default("")
Expand Down
12 changes: 8 additions & 4 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ def _convert_dense(


def _convert_convolution1d(inexpr, keras_layer, etab, data_layout, input_shape=None):
is_deconv = type(keras_layer).__name__ == "Conv1DTranspose"

if input_shape is None:
input_shape = keras_layer.input_shape
_check_data_format(keras_layer)
Expand All @@ -290,19 +292,21 @@ def _convert_convolution1d(inexpr, keras_layer, etab, data_layout, input_shape=N

if data_layout == "NWC":
kernel_layout = "WIO"
if is_deconv:
kernel_layout = "WOI"
else:
kernel_layout = "OIW"
if is_deconv:
kernel_layout = "IOW"
msg = (
"Kernel layout with {} is not supported for operator Convolution1D "
"in frontend Keras."
)
raise tvm.error.OpAttributeUnImplemented(msg.format(data_layout))

is_deconv = type(keras_layer).__name__ == "Conv1DTranspose"

if is_deconv:
if kernel_layout == "OIW":
weight = weight.transpose([2, 0, 1])
if kernel_layout == "IOW":
weight = weight.transpose([2, 1, 0])
kernel_w, n_filters, _ = weight.shape
else:
kernel_w, _, n_filters = weight.shape
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def _mx_conv1d_transpose(inputs, attrs):
if data_layout != "NCW":
raise tvm.error.OpAttributeInvalid('Only "NCW" data layout is supported for 1D Convolution')
channel_axis = 1
kernel_layout = "OIW"
kernel_layout = "IOW"
new_attrs = {}
new_attrs["channels"] = attrs.get_int("num_filter")
new_attrs["kernel_size"] = attrs.get_int_tuple("kernel")
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,9 @@ def convolution(self, inputs, input_types):
else:
data_layout = "NCW"
kernel_layout = "OIW"
if use_transpose:
# Transposed convolutions have IOW layout.
kernel_layout = "IOW"

# Conv1d does not currently support grouped convolution so we convert it to conv2d
is_grouped_conv1d = False
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def conv1d_transpose(
channels=None,
kernel_size=None,
data_layout="NCW",
kernel_layout="OIW",
kernel_layout="IOW",
out_layout="",
output_padding=(0,),
out_dtype="",
Expand Down
18 changes: 10 additions & 8 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ bool Conv1DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
if (data == nullptr) return false;

static const Layout kNCW("NCW");
static const Layout kOIW("OIW");
static const Layout kIOW("IOW");

const Conv1DTransposeAttrs* param = attrs.as<Conv1DTransposeAttrs>();
ICHECK(param != nullptr);
Expand All @@ -938,9 +938,9 @@ bool Conv1DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
<< "Conv only support input layouts that are convertible from NCW."
<< " But got " << in_layout;

const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW);
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kIOW);
ICHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIW."
<< "Conv only support kernel layouts that are convertible from IOW."
<< " But got " << kernel_layout;

Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
Expand Down Expand Up @@ -979,16 +979,18 @@ bool Conv1DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
ICHECK_EQ(param->kernel_size.size(), 1);
// check the size
ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]))
<< "Conv1D: shape of weight is inconsistent with kernel_size, "
<< "Conv1DTraspose: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size << " wshape=" << Array<IndexExpr>(wshape);
}
if (param->channels.defined()) {
ICHECK(reporter->AssertEQ(param->channels, wshape[1]))
<< "Conv1D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << Array<IndexExpr>(wshape);
ICHECK(reporter->AssertEQ(indexdiv(param->channels, param->groups), wshape[1]))
<< "Conv1DTraspose: shape of weight is inconsistent with channels, "
<< " out_channels // groups != weight.shape[1] "
<< " out_channels=" << param->channels << " groups=" << param->groups
<< " wshape=" << Array<IndexExpr>(wshape);
}
if (!dshape_ncw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) {
ICHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0]));
ICHECK(reporter->AssertEQ(dshape_ncw[1], wshape[0]));
}
channels = wshape[1];
dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0];
Expand Down