Skip to content

Commit

Permalink
[Relay/Topi][Op] Conv1D (apache#4639)
Browse files Browse the repository at this point in the history
* added conv1d operators to topi.

* Started to add python testing.

* Added python conv1d implementation for testing.

* Wrote test but need to add cuda schedule :(

* Cuda schedules working for both conv1d layouts.

* All topi tests passing.

* Formatting topi.

* Removed pad_method option as its probably overkill.

* Added relay op definition of conv1d.

* End2end conv1d working with onnx.

* Lint fixes.

* Formatting fixes.

* Rebase fix.

* Switched to array based attributes for consistency across convs.

* Improved onnx parsing and testing for convolutions.

* lint fix

* Tiny tweak.

* Bug fix

* Rebase fix.

* Add group ignore to onnx conv1d frontend.

* Unified MakeConv and fixed documentation.

* improved autopadding

* Addressed feedback and simplified onnx frontend.

* Format fix.

* Basic X86 NCW schedule working.

* Added nwc schedule.

* fixed name

* Added more tests and basic x86 schedules.

* Format fix.

* Added non power of two shape tests.
  • Loading branch information
jwfromm authored and zhiics committed Mar 2, 2020
1 parent e192cf7 commit 5fbbc5f
Show file tree
Hide file tree
Showing 19 changed files with 1,445 additions and 106 deletions.
48 changes: 48 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,54 @@ struct BiasAddAttrs : public tvm::AttrsNode<BiasAddAttrs> {
};


/*! \brief Attributes used in 1D convolution operators */
struct Conv1DAttrs : public tvm::AttrsNode<Conv1DAttrs> {
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
int groups;
IndexExpr channels;
Array<IndexExpr> kernel_size;
std::string data_layout;
std::string kernel_layout;
std::string out_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv1DAttrs, "relay.attrs.Conv1DAttrs") {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, }))
.describe("Specifies the stride of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, }))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1)
.describe("Currently unused but may be added in the future.");
TVM_ATTR_FIELD(channels)
.describe("The number of output channels in the convolution."
" If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
TVM_ATTR_FIELD(data_layout).set_default("NCW")
.describe("Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
"'N', 'C', 'W' stands for batch, channel, and width"
"dimensions respectively. Convolution is applied on the 'W'"
"dimension.");
TVM_ATTR_FIELD(kernel_layout).set_default("OIW")
.describe("Dimension ordering of weight. Can be 'OIW', or 'WIO', etc."
"'O', 'I', 'W' stands for num_filter, input_channel, and width"
"dimensions respectively.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};


/*! \brief Attributes used in convolution operators */
struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
Array<IndexExpr> strides;
Expand Down
51 changes: 28 additions & 23 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,22 +267,25 @@ class Conv(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
# infer pads for auto_pad
# Use shape of input to determine convolution type.
input_shape = infer_shape(inputs[0])

if 'auto_pad' in attr:
attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
input_shape = infer_shape(inputs[0])
in_h, in_w = input_shape[2], input_shape[3]
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']
dilation_h, dilation_w = attr['dilations']
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_v = get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_h = get_pad_pair(in_w, dilated_kernel_w, stride_w)
attr['pads'] = (pad_v[0], pad_h[0], pad_v[1], pad_h[1])
pad_tuple = []
for axis in range(len(input_shape) - 2):
axis_shape = input_shape[2 + axis]
stride = attr['strides'][axis]
kernel = attr['kernel_shape'][axis]
dilation = attr['dilations'][axis]
dilated_kernel = (kernel - 1) * dilation + 1
pad = get_pad_pair(axis_shape, dilated_kernel, stride)
pad_tuple.append(pad)
pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair])
attr['pads'] = pad_tuple
elif attr['auto_pad'] == 'VALID':
attr['pads'] = (0, 0)
attr['pads'] = tuple([0 for i in range(len(input_shape) - 2)])
elif attr['auto_pad'] == 'NOTSET':
pass
else:
Expand All @@ -294,10 +297,12 @@ def _impl_v1(cls, inputs, attr, params):
op_name=dimension_picker('conv'),
transforms={
'kernel_shape': 'kernel_size',
'dilations': ('dilation', (0, 0)),
'pads': ('padding', (0, 0), revert_caffe2_pad),
'group': ('groups', 1)},
'dilations': ('dilation', 1),
'pads': ('padding', 0),
'group': ('groups', 1)
},
custom_check=dimension_constraint())(inputs[:2], attr, params)

