Skip to content

Commit

Permalink
[FRONTEND][TF] Add conv3d (#4604)
Browse files Browse the repository at this point in the history
* [FRONTEND][TF] Add conv3d

* fix high rtol
  • Loading branch information
optima2005 authored and masahi committed Jan 1, 2020
1 parent f096c06 commit 1ef1605
Show file tree
Hide file tree
Showing 17 changed files with 683 additions and 66 deletions.
6 changes: 5 additions & 1 deletion include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,11 @@ struct Conv3DAttrs : public tvm::AttrsNode<Conv3DAttrs> {
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"three int : back, bottom, right will use same padding as front, top, left"
"six int : padding width in the order of (front, top, left, back, bottom,"
"right)");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1)
Expand Down
133 changes: 130 additions & 3 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,18 @@ def _impl(attr):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
if len(kernel) == 3:
return prefix + '3d' + surfix
raise tvm.error.OpAttributeInvalid(
'Only 2D kernels are supported for operator {}'.format(prefix + '2d'))
'Only 2D or 3D kernels are supported for operator {}'.format(prefix + '2d or 3d'))
return _impl

def _dimension_constraint():
def _dim_check(attrs):
if len(attrs['kernel_shape']) == 2:
if len(attrs['kernel_shape']) in (2, 3):
return True
return False
return _dim_check, "Only 2d kernel supported."
return _dim_check, "Only 2d or 3d kernel supported."

def _get_param(params, input_node):
if isinstance(input_node, _expr.Constant):
Expand Down Expand Up @@ -425,6 +427,130 @@ def _impl(inputs, attr, params):
return out
return _impl

def _conv3d(opname):
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False

inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2]

# NCDHW Layout require weights transpose
if attr['data_format'] == 'NCDHW':
tmp_shape = attr['_input_shapes'][inputs[1]]
tmp_shape = [tmp_shape[ii] for ii in (4, 3, 0, 1, 2)]
inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2))
attr['_input_shapes'][inputs[1]] = tmp_shape

input_shape = attr['_input_shapes'][inputs_data]
weights_shape = attr['_input_shapes'][inputs[1]]

if attr['_target_layout'] == "NCDHW" and attr['data_format'] == "NDHWC":
input_shape = [input_shape[ii] for ii in (0, 4, 1, 2, 3)]
inputs_data = _op.transpose(inputs_data, axes=(0, 4, 1, 2, 3))
weights_shape = [weights_shape[ii] for ii in (4, 3, 0, 1, 2)]
inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2))

attr['data_format'] = "NCDHW"
attr['strides'] = [attr['strides'][ii] for ii in (0, 4, 1, 2, 3)]
flip_layout = True

if attr['data_format'] == 'NDHWC':
kernel_d, kernel_h, kernel_w, _, _ = weights_shape
attr['kernel_shape'] = (kernel_d, kernel_h, kernel_w)
if opname == 'conv':
attr['channels'] = weights_shape[4]
elif opname == 'conv_transpose':
attr['channels'] = weights_shape[3]

if 'dilations' in attr:
attr['dilations'] =\
(attr['dilations'][1], attr['dilations'][2], attr['dilations'][3])
attr['strides'] = (attr['strides'][1], attr['strides'][2], attr['strides'][3])
elif attr['data_format'] == 'NCDHW':
_, _, kernel_d, kernel_h, kernel_w = weights_shape
attr['kernel_shape'] = (kernel_d, kernel_h, kernel_w)
if opname == 'conv':
attr['channels'] = weights_shape[0]
elif opname == 'conv_transpose':
attr['channels'] = weights_shape[1]

if 'dilations' in attr:
attr['dilations'] =\
(attr['dilations'][2], attr['dilations'][3], attr['dilations'][4])
attr['strides'] = (attr['strides'][2], attr['strides'][3], attr['strides'][4])
else:
msg = 'Value {} in attribute "data_format" of operator Conv is ' \
'not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))

# Fix padding
attr['padding'] = attr['padding'].decode("utf-8")

if attr['padding'] == 'VALID':
attr['padding'] = [0, 0, 0]
elif attr['padding'] == 'SAME':
stride_d, stride_h, stride_w = attr['strides']
kernel_d, kernel_h, kernel_w = attr['kernel_shape']

pdata_shape = input_shape
if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0:
pdata_shape = attr['_output_shapes'][0]

if attr['data_format'] == 'NDHWC':
in_d = pdata_shape[1]
in_h = pdata_shape[2]
in_w = pdata_shape[3]
else:
in_d = pdata_shape[2]
in_h = pdata_shape[3]
in_w = pdata_shape[4]

