Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOP][ARM] Winograd nnpack backend #2721

Merged
merged 1 commit into from
Mar 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,24 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
}
};

/*! \brief Attributes used in winograd weight transformation operators */
struct Conv2DWinogradNNPACKWeightTransformAttrs
: public tvm::AttrsNode<Conv2DWinogradNNPACKWeightTransformAttrs> {
int convolution_algorithm;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv2DWinogradNNPACKWeightTransformAttrs,
"relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs") {
TVM_ATTR_FIELD(convolution_algorithm)
.describe(
"The convolution algorithm for Winograd NNPACK. "
"E.g. tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8 for WT_8x8, "
"tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8_FP16 for WT_8x8_FP16");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};

/*! \brief Attributes used in softmax operators */
struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
Expand Down
20 changes: 20 additions & 0 deletions nnvm/include/nnvm/top/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,26 @@ struct WinogradWeightTransformParam : public dmlc::Parameter<WinogradWeightTrans
static const constexpr int kWeight = 0;
};

struct WinogradNNPACKWeightTransformParam
: public dmlc::Parameter<WinogradNNPACKWeightTransformParam> {
int convolution_algorithm;
int out_dtype;

DMLC_DECLARE_PARAMETER(WinogradNNPACKWeightTransformParam) {
DMLC_DECLARE_FIELD(convolution_algorithm)
.describe(
"The convolution algorithm for Winograd NNPACK. "
"E.g. tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8 for WT_8x8, "
"tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8_FP16 for WT_8x8_FP16");
DMLC_DECLARE_DTYPE_FIELD(out_dtype)
.add_enum("same", -1)
.set_default(-1)
.describe("Output data type, set to explicit type under mixed precision setting");
}

static const constexpr int kWeight = 0;
};

struct WinogradConv2DParam : public dmlc::Parameter<WinogradConv2DParam> {
int channels;
TShape kernel_size;
Expand Down
47 changes: 47 additions & 0 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def alter_conv2d_layout(attrs, inputs, tinfos):
sym.contrib.conv2d_winograd_without_weight_transform
sym.contrib_conv2d_winograd_weight_transform = \
sym.contrib.conv2d_winograd_weight_transform
sym.contrib_conv2d_winograd_nnpack_without_weight_transform = \
sym.contrib.conv2d_winograd_nnpack_without_weight_transform
sym.contrib_conv2d_winograd_nnpack_weight_transform = \
sym.contrib.conv2d_winograd_nnpack_weight_transform
sym.nn = sym

# map relay argument names to nnvm argument names
Expand Down Expand Up @@ -274,6 +278,49 @@ def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, targe
OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_compute("_contrib_conv2d_winograd_nnpack_weight_transform")
def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, _):
convolution_algorithm = attrs.get_int('convolution_algorithm')
out_dype = attrs.get_str('out_dtype')
return topi.nn.conv2d_winograd_nnpack_weight_transform(
inputs[0], convolution_algorithm, out_dype)


@reg.register_schedule("_contrib_conv2d_winograd_nnpack_weight_transform")
def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
with tvm.target.create(target):
return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)

reg.register_pattern("_contrib_conv2d_winograd_nnpack_weight_transform", OpPattern.OPAQUE)


@reg.register_compute("_contrib_conv2d_winograd_nnpack_without_weight_transform")
def compute_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, inputs, _):
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs.get_str("layout")
out_dtype = attrs.get_str("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
assert dilation == (1, 1), "Do not support dilate now"
assert groups == 1, "Do not supoort arbitrary group number"

# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_winograd_nnpack_without_weight_transform(
inputs[0], inputs[1], inputs[2] if attrs.get_bool("use_bias") else None,
strides, padding, dilation, layout, out_dtype)
return out

@reg.register_schedule("_contrib_conv2d_winograd_nnpack_without_weight_transform")
def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs, target):
with tvm.target.create(target):
return topi.generic.schedule_conv2d_winograd_nnpack_without_weight_transform(outs)

