Skip to content

Commit

Permalink
Fix winograd_nnpack_fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
hlu1 committed Mar 11, 2019
1 parent 6d52eca commit ce8a38f
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 43 deletions.
4 changes: 4 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,17 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
struct Conv2DWinogradNNPACKWeightTransformAttrs
: public tvm::AttrsNode<Conv2DWinogradNNPACKWeightTransformAttrs> {
int convolution_algorithm;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv2DWinogradNNPACKWeightTransformAttrs,
"relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs") {
TVM_ATTR_FIELD(convolution_algorithm)
.describe(
"The convolution algorithm for Winograd NNPACK. E.g. 3 for WT_8x8, "
"6 for WT_8x8_FP16");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};

Expand Down
5 changes: 5 additions & 0 deletions nnvm/include/nnvm/top/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,17 @@ struct WinogradWeightTransformParam : public dmlc::Parameter<WinogradWeightTrans
struct WinogradNNPACKWeightTransformParam
: public dmlc::Parameter<WinogradNNPACKWeightTransformParam> {
int convolution_algorithm;
int out_dtype;

DMLC_DECLARE_PARAMETER(WinogradNNPACKWeightTransformParam) {
DMLC_DECLARE_FIELD(convolution_algorithm)
.describe(
"The convolution algorithm for Winograd NNPACK. E.g. 3 for WT_8x8, "
"6 for WT_8x8_FP16");
DMLC_DECLARE_DTYPE_FIELD(out_dtype)
.add_enum("same", -1)
.set_default(-1)
.describe("Output data type, set to explicit type under mixed precision setting");
}

static const constexpr int kWeight = 0;
Expand Down
5 changes: 3 additions & 2 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,9 @@ def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, targe

@reg.register_compute("_contrib_conv2d_winograd_nnpack_weight_transform")
def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, _):
return topi.nn.conv2d_winograd_nnpack_weight_transform(inputs[0],
attrs.get_int('convolution_algorithm'))
convolution_algorithm = attrs.get_int('convolution_algorithm')
out_dype = attrs.get_str('out_dtype')
return topi.nn.conv2d_winograd_nnpack_weight_transform(inputs[0], convolution_algorithm, out_dype)

@reg.register_schedule("_contrib_conv2d_winograd_nnpack_weight_transform")
def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
Expand Down
21 changes: 19 additions & 2 deletions nnvm/src/top/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,23 @@ NNVM_REGISTER_OP(_contrib_conv2d_winograd_without_weight_transform)
DMLC_REGISTER_PARAMETER(WinogradConv2DParam);


inline bool Conv2DWinogradNNPACKWTInferType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_type,
std::vector<int>* out_type) {
const WinogradNNPACKWeightTransformParam& param =
nnvm::get<WinogradNNPACKWeightTransformParam>(attrs.parsed);

CHECK_EQ(in_type->size(), 1U) << "Input:[weight]";
CHECK_EQ(out_type->size(), 1U);
printf("param.out_dtype: %d\n", param.out_dtype);
if (param.out_dtype != -1) {
NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_type, 0, param.out_dtype);
} else {
ElemwiseType<1, 1>(attrs, in_type, out_type);
}
return true;
}

NNVM_REGISTER_OP(_contrib_conv2d_winograd_nnpack_weight_transform)
.describe(R"code(Weight transformation of winograd fast convolution algorithm.
Separate this into another nnvm symbol in order to enable Precompute Pass to compute the
Expand All @@ -432,7 +449,7 @@ weight transformation in advance.
TShape oshape({wshape[0], wshape[1], 8, 8});
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
})
})
.set_attr<FCorrectLayout>("FCorrectLayout", [](const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
Expand All @@ -442,7 +459,7 @@ weight transformation in advance.
NNVM_ASSIGN_LAYOUT(*olayouts, 0, layout);
return true;
})
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferType>("FInferType", Conv2DWinogradNNPACKWTInferType)
.set_num_outputs(1)
.set_num_inputs(1)
.set_support_level(5);
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,8 @@ def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs
@reg.register_compute("nn.contrib_conv2d_winograd_nnpack_weight_transform")
def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_dtype, target):
"""Compute definition of contrib_conv2d_winograd_nnpack_weight_transform"""
out = topi.nn.conv2d_winograd_nnpack_weight_transform(inputs[0], attrs.get_int('convolution_algorithm'))
convolution_algorithm = attrs.get_int('convolution_algorithm')
out = topi.nn.conv2d_winograd_nnpack_weight_transform(inputs[0], convolution_algorithm, out_dtype)
return [out]

