Skip to content

Commit

Permalink
winograd_nnpack
Browse files Browse the repository at this point in the history
  • Loading branch information
hlu1 committed Mar 26, 2019
1 parent 7f94247 commit 84a1394
Show file tree
Hide file tree
Showing 15 changed files with 828 additions and 31 deletions.
18 changes: 18 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,24 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
}
};

/*! \brief Attributes used in winograd weight transformation operators */
struct Conv2DWinogradNNPACKWeightTransformAttrs
: public tvm::AttrsNode<Conv2DWinogradNNPACKWeightTransformAttrs> {
int convolution_algorithm;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv2DWinogradNNPACKWeightTransformAttrs,
"relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs") {
TVM_ATTR_FIELD(convolution_algorithm)
.describe(
"The convolution algorithm for Winograd NNPACK. "
"E.g. tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8 for WT_8x8, "
"tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8_FP16 for WT_8x8_FP16");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};

/*! \brief Attributes used in softmax operators */
struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
Expand Down
20 changes: 20 additions & 0 deletions nnvm/include/nnvm/top/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,26 @@ struct WinogradWeightTransformParam : public dmlc::Parameter<WinogradWeightTrans
static const constexpr int kWeight = 0;
};

struct WinogradNNPACKWeightTransformParam
: public dmlc::Parameter<WinogradNNPACKWeightTransformParam> {
int convolution_algorithm;
int out_dtype;

DMLC_DECLARE_PARAMETER(WinogradNNPACKWeightTransformParam) {
DMLC_DECLARE_FIELD(convolution_algorithm)
.describe(
"The convolution algorithm for Winograd NNPACK. "
"E.g. tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8 for WT_8x8, "
"tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8_FP16 for WT_8x8_FP16");
DMLC_DECLARE_DTYPE_FIELD(out_dtype)
.add_enum("same", -1)
.set_default(-1)
.describe("Output data type, set to explicit type under mixed precision setting");
}

static const constexpr int kWeight = 0;
};

struct WinogradConv2DParam : public dmlc::Parameter<WinogradConv2DParam> {
int channels;
TShape kernel_size;
Expand Down
47 changes: 47 additions & 0 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def alter_conv2d_layout(attrs, inputs, tinfos):
sym.contrib.conv2d_winograd_without_weight_transform
sym.contrib_conv2d_winograd_weight_transform = \
sym.contrib.conv2d_winograd_weight_transform
sym.contrib_conv2d_winograd_nnpack_without_weight_transform = \
sym.contrib.conv2d_winograd_nnpack_without_weight_transform
sym.contrib_conv2d_winograd_nnpack_weight_transform = \
sym.contrib.conv2d_winograd_nnpack_weight_transform
sym.nn = sym

# map relay argument names to nnvm argument names
Expand Down Expand Up @@ -274,6 +278,49 @@ def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, targe
OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_compute("_contrib_conv2d_winograd_nnpack_weight_transform")
def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, _):
convolution_algorithm = attrs.get_int('convolution_algorithm')
out_dype = attrs.get_str('out_dtype')
return topi.nn.conv2d_winograd_nnpack_weight_transform(
inputs[0], convolution_algorithm, out_dype)


@reg.register_schedule("_contrib_conv2d_winograd_nnpack_weight_transform")
def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
with tvm.target.create(target):
return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)

reg.register_pattern("_contrib_conv2d_winograd_nnpack_weight_transform", OpPattern.OPAQUE)


@reg.register_compute("_contrib_conv2d_winograd_nnpack_without_weight_transform")
def compute_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, inputs, _):
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs.get_str("layout")
out_dtype = attrs.get_str("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
assert dilation == (1, 1), "Do not support dilate now"
assert groups == 1, "Do not supoort arbitrary group number"

# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_winograd_nnpack_without_weight_transform(
inputs[0], inputs[1], inputs[2] if attrs.get_bool("use_bias") else None,
strides, padding, dilation, layout, out_dtype)
return out

@reg.register_schedule("_contrib_conv2d_winograd_nnpack_without_weight_transform")
def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs, target):
with tvm.target.create(target):
return topi.generic.schedule_conv2d_winograd_nnpack_without_weight_transform(outs)