use_bias = len(inputs) == 3
if use_bias:
out = _op.nn.bias_add(out, inputs[2])
Expand Down Expand Up @@ -713,8 +718,8 @@ def _impl_v9(cls, inputs, attr, params):
else:
raise tvm.error.OpAttributeInvalid(
'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode))
attr = {'scale_h':scales[-2], 'scale_w':scales[-1], 'method':method,
'layout':'NCHW', 'align_corners':True}
attr = {'scale_h': scales[-2], 'scale_w': scales[-1], 'method': method,
'layout': 'NCHW', 'align_corners': True}
return AttrCvt('upsampling')(inputs, attr)


Expand Down Expand Up @@ -848,7 +853,7 @@ class Gather(OnnxOpConverter):
def _impl_v1(cls, inputs, attr, params):
axis = attr.get('axis', 0)
return AttrCvt('take',
extras={'axis':axis})(inputs, {})
extras={'axis': axis})(inputs, {})


class Greater(OnnxOpConverter):
Expand Down Expand Up @@ -880,7 +885,7 @@ def _impl_v1(cls, inputs, attr, params):
beta = attr.get('beta', 0.75)
bias = attr.get('bias', 1.0)
nsize = attr.get('size')
attr = {'size':nsize, 'axis':axis, 'alpha':alpha, 'beta':beta, 'bias':bias}
attr = {'size': nsize, 'axis': axis, 'alpha': alpha, 'beta': beta, 'bias': bias}
return AttrCvt('lrn')(inputs, attr)

class Maximum(OnnxOpConverter):
Expand Down Expand Up @@ -926,7 +931,7 @@ def _impl_v1(cls, inputs, attr, params):
alpha = attr.get('alpha', 0.2)
beta = attr.get('beta', 0.5)
transformX = (inputs[0] * _expr.const(alpha)) + _expr.const(beta)
attr = {'a_min':0, 'a_max':1}
attr = {'a_min': 0, 'a_max': 1}
return AttrCvt('clip')([transformX], attr)

class Reduce(OnnxOpConverter):
Expand All @@ -940,7 +945,7 @@ def _impl_v1(cls, inputs, attr, params):
else:
axis_len = len(infer_shape(inputs[0]))
axis = list(range(axis_len))
attr = {'axis':axis, 'keepdims':attr.get('keepdims', True)}
attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)}
return AttrCvt(cls.name)(inputs, attr)

class ReduceMax(Reduce):
Expand Down Expand Up @@ -975,7 +980,7 @@ class ArgMax(OnnxOpConverter):
def _impl_v1(cls, inputs, attr, params):
axis = attr.get('axis', 0)
keepdims = attr.get('keepdims', True)
attr = {'axis':axis, 'keepdims':keepdims}
attr = {'axis': axis, 'keepdims': keepdims}
return AttrCvt('argmax')(inputs, attr)

class ArgMin(OnnxOpConverter):
Expand All @@ -985,7 +990,7 @@ class ArgMin(OnnxOpConverter):
def _impl_v1(cls, inputs, attr, params):
axis = attr.get('axis', 0)
keepdims = attr.get('keepdims', True)
attr = {'axis':axis, 'keepdims':keepdims}
attr = {'axis': axis, 'keepdims': keepdims}
return AttrCvt('argmin')(inputs, attr)

class Softmax(OnnxOpConverter):
Expand Down
36 changes: 36 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,42 @@ def schedule_sparse_transpose(attrs, outputs, target):

reg.register_pattern("nn.sparse_transpose", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


# Conv1D
@reg.register_compute("nn.conv1d")
def compute_conv1d(attrs, inputs, out_type, target):
"""Compute definition of conv1d"""
strides = get_const_tuple(attrs.strides)
padding = get_const_tuple(attrs.padding)
dilation = get_const_tuple(attrs.dilation)
layout = attrs.data_layout
out_dtype = attrs.out_dtype
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype)

assert layout in ["NCW", "NWC"]
if dilation[0] < 1:
raise ValueError("dilation should be a positive value")

return [topi.nn.conv1d(inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype)]


@reg.register_schedule("nn.conv1d")
def schedule_conv1d(attrs, outs, target):
"""Schedule definition of conv1d"""
layout = attrs.data_layout

with target:
if layout == "NCW":
return topi.generic.schedule_conv1d_ncw(outs)
elif layout == "NCW":
return topi.generic.schedule_conv1d_nwc(outs)
raise ValueError("No compatible schedule")


reg.register_pattern("nn.conv1d", OpPattern.OUT_ELEMWISE_FUSABLE)


