diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index a402eb1462d2..549eb67ebcb0 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -49,6 +49,54 @@ struct BiasAddAttrs : public tvm::AttrsNode { }; +/*! \brief Attributes used in 1D convolution operators */ +struct Conv1DAttrs : public tvm::AttrsNode { + Array strides; + Array padding; + Array dilation; + int groups; + IndexExpr channels; + Array kernel_size; + std::string data_layout; + std::string kernel_layout; + std::string out_layout; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Conv1DAttrs, "relay.attrs.Conv1DAttrs") { + TVM_ATTR_FIELD(strides).set_default(Array({1, })) + .describe("Specifies the stride of the convolution."); + TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) + .describe("If padding is non-zero, then the input is implicitly zero-padded" + "on both sides for padding number of points"); + TVM_ATTR_FIELD(dilation).set_default(Array({1, })) + .describe("Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).set_default(1) + .describe("Currently unused but may be added in the future."); + TVM_ATTR_FIELD(channels) + .describe("The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") + .set_default(NullValue()); + TVM_ATTR_FIELD(kernel_size) + .describe("Specifies the dimensions of the convolution window.") + .set_default(NullValue >()); + TVM_ATTR_FIELD(data_layout).set_default("NCW") + .describe("Dimension ordering of input data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Convolution is applied on the 'W'" + "dimension."); + TVM_ATTR_FIELD(kernel_layout).set_default("OIW") + .describe("Dimension ordering of weight. Can be 'OIW', or 'WIO', etc." + "'O', 'I', 'W' stands for num_filter, input_channel, and width" + "dimensions respectively."); + + // use 0 bits to indicate none. + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + } +}; + + /*! \brief Attributes used in convolution operators */ struct Conv2DAttrs : public tvm::AttrsNode { Array strides; diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 1236b0dbe054..1bd8673cd3aa 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -267,22 +267,25 @@ class Conv(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - # infer pads for auto_pad + # Use shape of input to determine convolution type. + input_shape = infer_shape(inputs[0]) + if 'auto_pad' in attr: attr['auto_pad'] = attr['auto_pad'].decode('utf-8') if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'): - input_shape = infer_shape(inputs[0]) - in_h, in_w = input_shape[2], input_shape[3] - stride_h, stride_w = attr['strides'] - kernel_h, kernel_w = attr['kernel_shape'] - dilation_h, dilation_w = attr['dilations'] - dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 - dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 - pad_v = get_pad_pair(in_h, dilated_kernel_h, stride_h) - pad_h = get_pad_pair(in_w, dilated_kernel_w, stride_w) - attr['pads'] = (pad_v[0], pad_h[0], pad_v[1], pad_h[1]) + pad_tuple = [] + for axis in range(len(input_shape) - 2): + axis_shape = input_shape[2 + axis] + stride = attr['strides'][axis] + kernel = attr['kernel_shape'][axis] + dilation = attr['dilations'][axis] + dilated_kernel = (kernel - 1) * dilation + 1 + pad = get_pad_pair(axis_shape, dilated_kernel, stride) + pad_tuple.append(pad) + pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair]) + attr['pads'] = pad_tuple elif attr['auto_pad'] == 'VALID': - attr['pads'] = (0, 0) + attr['pads'] = tuple([0 for i in range(len(input_shape) - 2)]) elif attr['auto_pad'] == 'NOTSET': pass else: @@ -294,10 +297,12 @@ def _impl_v1(cls, inputs, attr, params): op_name=dimension_picker('conv'), transforms={ 'kernel_shape': 'kernel_size', - 'dilations': ('dilation', (0, 0)), - 'pads': ('padding', (0, 0), revert_caffe2_pad), - 'group': ('groups', 1)}, + 'dilations': ('dilation', 1), + 'pads': ('padding', 0), + 'group': ('groups', 1) + }, custom_check=dimension_constraint())(inputs[:2], attr, params) + use_bias = len(inputs) == 3 if use_bias: out = _op.nn.bias_add(out, inputs[2]) @@ -713,8 +718,8 @@ def _impl_v9(cls, inputs, attr, params): else: raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode)) - attr = {'scale_h':scales[-2], 'scale_w':scales[-1], 'method':method, - 'layout':'NCHW', 'align_corners':True} + attr = {'scale_h': scales[-2], 'scale_w': scales[-1], 'method': method, + 'layout': 'NCHW', 'align_corners': True} return AttrCvt('upsampling')(inputs, attr) @@ -848,7 +853,7 @@ class Gather(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): axis = attr.get('axis', 0) return AttrCvt('take', - extras={'axis':axis})(inputs, {}) + extras={'axis': axis})(inputs, {}) class Greater(OnnxOpConverter): @@ -880,7 +885,7 @@ def _impl_v1(cls, inputs, attr, params): beta = attr.get('beta', 0.75) bias = attr.get('bias', 1.0) nsize = attr.get('size') - attr = {'size':nsize, 'axis':axis, 'alpha':alpha, 'beta':beta, 'bias':bias} + attr = {'size': nsize, 'axis': axis, 'alpha': alpha, 'beta': beta, 'bias': bias} return AttrCvt('lrn')(inputs, attr) class Maximum(OnnxOpConverter): @@ -926,7 +931,7 @@ def _impl_v1(cls, inputs, attr, params): alpha = attr.get('alpha', 0.2) beta = attr.get('beta', 0.5) transformX = (inputs[0] * _expr.const(alpha)) + _expr.const(beta) - attr = {'a_min':0, 'a_max':1} + attr = {'a_min': 0, 'a_max': 1} return AttrCvt('clip')([transformX], attr) class Reduce(OnnxOpConverter): @@ -940,7 +945,7 @@ def _impl_v1(cls, inputs, attr, params): else: axis_len = len(infer_shape(inputs[0])) axis = list(range(axis_len)) - attr = {'axis':axis, 'keepdims':attr.get('keepdims', True)} + attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)} return AttrCvt(cls.name)(inputs, attr) class ReduceMax(Reduce): @@ -975,7 +980,7 @@ class ArgMax(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): axis = attr.get('axis', 0) keepdims = attr.get('keepdims', True) - attr = {'axis':axis, 'keepdims':keepdims} + attr = {'axis': axis, 'keepdims': keepdims} return AttrCvt('argmax')(inputs, attr) class ArgMin(OnnxOpConverter): @@ -985,7 +990,7 @@ class ArgMin(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): axis = attr.get('axis', 0) keepdims = attr.get('keepdims', True) - attr = {'axis':axis, 'keepdims':keepdims} + attr = {'axis': axis, 'keepdims': keepdims} return AttrCvt('argmin')(inputs, attr) class Softmax(OnnxOpConverter): diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 4bfac9212f56..e405fee916dc 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -131,6 +131,42 @@ def schedule_sparse_transpose(attrs, outputs, target): reg.register_pattern("nn.sparse_transpose", reg.OpPattern.OUT_ELEMWISE_FUSABLE) + +# Conv1D +@reg.register_compute("nn.conv1d") +def compute_conv1d(attrs, inputs, out_type, target): + """Compute definition of conv1d""" + strides = get_const_tuple(attrs.strides) + padding = get_const_tuple(attrs.padding) + dilation = get_const_tuple(attrs.dilation) + layout = attrs.data_layout + out_dtype = attrs.out_dtype + out_dtype = (inputs[0].dtype if out_dtype in ("same", "") + else out_dtype) + + assert layout in ["NCW", "NWC"] + if dilation[0] < 1: + raise ValueError("dilation should be a positive value") + + return [topi.nn.conv1d(inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype)] + + +@reg.register_schedule("nn.conv1d") +def schedule_conv1d(attrs, outs, target): + """Schedule definition of conv1d""" + layout = attrs.data_layout + + with target: + if layout == "NCW": + return topi.generic.schedule_conv1d_ncw(outs) + elif layout == "NCW": + return topi.generic.schedule_conv1d_nwc(outs) + raise ValueError("No compatible schedule") + + +reg.register_pattern("nn.conv1d", OpPattern.OUT_ELEMWISE_FUSABLE) + + # conv2d def _find_conv2d_op(op): """Find the op with conv2d in its tag by traversing.""" diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index b7a091891ad3..2473aab63314 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -21,6 +21,99 @@ from . import _make +def conv1d(data, + weight, + strides=1, + padding=0, + dilation=1, + groups=1, + channels=None, + kernel_size=None, + data_layout="NCW", + kernel_layout="OIW", + out_layout="", + out_dtype=""): + r"""1D convolution. + + This operator takes the weight as the convolution kernel + and convolves it with data to produce an output. + + + In the default case, where the data_layout is `NCW` + and kernel_layout is `OIW`, conv1d takes in + a data Tensor with shape `(batch_size, in_channels, width)`, + and a weight Tensor with shape `(channels, in_channels, kernel_size)` + to produce an output Tensor with the following rule: + + .. math:: + + \mbox{out}[b, c, w] = \sum_{dw, k} + \mbox{data}[b, k, \mbox{strides}[0] * w + dw] * + \mbox{weight}[c, k, dw] + + Padding and dilation are applied to data and weight respectively before the computation. + This operator accepts data layout specification. + Semantically, the operator will convert the layout to the canonical layout + (`NCW` for data and `OIW` for weight), perform the computation, + then convert to the out_layout. + + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + weight : tvm.relay.Expr + The weight expressions. + + strides : Optional[int, Tuple[int]] + The strides of convolution. + + padding : Optional[int, Tuple[int]] + The padding of convolution on both sides of the input before convolution. + + dilation : Optional[int, Tuple[int]] + Specifies the dilation rate to be used for dilated convolution. + + groups : Optional[int] + Currently unused for 1D convolution. + + channels : Optional[int] + Number of output channels of this convolution. + + kernel_size : Optional[int, Tuple[int]] + The spatial dimension of the convolution kernel. + + data_layout : Optional[str] + Layout of the input. + + kernel_layout : Optional[str] + Layout of the weight. + + out_layout : Optional[str] + Layout of the output, by default, out_layout is the same as data_layout + + out_dtype : Optional[str] + Specifies the output data type for mixed precision conv2d. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + if isinstance(kernel_size, int): + kernel_size = (kernel_size, ) + if isinstance(strides, int): + strides = (strides, ) + if isinstance(dilation, int): + dilation = (dilation, ) + if isinstance(padding, int): + padding = (padding, padding) + return _make.conv1d(data, weight, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, out_dtype) + + def conv2d(data, weight, strides=(1, 1), @@ -66,13 +159,13 @@ def conv2d(data, weight : tvm.relay.Expr The weight expressions. - strides : Optional[Tuple[int]] + strides : Optional[int, Tuple[int]] The strides of convolution. - padding : Optional[Tuple[int]] + padding : Optional[int, Tuple[int]] The padding of convolution on both sides of inputs before convolution. - dilation : Optional[Tuple[int]] + dilation : Optional[int, Tuple[int]] Specifies the dilation rate to be used for dilated convolution. groups : Optional[int] @@ -81,7 +174,7 @@ def conv2d(data, channels : Optional[int] Number of output channels of this convolution. - kernel_size : Optional[Tuple[int]] + kernel_size : Optional[int, Tuple[int]] The spatial of the convolution kernel. data_layout : Optional[str] @@ -101,6 +194,15 @@ def conv2d(data, result : tvm.relay.Expr The computed result. """ + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if isinstance(strides, int): + strides = (strides, strides) + if isinstance(dilation, int): + dilation = (dilation, dilation) + if isinstance(padding, int): + padding = (padding, padding) + return _make.conv2d(data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, kernel_layout, out_layout, out_dtype) @@ -154,10 +256,10 @@ def conv3d(data, strides : Optional[Tuple[int]] The strides of convolution. - padding : Optional[Tuple[int]] + padding : Optional[int, Tuple[int]] The padding of convolution on both sides of inputs before convolution. - dilation : Optional[Tuple[int]] + dilation : Optional[int, Tuple[int]] Specifies the dilation rate to be used for dilated convolution. groups : Optional[int] @@ -166,7 +268,7 @@ def conv3d(data, channels : Optional[int] Number of output channels of this convolution. - kernel_size : Optional[Tuple[int]] + kernel_size : Optional[int, Tuple[int]] The spatial of the convolution kernel. data_layout : Optional[str] @@ -186,6 +288,15 @@ def conv3d(data, result : tvm.relay.Expr The computed result. """ + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if isinstance(strides, int): + strides = (strides, strides, strides) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) + if isinstance(padding, int): + padding = (padding, padding, padding) + return _make.conv3d(data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, kernel_layout, out_layout, out_dtype) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index e5a9a11fb012..2da35daba225 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -19,6 +19,12 @@ from ...attrs import Attrs from ..base import register_relay_attr_node + +@register_relay_attr_node +class Conv1DAttrs(Attrs): + """Attributes for nn.conv1d""" + + @register_relay_attr_node class Conv2DAttrs(Attrs): """Attributes for nn.conv2d""" diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 627c42041992..6c3fb6187b43 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -34,8 +34,6 @@ namespace tvm { namespace relay { -// relay.nn.conv2d -TVM_REGISTER_NODE_TYPE(Conv2DAttrs); template Array > ConvInferCorrectLayout( @@ -52,21 +50,22 @@ Array > ConvInferCorrectLayout( params->data_layout : params->out_layout}}; } -// Positional relay function to create conv2d operator -// used by frontend FFI. -Expr MakeConv2D(Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - auto attrs = make_object(); + +template +Expr MakeConv(Expr data, + Expr weight, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + DataType out_dtype, + std::string op_name) { + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); @@ -77,13 +76,77 @@ Expr MakeConv2D(Expr data, attrs->kernel_layout = std::move(kernel_layout); attrs->out_layout = std::move(out_layout); attrs->out_dtype = std::move(out_dtype); - static const Op& op = Op::Get("nn.conv2d"); + static const Op& op = Op::Get(op_name); return CallNode::make(op, {data, weight}, Attrs(attrs), {}); } +// relay.nn.conv1d +TVM_REGISTER_NODE_TYPE(Conv1DAttrs); + +TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d") +.set_body_typed([](Expr data, + Expr weight, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + DataType out_dtype) { + return MakeConv( + data, weight, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, out_dtype, "nn.conv1d"); +}); + + +RELAY_REGISTER_OP("nn.conv1d") +.describe(R"code(1D convolution layer (e.g. spatial convolution over sequences). + +This layer creates a convolution kernel that is convolved +with the layer input to produce a tensor of outputs. + +- **data**: This depends on the `layout` parameter. Input is 3D array of shape + (batch_size, in_channels, width) if `layout` is `NCW`. +- **weight**: (channels, in_channels, kernel_size) +- **out**: This depends on the `layout` parameter. Output is 3D array of shape + (batch_size, channels, out_width) if `layout` is `NCW`. + +)code" TVM_ADD_FILELINE) +.set_attrs_type() +.set_num_inputs(2) +.add_argument("data", "Tensor", "The input tensor.") +.add_argument("weight", "Tensor", "The weight tensor.") +.set_support_level(2) +.add_type_rel("Conv1D", Conv1DRel) +.set_attr("FInferCorrectLayout", ConvInferCorrectLayout); + + +// relay.nn.conv2d +TVM_REGISTER_NODE_TYPE(Conv2DAttrs); + TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d") -.set_body_typed(MakeConv2D); +.set_body_typed([](Expr data, + Expr weight, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + DataType out_dtype) { + return MakeConv( + data, weight, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, out_dtype, "nn.conv2d"); +}); RELAY_REGISTER_OP("nn.conv2d") @@ -110,38 +173,24 @@ with the layer input to produce a tensor of outputs. // relay.nn.conv3d TVM_REGISTER_NODE_TYPE(Conv3DAttrs); -// Positional relay function to create conv3d operator -// used by frontend FFI. -Expr MakeConv3D(Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - auto attrs = make_object(); - attrs->strides = std::move(strides); - attrs->padding = std::move(padding); - attrs->dilation = std::move(dilation); - attrs->groups = groups; - attrs->channels = std::move(channels); - attrs->kernel_size = std::move(kernel_size); - attrs->data_layout = std::move(data_layout); - attrs->kernel_layout = std::move(kernel_layout); - attrs->out_layout = std::move(out_layout); - attrs->out_dtype = std::move(out_dtype); - static const Op& op = Op::Get("nn.conv3d"); - return CallNode::make(op, {data, weight}, Attrs(attrs), {}); -} - - TVM_REGISTER_GLOBAL("relay.op.nn._make.conv3d") -.set_body_typed(MakeConv3D); +.set_body_typed([](Expr data, + Expr weight, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + DataType out_dtype) { + return MakeConv( + data, weight, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, out_dtype, "nn.conv3d"); +}); RELAY_REGISTER_OP("nn.conv3d") diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index b61942d790e4..9e8f4b55d26e 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -33,6 +33,94 @@ namespace tvm { namespace relay { +template +bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr) return false; + static const Layout kNCW("NCW"); + static const Layout kOIW("OIW"); + + const AttrType* param = attrs.as(); + CHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCW); + CHECK(trans_in_layout.defined()) + << "Conv only support input layouts that are convertible from NCW." + << " But got " << in_layout; + + const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIW); + CHECK(trans_kernel_layout.defined()) + << "Conv only support kernel layouts that are convertible from OIW." + << " But got " << kernel_layout; + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCW); + CHECK(trans_out_layout.defined()) + << "Conv only support output layouts that are convertible from NCW." + << " But got " << out_layout; + + Array dshape_ncw = trans_in_layout.ForwardShape(data->shape); + + IndexExpr channels, dilated_ksize; + // infer weight if the kernel_size and channels are defined + if (param->kernel_size.defined() && param->channels.defined()) { + Array wshape; + + wshape = {{param->channels, dshape_ncw[1], param->kernel_size[0]}}; + + wshape = trans_kernel_layout.BackwardShape(wshape); + channels = param->channels; + dilated_ksize = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + DataType weight_dtype = data->dtype; + if (weight != nullptr) { + weight_dtype = weight->dtype; + } + // assign result to reporter + reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype)); + } else { + // use weight to infer the conv shape. + if (weight == nullptr) return false; + auto wshape = trans_kernel_layout.ForwardShape(weight->shape); + if (param->kernel_size.defined()) { + // check the size + CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) ) + << "Conv1D: shape of weight is inconsistent with kernel_size, " + << " kernel_size=" << param->kernel_size << " wshape=" << wshape; + } + if (param->channels.defined()) { + CHECK(reporter->AssertEQ(param->channels, wshape[0])) + << "Conv1D: shape of weight is inconsistent with channels, " + << " channels=" << param->channels << " wshape=" << wshape; + } + CHECK(reporter->AssertEQ(dshape_ncw[1], wshape[1])); + channels = wshape[0]; + dilated_ksize = 1 + (wshape[2] - 1) * param->dilation[0]; + } + // dilation + Array oshape({dshape_ncw[0], channels, 0}); + + if (!dshape_ncw[2].as()) { + oshape.Set(2, indexdiv(dshape_ncw[2] + param->padding[0] + param->padding[1] - dilated_ksize, + param->strides[0]) + 1); + } else { + oshape.Set(2, dshape_ncw[2]); + } + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = trans_out_layout.BackwardShape(oshape); + // assign output type + reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + return true; +} + template bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 386adf872ff1..7eb09493df8c 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1732,16 +1732,27 @@ def test_or(): verify_or(indata=[x, y], dtype=bool) -def verify_conv(x_shape, w_shape, y_shape, p): - node = helper.make_node('Conv', - inputs=['x', 'W'], - outputs=['y'], - kernel_shape=[3, 3], - # Default values for other attributes: - # strides=[1, 1], - # dilations=[1, 1], - # groups=1 - pads=p,) +def verify_conv(x_shape, w_shape, y_shape, padding, kernel_shape, strides, dilations, auto_pad="NOTSET"): + if padding is None: + node = helper.make_node('Conv', + inputs=['x', 'W'], + outputs=['y'], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + # groups=1 + auto_pad=auto_pad) + else: + node = helper.make_node('Conv', + inputs=['x', 'W'], + outputs=['y'], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + # groups=1 + pads=padding) graph = helper.make_graph([node], 'conv_test', @@ -1761,18 +1772,35 @@ def verify_conv(x_shape, w_shape, y_shape, p): def test_conv(): # Convolution with padding - # (1, 1, 5, 5) input tensor - # (1, 1, 3, 3) tensor for convolution weights - # (1, 1, 5, 5) output tensor - # [1, 1, 1, 1] list for pads - verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 5, 5), [1, 1, 1, 1]) + # Conv2D + verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 5, 5), [1, 1, 1, 1], [3, 3], [1, 1], [1, 1]) + # Conv1D + verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 5), [1, 1], [3], [1], [1]) # Convolution without padding - # (1, 1, 5, 5) input tensor - # (1, 1, 3, 3) tensor for convolution weights - # (1, 1, 3, 3) output tensor - # [0, 0, 0, 0] list for pads - verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 3, 3), [0, 0, 0, 0]) + # Conv2D + verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 3, 3), [0, 0, 0, 0], [3, 3], [1, 1], [1, 1]) + # Conv1D + verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 3), [0, 0], [3], [1], [1]) + + # Convolution with autopadding + verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 5, 5), + None, [3, 3], [1, 1], [1, 1], + auto_pad="SAME_UPPER") + # Conv1D + verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 5), None, [3], [1], [1], auto_pad="SAME_UPPER") + + # Convolution with non uniform stride + verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 3, 3), + None, [3, 3], [2, 2], [1, 1], + auto_pad="SAME_UPPER") + # Conv1D + verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 3), None, [3], [2], [1], auto_pad="SAME_UPPER") + + # Convolution with dilation + verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 5, 5), [2, 2, 2, 2], [3, 3], [1, 1], [2, 2]) + # Conv1D + verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 5), [2, 2], [3], [1], [2]) def verify_convtranspose(x_shape, w_shape, y_shape, p): @@ -1838,15 +1866,15 @@ def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_p raise ValueError("Pool method {} is not supported.".format(mode)) if pads is None: - pool_node = helper.make_node(node_type, - inputs=["x"], + pool_node = helper.make_node(node_type, + inputs=["x"], outputs=["y"], kernel_shape=kernel_shape, auto_pad=auto_pad, strides=strides) else: - pool_node = helper.make_node(node_type, - inputs=["x"], + pool_node = helper.make_node(node_type, + inputs=["x"], outputs=["y"], kernel_shape=kernel_shape, pads=pads, @@ -1867,6 +1895,7 @@ def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_p model, [x_np], target, ctx, out_shape) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) + def test_pooling(): for mode in ['max', 'average']: # Pool1D diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 6faad9997e4e..68f398396c05 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -31,6 +31,101 @@ def run_infer_type(expr): entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body + +def test_conv1d_infer_type(): + # symbolic in batch dimension + n, c, w = tvm.var("n"), 10, 224 + x = relay.var("x", relay.ty.TensorType((n, c, w), "float32")) + w = relay.var("w") + y = relay.nn.conv1d(x, w, + kernel_size=3, + padding=(1, 1), + channels=2) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType( + (n, 2, 224), "float32") + assert yy.args[1].checked_type == relay.TensorType( + (2, 10, 3), "float32") + + # infer by shape of w, mixed precision + n, c, w = tvm.var("n"), 10, 224 + x = relay.var("x", relay.TensorType((n, c, w), "int8")) + w = relay.var("w", relay.TensorType((2, 10, 3), "int8")) + y = relay.nn.conv1d(x, w, out_dtype="int32") + assert "out_dtype=\"int32\"" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType( + (n, 2, 222), "int32") + + # infer shape in case of different dtypes for input and weight. + n, c, w = tvm.var("n"), 10, 224 + x = relay.var("x", relay.TensorType((n, c, w), "uint8")) + w = relay.var("w", relay.TensorType((2, 10, 3), "int8")) + y = relay.nn.conv1d(x, w, out_dtype="int32") + assert "out_dtype=\"int32\"" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType( + (n, 2, 222), "int32") + + # Infer with NWC + n, c, w = 4, 32, 224 + x = relay.var("x", relay.TensorType((n, w, c), "int8")) + wt = relay.var("w") + y = relay.nn.conv1d(x, wt, + kernel_size=3, + padding=(1, 1), + channels=16, + data_layout="NWC", + out_dtype="int32") + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType( + (n, w, 16), "int32") + + +def test_conv1d_run(): + def run_test_conv1d(dtype, out_dtype, scale, dshape, kshape, + padding=(1, 1), + fref=None, + dilation=1, + except_targets=None, + **attrs): + if except_targets is None: + except_targets = [] + + x = relay.var("x", shape=dshape, dtype=dtype) + w = relay.var("w", dtype=dtype) + y = relay.nn.conv1d(x, w, + padding=padding, + dilation=dilation, + **attrs) + func = relay.Function([x, w], y) + data = np.random.uniform(-scale, scale, size=dshape).astype(dtype) + kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype) + ref_res = topi.testing.conv1d_ncw_python( + data.astype(out_dtype), kernel.astype(out_dtype), 1, padding, dilation) + + for target, ctx in ctx_list(): + if target in except_targets: + continue + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data, kernel) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + # normal conv1d + dshape = (1, 3, 224) + kshape = (10, 3, 3) + run_test_conv1d("float32", "float32", 1, dshape, kshape, + padding=(1, 1), channels=10, kernel_size=3) + # mixed precision + run_test_conv1d("int8", "int32", 1, dshape, kshape, + padding=(1, 1), channels=10, kernel_size=3) + # dilated conv2d + dshape = (1, 3, 18) + kshape = (10, 3, 3) + run_test_conv1d("float32", "float32", 1, dshape, kshape, + padding=(1, 1), channels=10, kernel_size=3, dilation=3) + + def test_conv2d_infer_type(): # symbolic in batch dimension n, c, h, w = tvm.var("n"), 10, 224, 224 @@ -1114,6 +1209,7 @@ def test_bitpack_infer_type(): test_avg_pool2d_no_count_pad() test_lrn() test_l2_normalize() + test_conv1d_infer_type() test_conv2d_infer_type() test_conv3d_infer_type() test_bitpack_infer_type() @@ -1126,6 +1222,7 @@ def test_bitpack_infer_type(): test_conv2d_transpose_nchw_run() test_conv2d_transpose_nhwc_run() test_conv1d_transpose_ncw_run() + test_conv1d_run() test_conv2d_run() test_conv2d_winograd() test_conv3d_run() diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 55255f42a05e..42af2aed5099 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -19,8 +19,8 @@ """CUDA specific declaration and schedules.""" from __future__ import absolute_import as _abs -from . import conv2d, depthwise_conv2d, conv2d_transpose_nchw, deformable_conv2d, \ - group_conv2d_nchw, dense, conv1d_transpose_ncw +from . import conv1d, conv2d, depthwise_conv2d, conv2d_transpose_nchw, \ + deformable_conv2d, group_conv2d_nchw, dense, conv1d_transpose_ncw from . import conv3d from .conv2d_hwcn import schedule_conv2d_hwcn from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc diff --git a/topi/python/topi/cuda/conv1d.py b/topi/python/topi/cuda/conv1d.py new file mode 100644 index 000000000000..201921564cbf --- /dev/null +++ b/topi/python/topi/cuda/conv1d.py @@ -0,0 +1,308 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Compute definition for conv1d with cuda backend""" +import tvm +from tvm import autotvm + +from .. import nn, generic +from ..util import traverse_inline, get_const_tuple + + +@autotvm.register_topi_compute(nn.conv1d, ['cuda', 'gpu'], ['direct']) +def conv1d_cuda(cfg, + data, + kernel, + strides, + padding, + dilation, + layout='NCW', + out_dtype='float32'): + """ 1D convolution forward operator for cuda backend. + + Parameters + ---------- + cfg : ConfigEntity + The config for this template + + data : tvm.Tensor + 3-D input shape [batch, in_channel, in_width] for layout == 'NCW' + and [batch, in_width, in_channel] for layout == 'NWC' + + kernel : tvm.Tensor + 3-D kernel with shape [num_filter, in_channel, filter_size] for layout == 'NCW' + and [filter_size, in_channel, num_filter] for layout == 'NWC' + + strides : int or tuple + The spatial stride along width + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + dilation : int or tuple + Dilation rate if convolution should be dilated. + + layout : str + How input data is laid out, must be one of ['NCW', 'NWC'] + + out_dtype : str + The output data type. If None then output is same type as input. + """ + if out_dtype is None: + out_dtype = data.dtype + if isinstance(strides, (tuple, list)): + strides = strides[0] + if isinstance(dilation, (tuple, list)): + dilation = dilation[0] + + if layout == 'NCW': + return nn.conv1d_ncw(data, kernel, strides, padding, dilation, + out_dtype) + if layout == 'NWC': + return nn.conv1d_nwc(data, kernel, strides, padding, dilation, + out_dtype) + raise ValueError("This layout is not yet supported: {}".format(layout)) + + +@autotvm.register_topi_schedule(generic.schedule_conv1d_ncw, ["cuda", "gpu"], + ["direct"]) +def schedule_conv1d_ncw(cfg, outs): + """TOPI schedule callback of conv1d ncw for cuda gpu + + Parameters + ---------- + cfg : ConfigEntity + the config for this template. + + outs : Array of Tensor + The computation graph description of conv1d + in the format of an array of tensors. + + Returns + ------- + s : Schedule + The computation schedule for conv1d. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == 'conv1d_ncw': + pad_data = op.input_tensors[0] + kernel = op.input_tensors[1] + conv = op.output(0) + + ##### space definition begin ##### + n, f, x = s[conv].op.axis + rc = s[conv].op.reduce_axis[0] + cfg.define_split("tile_n", cfg.axis(n), num_outputs=4) + cfg.define_split("tile_f", cfg.axis(f), num_outputs=4) + cfg.define_split("tile_x", cfg.axis(x), num_outputs=4) + cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3) + cfg.define_knob("auto_unroll_max_step", [64, 512, 1500]) + + target = tvm.target.current_target() + if target.target_name in ['nvptx', 'rocm']: + cfg.define_knob("unroll_explicit", [1]) + else: + cfg.define_knob("unroll_explicit", [0, 1]) + + ##### space definition end ##### + + if isinstance(kernel.op, + tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag: + s[kernel].compute_inline() + + if conv.op in s.outputs: + output = conv + OL = s.cache_write(conv, 'local') + else: + output = s.outputs[0].output(0) + s[conv].set_scope('local') + OL = conv + + # create cache stage + s[pad_data].set_scope('shared') + AA = pad_data + WW = s.cache_read(kernel, 'shared', [OL]) + + # tile and bind spatial axes + n, f, x = s[output].op.axis + kernel_scope, n = s[output].split(n, nparts=1) + bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n) + bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) + bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) + + s[output].reorder(bn, bf, bx, vn, vf, vx, tn, tf, tx, ni, fi, xi) + s[output].bind(bn, tvm.thread_axis("blockIdx.z")) + s[output].bind(bf, tvm.thread_axis("blockIdx.y")) + s[output].bind(bx, tvm.thread_axis("blockIdx.x")) + s[output].bind(vn, tvm.thread_axis("vthread")) + s[output].bind(vf, tvm.thread_axis("vthread")) + s[output].bind(vx, tvm.thread_axis("vthread")) + + s[output].bind(tx, tvm.thread_axis("threadIdx.x")) + s[OL].compute_at(s[output], tx) + # number of threads + n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2] + n_tx = cfg["tile_x"].size[2] + + # tile reduction axes + n, f, x = s[OL].op.axis + rc, rx = s[OL].op.reduce_axis + rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc) + s[OL].reorder(rco, rcm, rx, rci, n, f, x) + + s[AA].compute_at(s[OL], rx) + s[WW].compute_at(s[OL], rx) + + # cooperative fetching + for load in [AA, WW]: + n, f, x = s[load].op.axis + fused = s[load].fuse(f, x) + tz, fused = s[load].split(fused, nparts=n_tz) + tx, fused = s[load].split(fused, nparts=n_tx) + s[load].bind(tz, tvm.thread_axis("threadIdx.y")) + s[load].bind(tx, tvm.thread_axis("threadIdx.x")) + + s[output].pragma(kernel_scope, 'auto_unroll_max_step', + cfg['auto_unroll_max_step'].val) + s[output].pragma(kernel_scope, 'unroll_explicit', + cfg['unroll_explicit'].val) + + N, CO, OW = get_const_tuple(output.shape) + _, CI, KW = get_const_tuple(kernel.shape) + cfg.add_flop(2 * N * OW * CO * KW * CI) + + traverse_inline(s, outs[0].op, _callback) + + return s + + +@autotvm.register_topi_schedule(generic.schedule_conv1d_nwc, ["cuda", "gpu"], + ["direct"]) +def schedule_conv1d_nwc(cfg, outs): + """TOPI schedule callback of conv1d nwc for cuda gpu + + Parameters + ---------- + cfg : ConfigEntity + the config for this template. + + outs : Array of Tensor + The computation graph description of conv1d + in the format of an array of tensors. + + Returns + ------- + s : Schedule + The computation schedule for conv1d. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == 'conv1d_nwc': + pad_data = op.input_tensors[0] + kernel = op.input_tensors[1] + conv = op.output(0) + + ##### space definition begin ##### + n, x, f = s[conv].op.axis + rc = s[conv].op.reduce_axis[0] + cfg.define_split("tile_n", cfg.axis(n), num_outputs=4) + cfg.define_split("tile_x", cfg.axis(x), num_outputs=4) + cfg.define_split("tile_f", cfg.axis(f), num_outputs=4) + cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3) + cfg.define_knob("auto_unroll_max_step", [64, 512, 1500]) + + target = tvm.target.current_target() + if target.target_name in ['nvptx', 'rocm']: + cfg.define_knob("unroll_explicit", [1]) + else: + cfg.define_knob("unroll_explicit", [0, 1]) + + ##### space definition end ##### + + if isinstance(kernel.op, + tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag: + s[kernel].compute_inline() + + if conv.op in s.outputs: + output = conv + OL = s.cache_write(conv, 'local') + else: + output = s.outputs[0].output(0) + s[conv].set_scope('local') + OL = conv + + # create cache stage + s[pad_data].set_scope('shared') + AA = pad_data + WW = s.cache_read(kernel, 'shared', [OL]) + + # tile and bind spatial axes + n, f, x = s[output].op.axis + kernel_scope, n = s[output].split(n, nparts=1) + bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n) + bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) + bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) + + s[output].reorder(bn, bx, bf, vn, vx, vf, tn, tx, tf, ni, xi, fi) + s[output].bind(bn, tvm.thread_axis("blockIdx.z")) + s[output].bind(bx, tvm.thread_axis("blockIdx.y")) + s[output].bind(bf, tvm.thread_axis("blockIdx.x")) + s[output].bind(vn, tvm.thread_axis("vthread")) + s[output].bind(vx, tvm.thread_axis("vthread")) + s[output].bind(vf, tvm.thread_axis("vthread")) + + s[output].bind(tf, tvm.thread_axis("threadIdx.x")) + s[OL].compute_at(s[output], tf) + # number of threads + n_tz = cfg["tile_n"].size[2] * cfg["tile_x"].size[2] + n_tx = cfg["tile_f"].size[2] + + # tile reduction axes + n, x, f = s[OL].op.axis + rc, rx = s[OL].op.reduce_axis + rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc) + s[OL].reorder(rco, rcm, rx, rci, n, x, f) + + s[AA].compute_at(s[OL], rx) + s[WW].compute_at(s[OL], rx) + + # cooperative fetching + for load in [AA, WW]: + n, x, f = s[load].op.axis + fused = s[load].fuse(x, f) + tz, fused = s[load].split(fused, nparts=n_tz) + tx, fused = s[load].split(fused, nparts=n_tx) + s[load].bind(tz, tvm.thread_axis("threadIdx.y")) + s[load].bind(tx, tvm.thread_axis("threadIdx.x")) + + s[output].pragma(kernel_scope, 'auto_unroll_max_step', + cfg['auto_unroll_max_step'].val) + s[output].pragma(kernel_scope, 'unroll_explicit', + cfg['unroll_explicit'].val) + + N, OW, CO = get_const_tuple(output.shape) + KW, CI, _ = get_const_tuple(kernel.shape) + cfg.add_flop(2 * N * OW * CO * KW * CI) + + traverse_inline(s, outs[0].op, _callback) + + return s diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 980db65d9b8d..50154d3ee40b 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -34,6 +34,42 @@ def _default_schedule(outs, auto_inline): return s +@tvm.target.generic_func +def schedule_conv1d_ncw(outs): + """Schedule for conv1d_ncw + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of conv1d_ncw + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + +@tvm.target.generic_func +def schedule_conv1d_nwc(outs): + """Schedule for conv1d_nwc + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of conv1d_nwc + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + @tvm.target.generic_func def schedule_conv2d_hwcn(outs): """Schedule for conv2d_hwcn diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py index b805b7c64919..4f0151b6a801 100644 --- a/topi/python/topi/nn/__init__.py +++ b/topi/python/topi/nn/__init__.py @@ -19,6 +19,7 @@ """Neural network operators""" from __future__ import absolute_import as _abs +from .conv1d import * from .conv2d import * from .conv3d import * from .deformable_conv2d import * diff --git a/topi/python/topi/nn/conv1d.py b/topi/python/topi/nn/conv1d.py new file mode 100644 index 000000000000..98fa2e3d7001 --- /dev/null +++ b/topi/python/topi/nn/conv1d.py @@ -0,0 +1,186 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-variable, unused-argument +"""1D convolution operators.""" +from __future__ import absolute_import as _abs +import tvm +from .pad import pad +from ..util import simplify +from .util import get_pad_tuple1d + + +@tvm.target.generic_func +def conv1d(data, + kernel, + strides=1, + padding='VALID', + dilation=1, + layout='NCW', + out_dtype=None): + """ 1D convolution forward operator. + + Parameters + ---------- + data : tvm.Tensor + 3-D input shape [batch, in_channel, in_width] for layout == 'NCW' + and [batch, in_width, in_channel] for layout == 'NWC' + + kernel : tvm.Tensor + 3-D kernel with shape [num_filter, in_channel, filter_size] for layout == 'NCW' + and [filter_size, in_channel, num_filter] for layout == 'NWC' + + strides : int or tuple + The spatial stride along width + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + dilation : int or tuple + Dilation rate if convolution should be dilated. + + layout : str + How input data is laid out, must be one of ['NCW', 'NWC'] + + out_dtype : str + The output data type. If None then output is same type as input. + """ + if out_dtype is None: + out_dtype = data.dtype + if isinstance(strides, (tuple, list)): + strides = strides[0] + if isinstance(dilation, (tuple, list)): + dilation = dilation[0] + + if layout == 'NCW': + return conv1d_ncw(data, kernel, strides, padding, dilation, out_dtype) + if layout == 'NWC': + return conv1d_nwc(data, kernel, strides, padding, dilation, out_dtype) + raise ValueError("This layout is not yet supported: {}".format(layout)) + + +def conv1d_ncw(data, + kernel, + strides=1, + padding='VALID', + dilation=1, + out_dtype=None): + """ 1D convolution forward operator for NCW layout. + + Parameters + ---------- + data : tvm.Tensor + 3-D with shape [batch, in_channel, in_width] + + kernel : tvm.Tensor + 3-D with shape [num_filter, in_channel, filter_size] + + strides : int or tuple + The spatial stride along width + + padding : int, tuple, or str + Padding size can be an integer for equal padding, + a tuple of (left, right) or a string in ['VALID', 'SAME']. + + dilation : int or tuple + Dilation rate if convolution should be dilated. + + out_dtype : str + The output data type. If None then output is same type as input. + """ + batch, in_channels, data_width = data.shape + out_channels, _, kernel_size = kernel.shape + + # Compute the output shape + dilated_kernel_size = (kernel_size - 1) * dilation + 1 + pad_left, pad_right = get_pad_tuple1d(padding, (dilated_kernel_size, )) + out_channels = simplify(out_channels) + out_width = simplify( + (data_width - dilated_kernel_size + pad_left + pad_right) // strides + 1) + + # Apply padding + pad_before = [0, 0, pad_left] + pad_after = [0, 0, pad_right] + temp = pad(data, pad_before, pad_after, name='pad_temp') + + # Compute graph + rc = tvm.reduce_axis((0, in_channels), name='rc') + rw = tvm.reduce_axis((0, kernel_size), name='rw') + + return tvm.compute( + (batch, out_channels, out_width), + lambda b, c, w: tvm.sum( + temp[b, rc, w * strides + rw * dilation].astype(out_dtype) + * kernel[c, rc, rw].astype(out_dtype), + axis=[rc, rw]), + tag="conv1d_ncw") + + +def conv1d_nwc(data, + kernel, + strides=1, + padding='VALID', + dilation=1, + out_dtype=None): + """ 1D convolution forward operator for NWC layout. + + Parameters + ---------- + data : tvm.Tensor + 3-D with shape [batch, in_width, in_channel] + + kernel : tvm.Tensor + 3-D with shape [filter_size, in_channel, num_filter] + + strides : int or tuple + The spatial stride along width + + padding : int, tuple, or str + Padding size can be an integer for equal padding, + a tuple of (left, right) or a string in ['VALID', 'SAME']. + + dilation : int or tuple + Dilation rate if convolution should be dilated. + + out_dtype : str + The output data type. If None then output is same type as input. + """ + batch, data_width, in_channels = data.shape + kernel_size, _, out_channels = kernel.shape + + # Compute the output shape + dilated_kernel_size = (kernel_size - 1) * dilation + 1 + pad_left, pad_right = get_pad_tuple1d(padding, (dilated_kernel_size, )) + out_channels = simplify(out_channels) + out_width = simplify( + (data_width - dilated_kernel_size + pad_left + pad_right) // strides + 1) + + # Apply padding + pad_before = [0, pad_left, 0] + pad_after = [0, pad_right, 0] + temp = pad(data, pad_before, pad_after, name='pad_temp') + + # Compute graph + rc = tvm.reduce_axis((0, in_channels), name='rc') + rw = tvm.reduce_axis((0, kernel_size), name='rw') + + return tvm.compute( + (batch, out_width, out_channels), + lambda b, w, c: tvm.sum( + temp[b, w * strides + rw * dilation, rc].astype(out_dtype) + * kernel[rw, rc, c].astype(out_dtype), + axis=[rc, rw]), + tag="conv1d_nwc") diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 87e48ff00600..91b7dc5bc60c 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -21,6 +21,7 @@ """ from __future__ import absolute_import as _abs +from .conv1d_ncw_python import conv1d_ncw_python from .conv2d_hwcn_python import conv2d_hwcn_python from .conv2d_nchw_python import conv2d_nchw_python from .conv2d_nhwc_python import conv2d_nhwc_python diff --git a/topi/python/topi/testing/conv1d_ncw_python.py b/topi/python/topi/testing/conv1d_ncw_python.py new file mode 100644 index 000000000000..90ee7de66808 --- /dev/null +++ b/topi/python/topi/testing/conv1d_ncw_python.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-variable, invalid-name +"""1D convolution in python""" +import numpy as np +from topi.nn.util import get_pad_tuple1d + + +def dilate_np(x, dilation): + """ 1D dilation using numpy + + Parameters + ---------- + x : numpy.ndarray + Array to dilate with shape [batch, in_channel, in_width] + + dilation : int + dilation rate of output + + Returns + ------- + out : numpy.ndarray + Dilated output with shape [batch, in_channel, (in_width - 1) * dilation + 1] + """ + irange = range(len(x) - 1) + for d in range(dilation - 1): + indices = [(d + 1)*(i + 1) for i in irange] + x = np.insert(x, indices, 0) + return x + + +def conv1d_ncw_python(a_np, w_np, stride, padding, dilation): + """1D convolution operator in NCW layout + + Parameters + ---------- + a_np : numpy.ndarray + 3-D with shape [batch, in_channel, in_width] + + w_np : numpy.ndarray + 3-D with shape [num_filter, in_channel, filter_width] + + stride : int + Stride size + + padding : int, tuple, or str + Single int for padding size or tuple of (left, right) padding + or a string in ['VALID', 'SAME'] + + dilation : int + Dilation rate of the kernel + + Returns + ------- + b_np : numpy.ndarray + 3-D with shape [batch, out_channel, out_width] + """ + batch, in_c, in_w = a_np.shape + out_c, _, filter_w = w_np.shape + if isinstance(stride, (tuple, list)): + stride = stride[0] + if isinstance(dilation, (tuple, list)): + dilation = dilation[0] + + dilated_filter_w = (filter_w - 1) * dilation + 1 + pad_left, pad_right = get_pad_tuple1d(padding, (dilated_filter_w,)) + out_w = ((in_w - dilated_filter_w + pad_left + pad_right) // stride) + 1 + + padded_a_np = np.zeros((batch, in_c, in_w + pad_left + pad_right)) + padded_a_np[:, :, pad_left:(in_w + pad_left)] = a_np + + b_np = np.zeros((batch, out_c, out_w)) + for n in range(batch): + for f in range(out_c): + for c in range(in_c): + out = np.convolve( + padded_a_np[n, c], np.flip(dilate_np(w_np[f, c], dilation)), mode='valid') + b_np[n, f] += out[::stride] + return b_np diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py index 6e41a1709f9c..af7f97415242 100644 --- a/topi/python/topi/x86/__init__.py +++ b/topi/python/topi/x86/__init__.py @@ -19,6 +19,7 @@ """x86 specific declaration and schedules.""" from __future__ import absolute_import as _abs +from .conv1d import schedule_conv1d_nwc from .conv2d import schedule_conv2d, schedule_conv2d_nhwc from .binarize_pack import schedule_binarize_pack from .binary_dense import schedule_binary_dense diff --git a/topi/python/topi/x86/conv1d.py b/topi/python/topi/x86/conv1d.py new file mode 100644 index 000000000000..95fd159acd47 --- /dev/null +++ b/topi/python/topi/x86/conv1d.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name +"""Conv1D schedule on for Intel CPU""" +from __future__ import absolute_import as _abs +import tvm +from .. import generic, tag + + +@generic.schedule_conv1d_ncw.register(["cpu"]) +def schedule_conv1d_ncw(outs): + """Create schedule for tensors""" + s = tvm.create_schedule([x.op for x in outs]) + output_op = outs[0].op + scheduled_ops = [] + + def traverse(op): + """Traverse operators from computation graph""" + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_broadcast(op.tag): + if op not in s.outputs: + s[op].compute_inline() + else: # inject custom schedule + if len(op.axis) == 3: # schedule bias + bn + relu + n, c, w = op.axis + fused = s[op].fuse(n, c) + s[op].parallel(fused) + s[op].vectorize(w) + for tensor in op.input_tensors: + if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops: + traverse(tensor.op) + + if 'conv1d_ncw' in op.tag: + conv = op.output(0) + kernel = op.input_tensors[1] + if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + + data = op.input_tensors[0] + data_pad = None + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + + n_pad, c_pad, w_pad = data_pad.op.axis + pad_fused = s[data_pad].fuse(n_pad, c_pad) + s[data_pad].parallel(pad_fused) + C = conv + n, c, w = C.op.axis + rc, rw = C.op.reduce_axis + n_out, c_out, w_out = output_op.axis + s[C].vectorize(w) + if op != output_op: # fuse bias + bn + relu into conv + s[C].compute_at(s[output_op], w_out) + else: + fused = s[C].fuse(n, c) + s[C].parallel(fused) + + scheduled_ops.append(op) + + traverse(output_op) + return s + + +@generic.schedule_conv1d_nwc.register(["cpu"]) +def schedule_conv1d_nwc(outs): + """Create schedule for tensors""" + s = tvm.create_schedule([x.op for x in outs]) + output_op = outs[0].op + scheduled_ops = [] + + def traverse(op): + """Traverse operators from computation graph""" + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_broadcast(op.tag): + if op not in s.outputs: + s[op].compute_inline() + else: # inject custom schedule + if len(op.axis) == 3: # schedule bias + bn + relu + n, w, c = op.axis + fused = s[op].fuse(n, w) + s[op].parallel(fused) + s[op].vectorize(c) + for tensor in op.input_tensors: + if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops: + traverse(tensor.op) + + if 'conv1d_nwc' in op.tag: + conv = op.output(0) + kernel = op.input_tensors[1] + if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + + data = op.input_tensors[0] + data_pad = None + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + + n_pad, w_pad, c_pad = data_pad.op.axis + pad_fused = s[data_pad].fuse(n_pad, w_pad) + s[data_pad].parallel(pad_fused) + C = conv + n, w, c = C.op.axis + rc, rw = C.op.reduce_axis + n_out, w_out, c_out = output_op.axis + s[C].vectorize(c) + if op != output_op: # fuse bias + bn + relu into conv + s[C].compute_at(s[output_op], c_out) + else: + fused = s[C].fuse(n, w) + s[C].parallel(fused) + + scheduled_ops.append(op) + + traverse(output_op) + return s diff --git a/topi/tests/python/test_topi_conv1d.py b/topi/tests/python/test_topi_conv1d.py new file mode 100644 index 000000000000..d54742c01d14 --- /dev/null +++ b/topi/tests/python/test_topi_conv1d.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for transposed convolution.""" +import numpy as np +import itertools +import tvm +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple +from common import get_all_backend + + +def verify_conv1d(batch, + in_channels, + in_width, + filters, + kernel_size=3, + stride=1, + dilation=1, + padding='VALID', + layout='NCW'): + if layout == 'NCW': + in_shape = [batch, in_channels, in_width] + kernel_shape = [filters, in_channels, kernel_size] + else: + in_shape = [batch, in_width, in_channels] + kernel_shape = [kernel_size, in_channels, filters] + + dtype = 'float32' + A = tvm.placeholder(in_shape, name='A', dtype=dtype) + W = tvm.placeholder(kernel_shape, name='W', dtype=dtype) + + def get_ref_data(layout): + a_np = np.random.uniform(size=in_shape).astype(dtype) + w_np = np.random.uniform(size=kernel_shape).astype(dtype) + if layout == 'NWC': + np_in = np.transpose(a_np, [0, 2, 1]) + np_w = np.transpose(w_np, [2, 1, 0]) + else: + np_in = a_np + np_w = w_np + b_np = topi.testing.conv1d_ncw_python(np_in, np_w, stride, padding, dilation) + if layout == 'NWC': + b_np = np.transpose(b_np, [0, 2, 1]) + return a_np, w_np, b_np + + a_np, w_np, b_np = get_ref_data(layout) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + with tvm.target.create(device): + B = topi.nn.conv1d(A, W, stride, padding, dilation, layout, 'float32') + if layout == 'NCW': + s = topi.generic.schedule_conv1d_ncw([B]) + else: + s = topi.generic.schedule_conv1d_nwc([B]) + + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) + + func = tvm.build(s, [A, W, B], device) + func(a, w, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + for device in get_all_backend(): + check_device(device) + + +def test_conv1d(): + for layout in ["NCW", "NWC"]: + # Most basic test case + verify_conv1d(1, 1, 8, 1, 3, 1, 1, 'VALID', layout) + # With padding + verify_conv1d(1, 1, 8, 1, 3, 1, 1, 'SAME', layout) + # Realistic dimensions + verify_conv1d(1, 16, 32, 16, 3, 1, 1, 'SAME', layout) + # With stride + verify_conv1d(1, 16, 32, 16, 3, 2, 1, 'SAME', layout) + # With dilation + verify_conv1d(1, 16, 32, 16, 3, 1, 2, 'SAME', layout) + # Large batch size + verify_conv1d(8, 16, 32, 16, 3, 1, 1, 'SAME', layout) + # Other kernel sizes + verify_conv1d(1, 16, 32, 16, 3, 1, 1, 'SAME', layout) + verify_conv1d(1, 16, 32, 16, 2, 1, 1, 'SAME', layout) + verify_conv1d(1, 16, 32, 16, 1, 1, 1, 'SAME', layout) + # Non-power-of-two shape + verify_conv1d(1, 17, 12, 21, 3, 1, 1, 'SAME', layout) + verify_conv1d(1, 5, 27, 18, 3, 1, 1, 'VALID', layout) + + + +if __name__ == "__main__": + test_conv1d()