dilation_d = attr['dilations'][0]
dilation_h = attr['dilations'][1]
dilation_w = attr['dilations'][2]
dilated_kernel_d = (kernel_d - 1) * dilation_d + 1
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_d = _get_pad_pair(in_d, dilated_kernel_d, stride_d)
pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)

attr['padding'] = [pad_d[0], pad_v[0], pad_h[0], pad_v[0], pad_v[1], pad_h[1]]

else:
msg = 'Value {} in attribute "padding" of operator Conv is not ' \
'valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))

if 'kernel_layout' not in attr:
attr['kernel_layout'] = 'DHWIO' if attr['data_format'] == 'NDHWC' else 'OIDHW'

use_bias = len(inputs) == (3 if opname != 'conv_transpose' else 4)
channel_axis = 1 if attr['data_format'] == "NCDHW" else 3

# Ignore the new attributes from TF2.0, for now.
out = AttrCvt(
op_name=_dimension_picker('conv', \
surfix="_transpose" if opname == 'conv_transpose' else ""),
ignores=['explicit_paddings'],
transforms={
'kernel_shape': 'kernel_size',
'data_format': 'data_layout',
'dilations': ('dilation', (0, 0)),
'group': ('groups', 1)},
custom_check=_dimension_constraint())([inputs_data, inputs[1]], attr)

if use_bias:
out = _op.nn.bias_add(out,
inputs[2] if opname != 'conv_transpose' else inputs[3],
axis=channel_axis)

if flip_layout:
out = _op.transpose(out, axes=(0, 2, 3, 4, 1))

return out
return _impl

def _decode_image():
def _impl(inputs, attr, params):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
Expand Down Expand Up @@ -1442,6 +1568,7 @@ def _impl(inputs, attr, params):
'Concat' : _concat(),
'ConcatV2' : _concatV2(),
'Conv2D' : _conv('conv'),
'Conv3D' : _conv3d('conv'),
'Conv2DBackpropInput' : _conv('conv_transpose'),
'CropAndResize' : _crop_and_resize(),
'DecodeJpeg' : _decode_image(),
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def _get_out_depth():
assert len(weight_shape) == 5
C, M, _, _, VC = weight_shape
return C * VC * M

if groups == 1:
out = topi.nn.conv2d(
inputs[0], inputs[1], strides, padding,
Expand Down Expand Up @@ -330,7 +331,7 @@ def compute_conv3d(attrs, inputs, out_type, target):
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype)

assert layout in ["NCDHW"]
assert layout in ["NCDHW", "NDHWC"]
(dilation_d, dilation_h, dilation_w) = dilation
if dilation_d < 1 or dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
Expand All @@ -353,6 +354,8 @@ def schedule_conv3d(attrs, outs, target):
with target:
if groups == 1 and layout == "NCDHW":
return topi.generic.schedule_conv3d_ncdhw(outs)
elif groups == 1 and layout == "NDHWC":
return topi.generic.schedule_conv3d_ndhwc(outs)

raise ValueError("No compatible schedule")

Expand Down
19 changes: 10 additions & 9 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace relay {
TVM_REGISTER_NODE_TYPE(Conv2DAttrs);

template<typename T>
Array<Array<Layout> > Conv2DInferCorrectLayout(
Array<Array<Layout> > ConvInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
Expand Down Expand Up @@ -105,7 +105,7 @@ with the layer input to produce a tensor of outputs.
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Conv2DInferCorrectLayout<Conv2DAttrs>);
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);

// relay.nn.conv3d
TVM_REGISTER_NODE_TYPE(Conv3DAttrs);
Expand Down Expand Up @@ -163,7 +163,8 @@ with the layer input to produce a tensor of outputs.
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.add_type_rel("Conv3D", Conv3DRel<Conv3DAttrs>);
.add_type_rel("Conv3D", Conv3DRel<Conv3DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv3DAttrs>);