reg.register_pattern("_contrib_conv2d_winograd_nnpack_without_weight_transform",
OpPattern.OPAQUE)


# conv2d_transpose
@reg.register_compute("conv2d_transpose")
def compute_conv2d_transpose(attrs, inputs, _):
Expand Down
81 changes: 79 additions & 2 deletions nnvm/src/top/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,14 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
return true;
}

template<class Param>
inline bool WinogradConv2DInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW");

const WinogradConv2DParam& param = nnvm::get<WinogradConv2DParam>(attrs.parsed);
const Param& param = nnvm::get<Param>(attrs.parsed);

const Layout in_layout(param.layout);
const Layout kernel_layout(param.kernel_layout);
Expand Down Expand Up @@ -403,7 +404,7 @@ NNVM_REGISTER_OP(_contrib_conv2d_winograd_without_weight_transform)
.set_attr_parser(ParamParser<WinogradConv2DParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<WinogradConv2DParam>)
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<WinogradConv2DParam>)
.set_attr<FInferShape>("FInferShape", WinogradConv2DInferShape)
.set_attr<FInferShape>("FInferShape", WinogradConv2DInferShape<WinogradConv2DParam>)
.set_attr<FInferType>("FInferType", Conv2DInferType<WinogradConv2DParam>)
.set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout<WinogradConv2DParam>)
.set_num_outputs(1)
Expand All @@ -412,6 +413,82 @@ NNVM_REGISTER_OP(_contrib_conv2d_winograd_without_weight_transform)

DMLC_REGISTER_PARAMETER(WinogradConv2DParam);


inline bool Conv2DWinogradNNPACKWTInferType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_type,
std::vector<int>* out_type) {
const WinogradNNPACKWeightTransformParam& param =
nnvm::get<WinogradNNPACKWeightTransformParam>(attrs.parsed);

CHECK_EQ(in_type->size(), 1U) << "Input:[weight]";
CHECK_EQ(out_type->size(), 1U);

if (param.out_dtype != -1) {
NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_type, 0, param.out_dtype);
} else {
ElemwiseType<1, 1>(attrs, in_type, out_type);
}
return true;
}

NNVM_REGISTER_OP(_contrib_conv2d_winograd_nnpack_weight_transform)
.describe(R"code(Weight transformation of winograd fast convolution algorithm.
Separate this into another nnvm symbol in order to enable Precompute Pass to compute the
weight transformation in advance.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
)code" NNVM_ADD_FILELINE)
.add_argument("weight", "4D Tensor", "Weight tensor.")
.add_arguments(WinogradNNPACKWeightTransformParam::__FIELDS__())
.set_attr_parser(ParamParser<WinogradNNPACKWeightTransformParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<WinogradNNPACKWeightTransformParam>)
.set_attr<FInferShape>("FInferShape", [](const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
const TShape &wshape = (*in_shape)[0];
CHECK_EQ(wshape.ndim(), 4) << "Weight should be a 4 dimensional tensor";
TShape oshape({wshape[0], wshape[1], 8, 8});
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
})
.set_attr<FCorrectLayout>("FCorrectLayout", [](const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
Layout layout("OIHW");
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, layout);
NNVM_ASSIGN_LAYOUT(*olayouts, 0, layout);
return true;
})
.set_attr<FInferType>("FInferType", Conv2DWinogradNNPACKWTInferType)
.set_num_outputs(1)
.set_num_inputs(1)
.set_support_level(5);

DMLC_REGISTER_PARAMETER(WinogradNNPACKWeightTransformParam);