reg.register_pattern("_contrib_conv2d_winograd_nnpack_without_weight_transform",
OpPattern.OPAQUE)


# conv2d_transpose
@reg.register_compute("conv2d_transpose")
def compute_conv2d_transpose(attrs, inputs, _):
Expand Down
81 changes: 79 additions & 2 deletions nnvm/src/top/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,14 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
return true;
}

template<class Param>
inline bool WinogradConv2DInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW");

const WinogradConv2DParam& param = nnvm::get<WinogradConv2DParam>(attrs.parsed);
const Param& param = nnvm::get<Param>(attrs.parsed);

const Layout in_layout(param.layout);
const Layout kernel_layout(param.kernel_layout);
Expand Down Expand Up @@ -403,7 +404,7 @@ NNVM_REGISTER_OP(_contrib_conv2d_winograd_without_weight_transform)
.set_attr_parser(ParamParser<WinogradConv2DParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<WinogradConv2DParam>)
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<WinogradConv2DParam>)
.set_attr<FInferShape>("FInferShape", WinogradConv2DInferShape)
.set_attr<FInferShape>("FInferShape", WinogradConv2DInferShape<WinogradConv2DParam>)
.set_attr<FInferType>("FInferType", Conv2DInferType<WinogradConv2DParam>)
.set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout<WinogradConv2DParam>)
.set_num_outputs(1)
Expand All @@ -412,6 +413,82 @@ NNVM_REGISTER_OP(_contrib_conv2d_winograd_without_weight_transform)

DMLC_REGISTER_PARAMETER(WinogradConv2DParam);


inline bool Conv2DWinogradNNPACKWTInferType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_type,
std::vector<int>* out_type) {
const WinogradNNPACKWeightTransformParam& param =
nnvm::get<WinogradNNPACKWeightTransformParam>(attrs.parsed);

CHECK_EQ(in_type->size(), 1U) << "Input:[weight]";
CHECK_EQ(out_type->size(), 1U);

if (param.out_dtype != -1) {
NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_type, 0, param.out_dtype);
} else {
ElemwiseType<1, 1>(attrs, in_type, out_type);
}
return true;
}

NNVM_REGISTER_OP(_contrib_conv2d_winograd_nnpack_weight_transform)
.describe(R"code(Weight transformation of winograd fast convolution algorithm.
Separate this into another nnvm symbol in order to enable Precompute Pass to compute the
weight transformation in advance.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
)code" NNVM_ADD_FILELINE)
.add_argument("weight", "4D Tensor", "Weight tensor.")
.add_arguments(WinogradNNPACKWeightTransformParam::__FIELDS__())
.set_attr_parser(ParamParser<WinogradNNPACKWeightTransformParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<WinogradNNPACKWeightTransformParam>)
.set_attr<FInferShape>("FInferShape", [](const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
const TShape &wshape = (*in_shape)[0];
CHECK_EQ(wshape.ndim(), 4) << "Weight should be a 4 dimensional tensor";
TShape oshape({wshape[0], wshape[1], 8, 8});
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
})
.set_attr<FCorrectLayout>("FCorrectLayout", [](const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
Layout layout("OIHW");
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, layout);
NNVM_ASSIGN_LAYOUT(*olayouts, 0, layout);
return true;
})
.set_attr<FInferType>("FInferType", Conv2DWinogradNNPACKWTInferType)
.set_num_outputs(1)
.set_num_inputs(1)
.set_support_level(5);

DMLC_REGISTER_PARAMETER(WinogradNNPACKWeightTransformParam);