# conv2d
def _find_conv2d_op(op):
"""Find the op with conv2d in its tag by traversing."""
Expand Down
125 changes: 118 additions & 7 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,99 @@
from . import _make


def conv1d(data,
weight,
strides=1,
padding=0,
dilation=1,
groups=1,
channels=None,
kernel_size=None,
data_layout="NCW",
kernel_layout="OIW",
out_layout="",
out_dtype=""):
r"""1D convolution.
This operator takes the weight as the convolution kernel
and convolves it with data to produce an output.
In the default case, where the data_layout is `NCW`
and kernel_layout is `OIW`, conv1d takes in
a data Tensor with shape `(batch_size, in_channels, width)`,
and a weight Tensor with shape `(channels, in_channels, kernel_size)`
to produce an output Tensor with the following rule:
.. math::
\mbox{out}[b, c, w] = \sum_{dw, k}
\mbox{data}[b, k, \mbox{strides}[0] * w + dw] *
\mbox{weight}[c, k, dw]
Padding and dilation are applied to data and weight respectively before the computation.
This operator accepts data layout specification.
Semantically, the operator will convert the layout to the canonical layout
(`NCW` for data and `OIW` for weight), perform the computation,
then convert to the out_layout.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
weight : tvm.relay.Expr
The weight expressions.
strides : Optional[int, Tuple[int]]
The strides of convolution.
padding : Optional[int, Tuple[int]]
The padding of convolution on both sides of the input before convolution.
dilation : Optional[int, Tuple[int]]
Specifies the dilation rate to be used for dilated convolution.
groups : Optional[int]
Currently unused for 1D convolution.
channels : Optional[int]
Number of output channels of this convolution.
kernel_size : Optional[int, Tuple[int]]
The spatial dimension of the convolution kernel.
data_layout : Optional[str]
Layout of the input.
kernel_layout : Optional[str]
Layout of the weight.
out_layout : Optional[str]
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : Optional[str]
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
if isinstance(kernel_size, int):
kernel_size = (kernel_size, )
if isinstance(strides, int):
strides = (strides, )
if isinstance(dilation, int):
dilation = (dilation, )
if isinstance(padding, int):
padding = (padding, padding)
return _make.conv1d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)


def conv2d(data,
weight,
strides=(1, 1),
Expand Down Expand Up @@ -66,13 +159,13 @@ def conv2d(data,
weight : tvm.relay.Expr
The weight expressions.
strides : Optional[Tuple[int]]
strides : Optional[int, Tuple[int]]
The strides of convolution.
padding : Optional[Tuple[int]]
padding : Optional[int, Tuple[int]]
The padding of convolution on both sides of inputs before convolution.
dilation : Optional[Tuple[int]]
dilation : Optional[int, Tuple[int]]
Specifies the dilation rate to be used for dilated convolution.
groups : Optional[int]
Expand All @@ -81,7 +174,7 @@ def conv2d(data,
channels : Optional[int]
Number of output channels of this convolution.
kernel_size : Optional[Tuple[int]]
kernel_size : Optional[int, Tuple[int]]
The spatial of the convolution kernel.
data_layout : Optional[str]
Expand All @@ -101,6 +194,15 @@ def conv2d(data,
result : tvm.relay.Expr
The computed result.
"""
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(strides, int):
strides = (strides, strides)
if isinstance(dilation, int):
dilation = (dilation, dilation)
if isinstance(padding, int):
padding = (padding, padding)

return _make.conv2d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
Expand Down Expand Up @@ -154,10 +256,10 @@ def conv3d(data,
strides : Optional[Tuple[int]]
The strides of convolution.
padding : Optional[Tuple[int]]
padding : Optional[int, Tuple[int]]
The padding of convolution on both sides of inputs before convolution.
dilation : Optional[Tuple[int]]
dilation : Optional[int, Tuple[int]]
Specifies the dilation rate to be used for dilated convolution.
groups : Optional[int]
Expand All @@ -166,7 +268,7 @@ def conv3d(data,
channels : Optional[int]
Number of output channels of this convolution.
kernel_size : Optional[Tuple[int]]
kernel_size : Optional[int, Tuple[int]]
The spatial of the convolution kernel.
data_layout : Optional[str]
Expand All @@ -186,6 +288,15 @@ def conv3d(data,
result : tvm.relay.Expr
The computed result.
"""
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(strides, int):
strides = (strides, strides, strides)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
if isinstance(padding, int):
padding = (padding, padding, padding)

return _make.conv3d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
Expand Down
Loading

0 comments on commit 5fbbc5f

Please sign in to comment.