NNVM_REGISTER_OP(_contrib_conv2d_winograd_nnpack_without_weight_transform)
.describe(R"code(Compute conv2d with winograd nnpack.
- **data**: Input is 4D array of shape (batch_size, in_channels, height, width)
- **weight**: Any shape
We do not check shape for this input tensor.
- **bias**: (channels,)
- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width)
)code" NNVM_ADD_FILELINE)
.add_argument("data", "4D Tensor", "Input data.")
.add_argument("weight", "4D Tensor", "Transformed weight tensor.")
.add_argument("bias", "1D Tensor", "Bias parameter.")
.add_arguments(Conv2DParam::__FIELDS__())
.set_attr_parser(ParamParser<Conv2DParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Conv2DParam>)
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>)
.set_attr<FInferShape>("FInferShape", WinogradConv2DInferShape<Conv2DParam>)
.set_attr<FInferType>("FInferType", Conv2DInferType<Conv2DParam>)
.set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout<Conv2DParam>)
.set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<Conv2DParam>)
.set_support_level(5);


NNVM_REGISTER_OP(_conv2d_grad)
.describe(R"code(2D convolution grad.
Expand Down
10 changes: 6 additions & 4 deletions python/tvm/contrib/nnpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,12 @@ def convolution_inference_without_weight_transform(
ins[1],
ins[2] if bias is not None else 0,
outs[0], padding[0], padding[1], padding[2], padding[3],
stride[0], stride[1], nthreads, algorithm), name="C")
stride[0], stride[1], nthreads, algorithm), name="C", dtype='float32')

def convolution_inference_weight_transform(
kernel, nthreads=1,
algorithm=ConvolutionAlgorithm.AUTO):
algorithm=ConvolutionAlgorithm.AUTO,
dtype='float32'):
"""Create an extern op to do inference convolution of 3D tensor data and
4D tensor kernel and 1D tensor bias with nnpack.
Expand All @@ -171,13 +172,14 @@ def convolution_inference_weight_transform(
"""
assert algorithm in (ConvolutionAlgorithm.WT_8x8, ConvolutionAlgorithm.WT_8x8_FP16)
output_channels, input_channels, _, _ = kernel.shape

transform_tile_size = 8
if not isinstance(dtype, str):
dtype = dtype.dtype
return _api.extern(
(output_channels, input_channels, transform_tile_size, transform_tile_size),
[kernel],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.convolution_inference_weight_transform",
ins[0], outs[0], nthreads, algorithm), name="transform_kernel")
ins[0], outs[0], nthreads, algorithm), name="transform_kernel", dtype=dtype)

_init_api("tvm.contrib.nnpack")
52 changes: 52 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,58 @@ def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target):
reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)


# winograd nnpack related operators
@reg.register_compute("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
def compute_contrib_conv2d_winograd_nnpack_without_weight_transform(
attrs, inputs, out_dtype, target):
"""Compute definition of conv2d_winograd_nnpack_without_weight_transform"""
# pylint: disable=assignment-from-no-return
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
data_layout = attrs.get_str("data_layout")
out_dtype = attrs.get_str("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
assert dilation == (1, 1), "Do not support dilate now"
assert groups == 1, "Do not supoort arbitrary group number"

# No bias
out = topi.nn.conv2d_winograd_nnpack_without_weight_transform(
inputs[0], inputs[1], None, strides, padding, dilation, data_layout,
out_dtype)

return [out]

@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs, target):
"""Schedule definition of conv2d_winograd_nnpack_without_weight_transform"""
with target:
return topi.generic.schedule_conv2d_winograd_nnpack_without_weight_transform(outs)

reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_without_weight_transform",
OpPattern.OPAQUE)


@reg.register_compute("nn.contrib_conv2d_winograd_nnpack_weight_transform")
def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_dtype, target):
"""Compute definition of contrib_conv2d_winograd_nnpack_weight_transform"""
convolution_algorithm = attrs.get_int('convolution_algorithm')
out = topi.nn.conv2d_winograd_nnpack_weight_transform(
inputs[0], convolution_algorithm, out_dtype)
return [out]

@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform")
def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
"""Schedule definition of contrib_conv2d_winograd_nnpack_weight_transform"""
with target:
return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)

reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_weight_transform",
OpPattern.OPAQUE)


@reg.register_compute("nn.contrib_conv2d_NCHWc")
def compute_contrib_conv2d_NCHWc(attrs, inputs, out_dtype, target):
"""Compute definition of conv2d NCHWc"""
Expand Down
Loading

0 comments on commit 84a1394

Please sign in to comment.