// relay.nn.conv2d_transpose
Expand Down Expand Up @@ -337,7 +338,7 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW`
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DTransposeAttrs>)
ConvInferCorrectLayout<Conv2DTransposeAttrs>)
.add_type_rel("Conv2DTranspose", Conv2DTransposeRel);


Expand Down Expand Up @@ -635,7 +636,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
.set_support_level(10)
.add_type_rel("Conv2DWinograd", Conv2DWinogradRel<Conv2DWinogradAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DWinogradAttrs>);
ConvInferCorrectLayout<Conv2DWinogradAttrs>);

// relay.nn.contrib_conv2d_winograd_weight_transform
TVM_REGISTER_NODE_TYPE(Conv2DWinogradWeightTransformAttrs);
Expand Down Expand Up @@ -744,7 +745,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2DWinogradNNPACKRel", Conv2DWinogradRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Conv2DInferCorrectLayout<Conv2DAttrs>);
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);

// relay.nn.contrib_conv2d_winograd_nnpack_weight_transform
TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs);
Expand Down Expand Up @@ -854,7 +855,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc_int8")
.set_support_level(10)
.add_type_rel("Conv2DNCHWcInt8", Conv2DWinogradRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DAttrs>);
ConvInferCorrectLayout<Conv2DAttrs>);

// Positional relay function to create conv2d NCHWc operator
// used by frontend FFI.
Expand Down Expand Up @@ -903,7 +904,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
.set_support_level(10)
.add_type_rel("Conv2DNCHWc", Conv2DWinogradRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DAttrs>);
ConvInferCorrectLayout<Conv2DAttrs>);


// Positional relay function to create depthwise conv2d NCHWc operator
Expand Down Expand Up @@ -953,7 +954,7 @@ RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
.set_support_level(10)
.add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DAttrs>);
ConvInferCorrectLayout<Conv2DAttrs>);


bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expand Down
13 changes: 9 additions & 4 deletions src/relay/op/nn/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <string>
#include <utility>

#include "../op_common.h"

namespace tvm {
namespace relay {

Expand Down Expand Up @@ -187,7 +189,7 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
param->kernel_size[1], param->kernel_size[2]}};
}

/*wshape = trans_kernel_layout.BackwardShape(wshape); */
wshape = trans_kernel_layout.BackwardShape(wshape);
channels = param->channels;
dilated_ksize_z = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
Expand All @@ -196,6 +198,7 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (weight != nullptr) {
weight_dtype = weight->dtype;
}

// assign result to reporter
reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
} else {
Expand Down Expand Up @@ -225,22 +228,24 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// dilation
Array<IndexExpr> oshape({dshape_ncdhw[0], channels, 0, 0, 0});

IndexExpr pad_d, pad_h, pad_w;
GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w);
if (!dshape_ncdhw[2].as<ir::Any>()) {
oshape.Set(2, indexdiv(dshape_ncdhw[2] + param->padding[0] * 2 - dilated_ksize_z,
oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z,
param->strides[0]) + 1);
} else {
oshape.Set(2, dshape_ncdhw[2]);
}

if (!dshape_ncdhw[3].as<ir::Any>()) {
oshape.Set(3, indexdiv(dshape_ncdhw[3] + param->padding[1] * 2 - dilated_ksize_y,
oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y,
param->strides[1]) + 1);
} else {
oshape.Set(3, dshape_ncdhw[3]);
}

if (!dshape_ncdhw[4].as<ir::Any>()) {
oshape.Set(4, indexdiv(dshape_ncdhw[4] + param->padding[2] * 2 - dilated_ksize_x,
oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x,
param->strides[2]) + 1);
} else {
oshape.Set(4, dshape_ncdhw[4]);
Expand Down
39 changes: 39 additions & 0 deletions src/relay/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,45 @@ inline void GetPaddingWidth(const Array<IndexExpr>& padding, IndexExpr* pad_w) {
}
}

/*! \brief A utility function to get padding height and width from a 1, 2, 4 ints tuple. */
inline void GetPaddingHeightWidth(const Array<IndexExpr>& padding, IndexExpr* pad_h,
IndexExpr* pad_w) {
if (padding.size() == 1) {
*pad_h = padding[0] * 2;
*pad_w = padding[0] * 2;
} else if (padding.size() == 2) {
*pad_h = padding[0] * 2;
*pad_w = padding[1] * 2;
} else if (padding.size() == 4) {
*pad_h = padding[0] + padding[2];
*pad_w = padding[1] + padding[3];
} else {
CHECK_EQ(padding.size(), 4) << " Padding size should be 1, 2 or 4, but got "
<< padding.size();
}
}

/*! \brief A utility function to get padding depth, height and width from a 1, 3, 6 ints tuple. */
inline void GetPaddingDepthHeightWidth(const Array<IndexExpr>& padding, IndexExpr* pad_d,
IndexExpr* pad_h, IndexExpr* pad_w) {
if (padding.size() == 1) {
*pad_d = padding[0] * 2;
*pad_h = padding[0] * 2;
*pad_w = padding[0] * 2;
} else if (padding.size() == 3) {
*pad_d = padding[0] * 2;
*pad_h = padding[1] * 2;
*pad_w = padding[2] * 2;
} else if (padding.size() == 6) {
*pad_d = padding[0] + padding[3];
*pad_h = padding[1] + padding[4];
*pad_w = padding[2] + padding[5];
} else {
CHECK_EQ(padding.size(), 6) << " Padding size should be 1, 3 or 6, but got "
<< padding.size();
}
}

} // namespace relay
} // namespace tvm

Expand Down
Loading

0 comments on commit 1ef1605

Please sign in to comment.