@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform")
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,8 @@ def contrib_conv2d_winograd_weight_transform(weight,


def contrib_conv2d_winograd_nnpack_weight_transform(weight,
convolution_algorithm):
convolution_algorithm,
out_dtype=""):
r"""Weight Transformation part for 2D convolution with winograd algorithm.
We separate this as a single op to enable pre-compute for inference.
Expand All @@ -1012,4 +1013,5 @@ def contrib_conv2d_winograd_nnpack_weight_transform(weight,
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_conv2d_winograd_nnpack_weight_transform(weight, convolution_algorithm)
return _make.contrib_conv2d_winograd_nnpack_weight_transform(
weight, convolution_algorithm, out_dtype)
70 changes: 36 additions & 34 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "../../pass/alter_op_layout.h"
#include "../layout.h"


namespace tvm {
namespace relay {

Expand Down Expand Up @@ -499,8 +500,8 @@ Expr MakeConv2DWinogradWeightTransform(Expr weight,

TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_weight_transform")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeConv2DWinogradWeightTransform, args, rv);
});
runtime::detail::unpack_call<Expr, 2>(MakeConv2DWinogradWeightTransform, args, rv);
});


RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform")
Expand All @@ -521,17 +522,17 @@ weight transformation in advance.
// Positional relay function to create conv2d winograd nnpack operator
// used by frontend FFI.
Expr MakeConv2DWinogradNNPACK(Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_node<Conv2DAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
Expand All @@ -547,17 +548,15 @@ Expr MakeConv2DWinogradNNPACK(Expr data,
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}


TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_nnpack_without_weight_transform")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 12>(MakeConv2DWinogradNNPACK, args, rv);
});

runtime::detail::unpack_call<Expr, 12>(MakeConv2DWinogradNNPACK, args, rv);
});

RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
.describe(R"code(Compute conv2d with winograd nnpack. Only supports NCHW layout.
This operator assumes the weight tensor is already pre-transformed by
nn.contrib_conv2d_winograd_nnpack_weight_transform.
This operator assumes the weight tensor is already pre-transformed by
nn.contrib_conv2d_winograd_nnpack_weight_transform.
- **data**: Input is 4D array of shape (batch_size, in_channels, height, width)
- **weight**: Any shape
Expand All @@ -572,60 +571,63 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2DWinogradNNPACKRel", Conv2DWinogradRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DAttrs>);
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Conv2DInferCorrectLayout<Conv2DAttrs>);

// relay.nn.contrib_conv2d_winograd_nnpack_weight_transform
TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs);

bool Conv2DWinogradNNPACKWeightTransformRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;

const Conv2DWinogradNNPACKWeightTransformAttrs* param = attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>();
const Conv2DWinogradNNPACKWeightTransformAttrs* param =
attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>();
CHECK(param != nullptr);

CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";

// each pad width element should be a pair of positive integers
std::vector<IndexExpr> oshape {
std::vector<IndexExpr> oshape{
data->shape[0],
data->shape[1],
8,
8,
};

reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
data->dtype));
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape), out_dtype));
return true;
}

Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight,
int convolution_algorithm) {
int convolution_algorithm,
DataType out_dtype) {
auto attrs = make_node<Conv2DWinogradNNPACKWeightTransformAttrs>();
attrs->convolution_algorithm = convolution_algorithm;
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.contrib_conv2d_winograd_nnpack_weight_transform");
return CallNode::make(op, {weight}, Attrs(attrs), {});
}


TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_nnpack_weight_transform")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeConv2DWinogradNNPACKWeightTransform, args, rv);
});

runtime::detail::unpack_call<Expr, 3>(MakeConv2DWinogradNNPACKWeightTransform, args, rv);
});

RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_weight_transform")
.describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK.
Separate this into another symbol in order to enable Precompute Pass to compute the
weight transformation in advance.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs")
.set_num_inputs(1)
Expand Down
4 changes: 3 additions & 1 deletion topi/python/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,8 +760,10 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
elif cfg.template_key == "winograd_nnpack_fp16" or cfg.template_key == "winograd_nnpack_fp32":
# pre-compute winograd_nnpack transform
# for winograd_nnpack_fp16, the the precomputeprune pass must run on device (where float16 is supported)
weight_dtype = 'same' if cfg.template_key == "winograd_nnpack_fp32" else 'float16'
transformed_kernel = F.nn.contrib_conv2d_winograd_nnpack_weight_transform(
copy_inputs[1], convolution_algorithm=cfg['winograd_nnpack_algorithm'].val)
copy_inputs[1], convolution_algorithm=cfg['winograd_nnpack_algorithm'].val, out_dtype=weight_dtype)
copy_inputs[1] = transformed_kernel
new_data = data
new_kernel = tvm.placeholder((CO, CI, 8, 8), "float32")
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def conv2d_winograd_without_weight_transform(input, filter, strides, padding, di
raise ValueError("missing register for topi.nn.conv2d_winograd_without_weight_transform")


def conv2d_winograd_nnpack_weight_transform(kernel, convolution_algorithm):
def conv2d_winograd_nnpack_weight_transform(kernel, convolution_algorithm, out_dtype):
"""Weight transformation for winograd
Parameters
----------
Expand Down

0 comments on commit ce8a38f

Please sign in to comment.