Skip to content

Commit

Permalink
Implement 1d deconvolution
Browse files Browse the repository at this point in the history
  • Loading branch information
alexgl-github committed Dec 9, 2019
1 parent 7cf1ead commit 9e5facf
Show file tree
Hide file tree
Showing 16 changed files with 818 additions and 14 deletions.
56 changes: 56 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,62 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
}
};

/*! \brief Attributes used in 1D transposed convolution operator */
struct Conv1DTransposeAttrs : public tvm::AttrsNode<Conv1DTransposeAttrs> {
IndexExpr channels;
Array<IndexExpr> kernel_size;
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> output_padding;
Array<IndexExpr> dilation;
int groups;
std::string data_layout;
std::string kernel_layout;
std::string out_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv1DTransposeAttrs, "relay.attrs.Conv1DTransposeAttrs") {
TVM_ATTR_FIELD(channels)
.set_default(NullValue<IndexExpr>())
.describe("The dimensionality of the output space"
"i.e. the number of output channels in the convolution.");
TVM_ATTR_FIELD(kernel_size)
.describe("The dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("The strides of the convolution.");
TVM_ATTR_FIELD(output_padding).set_default(Array<IndexExpr>({0, 0}))
.describe("Zero-padding added to one side of the output.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1)
.describe("Controls the connections between inputs and outputs."
"At groups=1, all inputs are convolved to all outputs."
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(data_layout).set_default("NCW")
.describe("Dimension ordering of data. Can be 'NCW', 'NWC', etc."
"'N', 'C', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the"
"'W' dimension.");
TVM_ATTR_FIELD(kernel_layout).set_default("OIW")
.describe("Dimension ordering of data and weight. Can be 'OIW', 'OIW16o16i', etc."
"'O', 'I', 'W' stands for num_filter, input_channel, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("")
.describe("Dimension ordering of output. Can be 'NCW', 'NWC', etc."
"'N', 'C', 'W' stands for batch, channel, and width"
"dimensions respectively. Default to be same as input layout.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};

/*! \brief Attributes for max pool operator */
struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
Array<IndexExpr> pool_size;
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __call__(self, args, attrs, type_args):
"nn.softmax": op.nn.softmax,
"reshape": op.reshape,
"nn.conv2d_transpose": op.nn.conv2d_transpose,
"nn.conv1d_transpose": op.nn.conv1d_transpose,
"concatenate": op.concatenate,
"nn.dropout": op.nn.dropout_raw,
"zeros": op.zeros,
Expand Down
20 changes: 7 additions & 13 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,29 +207,23 @@ def _mx_conv1d_transpose(inputs, attrs):
if data_layout != "NCW":
raise tvm.error.OpAttributeInvalid(
'Only "NCW" data layout is supported for 1D Convolution')
data_layout = "NCHW"
channel_axis = 1
kernel_layout = "OIHW"

kernel_layout = "OIW"
new_attrs = {}
new_attrs["channels"] = attrs.get_int("num_filter")
new_attrs["kernel_size"] = (1,) + attrs.get_int_tuple("kernel")
new_attrs["strides"] = (1,) + attrs.get_int_tuple("stride", (1,))
new_attrs["output_padding"] = (0,) + attrs.get_int_tuple("adj", (0,))
new_attrs["padding"] = (0,) + attrs.get_int_tuple("pad", (0,))
new_attrs["dilation"] = (1,) + attrs.get_int_tuple("dilate", (1,))
new_attrs["kernel_size"] = attrs.get_int_tuple("kernel")
new_attrs["strides"] = attrs.get_int_tuple("stride", (1,))
new_attrs["output_padding"] = attrs.get_int_tuple("adj", (0,))
new_attrs["padding"] = attrs.get_int_tuple("pad", (0,))
new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1,))
new_attrs["groups"] = attrs.get_int("num_group", 1)
new_attrs["data_layout"] = data_layout
new_attrs["kernel_layout"] = kernel_layout
use_bias = not attrs.get_bool("no_bias", True)
data = _op.expand_dims(inputs[0], axis=2)
kernel = _op.expand_dims(inputs[1], axis=2)
res = _op.nn.conv2d_transpose(data, kernel, **new_attrs)

res = _op.nn.conv1d_transpose(inputs[0], inputs[1], **new_attrs)
if use_bias:
assert len(inputs) == 3
res = _op.nn.bias_add(res, inputs[2], axis=channel_axis)
res = _op.squeeze(res, axis=[2])
return res


Expand Down
32 changes: 32 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,38 @@ def legalize_conv2d_transpose(attrs, inputs, types):

reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)

# conv1d_transpose
@reg.register_compute("nn.conv1d_transpose")
def compute_conv1d_transpose(attrs, inputs, out_dtype, target):
"""Compute definition of conv1d_transpose"""
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
layout = attrs.data_layout
out_dtype = attrs.out_dtype
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype)
assert layout == "NCW", "conv1d_transpose ncw only supported"
assert dilation == (1,), "conv1d_transpose dilation is not supported"
assert groups == 1, "conv1d_transpose groups == 1 only supported"
out = topi.nn.conv1d_transpose_ncw(
inputs[0], inputs[1], strides, padding, out_dtype)
output_padding = get_const_tuple(attrs.output_padding)
out = topi.nn.pad(out,
[0, 0, 0], [0, 0, output_padding[0]])
return [out]


@reg.register_schedule("nn.conv1d_transpose")
def schedule_conv1d_transpose(attrs, outs, target):
"""Schedule definition of conv1d_transpose"""
with target:
return topi.generic.schedule_conv1d_transpose_ncw(outs)
raise ValueError("No compatible schedule")

reg.register_pattern("nn.conv1d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)

# bias_add
reg.register_schedule("nn.bias_add", schedule_injective)
reg.register_pattern("nn.bias_add", OpPattern.BROADCAST)
Expand Down
67 changes: 67 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
Expand Down Expand Up @@ -257,6 +258,72 @@ def conv2d_transpose(data,
kernel_layout, out_layout, output_padding, out_dtype)


def conv1d_transpose(data,
weight,
strides=(1,),
padding=(0,),
dilation=(1,),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCW",
kernel_layout="OIW",
out_layout="",
output_padding=(0,),
out_dtype=""):
"""One dimensional transposed convolution operator.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
weight : tvm.relay.Expr
The weight expressions.
strides : Tuple[int], optional
The strides of convolution.
padding : Tuple[int], optional
The padding of convolution on both sides of inputs.
dilation : Tuple[int], optional
Specifies the dilation rate to be used for dilated convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
groups : int, optional
Number of groups for grouped convolution.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : Optional[str]
Layout of the output, by default, out_layout is the same as data_layout
output_padding : Tuple[int], optional
Additional zero-padding to be added to one side of the output.
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.conv1d_transpose(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, output_padding, out_dtype)


def softmax(data, axis=-1):
r"""Computes softmax.
Expand Down
154 changes: 154 additions & 0 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,160 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW`
.add_type_rel("Conv2DTranspose", Conv2DTransposeRel);


// relay.nn.conv1d_transpose
TVM_REGISTER_NODE_TYPE(Conv1DTransposeAttrs);

bool Conv1DTransposeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr) return false;

static const Layout kNCW("NCW");
static const Layout kOIW("OIW");

const Conv1DTransposeAttrs* param = attrs.as<Conv1DTransposeAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);

const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCW."
<< " But got " << in_layout;

const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIW."
<< " But got "<< kernel_layout;

Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCW."
<< " But got " << out_layout;

IndexExpr channels, dilated_ksize_y, dilated_ksize_x;

auto dshape_ncw = trans_in_layout.ForwardShape(data->shape);

// infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 1);
CHECK_EQ(param->dilation.size(), 1);

Array<IndexExpr> wshape({dshape_ncw[1],
indexdiv(param->channels, param->groups),
param->kernel_size[0]});

wshape = trans_kernel_layout.BackwardShape(wshape);
dilated_ksize_x = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
channels = param->channels;

// assign result to reporter
reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
if (param->kernel_size.defined()) {
CHECK_EQ(param->kernel_size.size(), 1);
// check the size
CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]))
<< "Conv1D: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size
<< " wshape=" << Array<IndexExpr>(wshape);
}
if (param->channels.defined()) {
CHECK(reporter->AssertEQ(param->channels, wshape[1]))
<< "Conv1D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels
<< " wshape=" << Array<IndexExpr>(wshape);
}
CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0]));
channels = wshape[1];
dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0];
}
// dilation
Array<IndexExpr> oshape({dshape_ncw[0], channels, 0});
oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x -
2 * param->padding[0] + param->output_padding[0]));

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = trans_out_layout.BackwardShape(oshape);
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
return true;
}


Expr MakeConv1DTranspose(Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
Array<IndexExpr> output_padding,
DataType out_dtype) {
auto attrs = make_node<Conv1DTransposeAttrs>();
attrs->channels = std::move(channels);
attrs->kernel_size = std::move(kernel_size);
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->output_padding = std::move(output_padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.conv1d_transpose");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}


TVM_REGISTER_API("relay.op.nn._make.conv1d_transpose")
.set_body_typed(MakeConv1DTranspose);

RELAY_REGISTER_OP("nn.conv1d_transpose")
.describe(R"code(Transposed 1D convolution layer (sometimes called Deconvolution).
The need for transposed convolutions generally arises
from the desire to use a transformation going in the opposite direction
of a normal convolution, i.e., from something that has the shape of the
output of some convolution to something that has the shape of its input
while maintaining a connectivity pattern that is compatible with
said convolution.
- **data**: This depends on the `layout` parameter. Input is 3D array of shape
(batch_size, in_channels, width) if `layout` is `NCW`.
- **weight**: (in_channels, channels, kernel_size[0])
- **bias**: (channels,)
- **out**: This depends on the `layout` parameter. Output is 3D array of shape
v (batch_size, channels, out_height, out_width) if `layout` is `NCW`.
out_height and out_width are calculated as::
out_width = (width-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0]
)code" TVM_ADD_FILELINE)
.set_attrs_type<Conv1DTransposeAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.add_type_rel("Conv1DTranspose", Conv1DTransposeRel);


// relay.nn.contrib_conv2d_winograd_without_weight_transform
TVM_REGISTER_NODE_TYPE(Conv2DWinogradAttrs);

Expand Down
Loading

0 comments on commit 9e5facf

Please sign in to comment.