NNVM_REGISTER_OP(_contrib_conv2d_winograd_nnpack_without_weight_transform)
.describe(R"code(Compute conv2d with winograd nnpack.
- **data**: Input is 4D array of shape (batch_size, in_channels, height, width)
- **weight**: Any shape
We do not check shape for this input tensor.
- **bias**: (channels,)
- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width)
)code" NNVM_ADD_FILELINE)
.add_argument("data", "4D Tensor", "Input data.")
.add_argument("weight", "4D Tensor", "Transformed weight tensor.")
.add_argument("bias", "1D Tensor", "Bias parameter.")
.add_arguments(Conv2DParam::__FIELDS__())
.set_attr_parser(ParamParser<Conv2DParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Conv2DParam>)
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>)
.set_attr<FInferShape>("FInferShape", WinogradConv2DInferShape<Conv2DParam>)
.set_attr<FInferType>("FInferType", Conv2DInferType<Conv2DParam>)
.set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout<Conv2DParam>)
.set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<Conv2DParam>)
.set_support_level(5);


NNVM_REGISTER_OP(_conv2d_grad)
.describe(R"code(2D convolution grad.

Expand Down
10 changes: 6 additions & 4 deletions python/tvm/contrib/nnpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,12 @@ def convolution_inference_without_weight_transform(
ins[1],
ins[2] if bias is not None else 0,
outs[0], padding[0], padding[1], padding[2], padding[3],
stride[0], stride[1], nthreads, algorithm), name="C")
stride[0], stride[1], nthreads, algorithm), name="C", dtype='float32')

def convolution_inference_weight_transform(
kernel, nthreads=1,
algorithm=ConvolutionAlgorithm.AUTO):
algorithm=ConvolutionAlgorithm.AUTO,
dtype='float32'):
"""Create an extern op to do inference convolution of 3D tensor data and
4D tensor kernel and 1D tensor bias with nnpack.

Expand All @@ -171,13 +172,14 @@ def convolution_inference_weight_transform(
"""
assert algorithm in (ConvolutionAlgorithm.WT_8x8, ConvolutionAlgorithm.WT_8x8_FP16)
output_channels, input_channels, _, _ = kernel.shape

transform_tile_size = 8
if not isinstance(dtype, str):
dtype = dtype.dtype
return _api.extern(
(output_channels, input_channels, transform_tile_size, transform_tile_size),
[kernel],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.convolution_inference_weight_transform",
ins[0], outs[0], nthreads, algorithm), name="transform_kernel")
ins[0], outs[0], nthreads, algorithm), name="transform_kernel", dtype=dtype)

_init_api("tvm.contrib.nnpack")
52 changes: 52 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,58 @@ def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target):
reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)


# winograd nnpack related operators
@reg.register_compute("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
def compute_contrib_conv2d_winograd_nnpack_without_weight_transform(
attrs, inputs, out_dtype, target):
"""Compute definition of conv2d_winograd_nnpack_without_weight_transform"""
# pylint: disable=assignment-from-no-return
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
data_layout = attrs.get_str("data_layout")
out_dtype = attrs.get_str("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
assert dilation == (1, 1), "Do not support dilate now"
assert groups == 1, "Do not supoort arbitrary group number"

# No bias
out = topi.nn.conv2d_winograd_nnpack_without_weight_transform(
inputs[0], inputs[1], None, strides, padding, dilation, data_layout,
out_dtype)

return [out]

@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs, target):
"""Schedule definition of conv2d_winograd_nnpack_without_weight_transform"""
with target:
return topi.generic.schedule_conv2d_winograd_nnpack_without_weight_transform(outs)

reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_without_weight_transform",
OpPattern.OPAQUE)


@reg.register_compute("nn.contrib_conv2d_winograd_nnpack_weight_transform")
def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_dtype, target):
"""Compute definition of contrib_conv2d_winograd_nnpack_weight_transform"""
convolution_algorithm = attrs.get_int('convolution_algorithm')
out = topi.nn.conv2d_winograd_nnpack_weight_transform(
inputs[0], convolution_algorithm, out_dtype)
return [out]

@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform")
def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
"""Schedule definition of contrib_conv2d_winograd_nnpack_weight_transform"""
with target:
return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)

reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_weight_transform",
OpPattern.OPAQUE)


@reg.register_compute("nn.contrib_conv2d_NCHWc")
def compute_contrib_conv2d_NCHWc(attrs, inputs, out_dtype, target):
"""Compute definition of conv2d NCHWc"""
Expand Down
Loading