Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Improve quantized convolution performance for armv8 architectures #5754

Merged
merged 15 commits into from
Jun 23, 2020
Merged
11 changes: 11 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,17 @@ struct ConvWinogradWeightTransformAttrs : public tvm::AttrsNode<ConvWinogradWeig
}
};

/*! \brief Attributes used in gemm weight transformation operators */
struct ConvGemmWeightTransformAttrs : public tvm::AttrsNode<ConvGemmWeightTransformAttrs> {
int tile_rows;
int tile_cols;

TVM_DECLARE_ATTRS(ConvGemmWeightTransformAttrs, "relay.attrs.ConvGemmWeightTransformAttrs") {
TVM_ATTR_FIELD(tile_rows).describe("Tile rows of the weight transformation for ConvGemm.");
TVM_ATTR_FIELD(tile_cols).describe("Tile columns of the weight transformation for ConvGemm.");
}
};

/*! \brief Attributes used in convolution operators with winograd algorithm */
struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
int tile_size;
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,23 @@ def compute_mirror_pad(attrs, inputs, out_dtype):
reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)

# conv2d_gemm related operators
reg.register_strategy("nn.contrib_conv2d_gemm_without_weight_transform",
strategy.conv2d_gemm_without_weight_transform_strategy)
reg.register_pattern("nn.contrib_conv2d_gemm_without_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("nn.contrib_conv2d_gemm_weight_transform")
def compute_contrib_conv2d_gemm_weight_transform(attrs, inputs, out_dtype):
"""Compute definition of contrib_conv2d_gemm_weight_transform"""
out = topi.nn.conv2d_gemm_weight_transform(
inputs[0], attrs.tile_rows, attrs.tile_cols)
return [out]

reg.register_schedule("nn.contrib_conv2d_gemm_weight_transform",
strategy.schedule_conv2d_gemm_weight_transform)
reg.register_pattern("nn.contrib_conv2d_gemm_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("nn.contrib_conv2d_winograd_weight_transform")
def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype):
Expand Down
91 changes: 91 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2046,6 +2046,74 @@ def contrib_conv2d_winograd_without_weight_transform(data,
kernel_layout, out_layout, out_dtype)


def contrib_conv2d_gemm_without_weight_transform(data,
weight,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="",
out_dtype=""):
r"""2D convolution with gemm algorithm.
giuseros marked this conversation as resolved.
Show resolved Hide resolved

The basic parameters are the same as the ones in vanilla conv2d.
It assumes the weight is pre-transformed by nn.contrib_conv2d_gemm_weight_transform

Parameters
----------
data : tvm.relay.Expr
The input data to the operator.

weight : tvm.relay.Expr
The weight expressions.

strides : tuple of int, optional
The strides of convolution.

padding : tuple of int, optional
The padding of convolution on both sides of inputs before convolution.

dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.

groups : int, optional
Number of groups for grouped convolution.

channels : int, optional
Number of output channels of this convolution.

kernel_size : tuple of int, optional
The spatial of the convolution kernel.

data_layout : str, optional
Layout of the input.

kernel_layout : str, optional
Layout of the weight.

out_layout : str, optional
Layout of the output, by default, out_layout is the same as data_layout

out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.

Returns
-------
result : tvm.relay.Expr
The computed result.
"""
# convert 2-way padding to 4-way padding
padding = get_pad_tuple2d(padding)
return _make.contrib_conv2d_gemm_without_weight_transform(
data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)


def contrib_conv2d_nchwc(data,
kernel,
strides=(1, 1),
Expand Down Expand Up @@ -2204,6 +2272,29 @@ def contrib_conv2d_winograd_weight_transform(weight,
return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size)


def contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols):
r"""Weight Transformation part for 2D convolution with gemm algorithm.

We separate this as a single op to enable pre-compute for inference.
Use this together with nn.contrib_conv2d_gemm_without_weight_transform

Parameters
----------
weights : tvm.relay.Expr
The weight expressions.
tile_rows: int
Tile rows of the weight transformation for ConvGemm.
tile_cols: int
Tile columns of the weight transformation for ConvGemm.

Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols)


def contrib_conv3d_winograd_weight_transform(weight,
tile_size):
r"""Weight Transformation part for 3D convolution with winograd algorithm.
Expand Down
42 changes: 42 additions & 0 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_direct_simd),
name='conv2d_direct_simd.micro_dev')
elif kernel_layout == "HWIO":
is_aarch64 = "aarch64" in str(isa.target)

if is_aarch64 and data.dtype in ["int8", "uint8"]:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized),
name="conv2d_NHWC_quantized.arm_cpu")

strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack),
Expand Down Expand Up @@ -246,6 +254,40 @@ def conv2d_winograd_without_weight_transfrom_strategy_arm_cpu(attrs, inputs, out
format(layout))
return strategy

def wrap_compute_conv2d_gemm(topi_compute):
"""wrap topi compute for conv2d_gemm"""

def _compute_conv2d_gemm(attrs, inputs, out_type):
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
out_dtype = attrs.get_str("out_dtype")
channels = attrs['channels']
kernel_size = attrs['kernel_size']
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
return [topi_compute(inputs[0], inputs[1], strides, padding,
dilation, out_dtype, kernel_size, channels)]

return _compute_conv2d_gemm

@conv2d_gemm_without_weight_transform_strategy.register("arm_cpu")
def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_type, target):
"""conv2d_winograd_without_weight_transfrom arm cpu strategy"""
layout = attrs.data_layout
data = inputs[0]
strategy = _op.OpStrategy()

if layout == "NHWC" and data.dtype in ['int8', 'uint8']:
strategy.add_implementation(
wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_quantized_without_transform),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized),
name="conv2d_NHWC_quantized_without_transform.arm_cpu")
else:
raise RuntimeError(
"Unsupported conv2d_gemm_without_weight_transform layout {0} with datatype {1}".
format(layout, data.dtype))
return strategy

@conv2d_transpose_strategy.register(["arm_cpu", "micro_dev"])
def conv2d_transpose_strategy_arm_cpu(attrs, inputs, out_type, target):
"""conv2d_transpose arm cpu strategy"""
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,12 @@ def conv2d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, t
"""conv2d_winograd_without_weight_transfrom generic strategy"""
raise ValueError("No generic implemenation for conv2d_winograd_without_weight_transform")

# conv2d_gemm_without_weight_transform
@override_native_generic_func("conv2d_gemm_without_weight_transform_strategy")
def conv2d_gemm_without_weight_transform_strategy(attrs, inputs, out_type, target):
"""conv2d_gemm_without_weight_transfrom generic strategy"""
raise ValueError("No generic implemenation for conv2d_gemm_without_weight_transform")

# conv2d_winograd_weight_transform
@generic_func
def schedule_conv2d_winograd_weight_transform(attrs, outs, target):
Expand All @@ -280,6 +286,13 @@ def schedule_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
with target:
return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)

# conv2d_gemm_weight_transform
@generic_func
def schedule_conv2d_gemm_weight_transform(attrs, outs, target):
"""Schedule conv2d_gemm_weight_transform"""
with target:
return topi.generic.schedule_conv2d_gemm_weight_transform(outs)

# deformable_conv2d
def wrap_compute_deformable_conv2d(topi_compute):
"""wrap deformable_conv2d topi compute"""
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,17 +237,23 @@ def is_fast_int8_on_arm():
target = tvm.target.Target.current(allow_none=False)
return '+v8.2a,+dotprod' in ' '.join(target.options)

def is_aarch64_arm():
""" Checks whether we are compiling for an AArch64 target. """
target = tvm.target.Target.current(allow_none=False)
return 'aarch64' in ' '.join(target.options)

########################
# ARM CPU legalizations.
########################

@qnn_conv2d_legalize.register('arm_cpu')
def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
# ARM prefers the dtypes to be same.
if is_fast_int8_on_arm():
if (is_aarch64_arm() and attrs["data_layout"] == "NHWC") or is_fast_int8_on_arm():
giuseros marked this conversation as resolved.
Show resolved Hide resolved
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)


@qnn_dense_legalize.register('arm_cpu')
def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):
# ARM prefers the dtypes to be same.
Expand Down
82 changes: 82 additions & 0 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,41 @@ Expr MakeConvWinograd(Expr data, Expr weight, int tile_size, Array<IndexExpr> st
return Call(op, {data, weight}, Attrs(attrs), {});
}

template <typename T>
Expr MakeConvGemm(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
Array<IndexExpr> dilation, int groups, IndexExpr channels,
Array<IndexExpr> kernel_size, std::string data_layout, std::string kernel_layout,
std::string out_layout, DataType out_dtype, std::string op_name) {
auto attrs = make_object<T>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = std::move(channels);
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
const Op& op = Op::Get(op_name);
return Call(op, {data, weight}, Attrs(attrs), {});
}

Expr MakeConvWinogradWeightTransform(Expr weight, int tile_size, std::string op_name) {
auto attrs = make_object<ConvWinogradWeightTransformAttrs>();
attrs->tile_size = tile_size;
const Op& op = Op::Get(op_name);
return Call(op, {weight}, Attrs(attrs), {});
}

Expr MakeConvGemmWeightTransform(Expr weight, int tile_rows, int tile_cols, std::string op_name) {
auto attrs = make_object<ConvGemmWeightTransformAttrs>();
attrs->tile_rows = tile_rows;
attrs->tile_cols = tile_cols;
const Op& op = Op::Get(op_name);
return Call(op, {weight}, Attrs(attrs), {});
}

template <typename T>
Expr MakeConvTranspose(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
Array<IndexExpr> dilation, int groups, IndexExpr channels,
Expand Down Expand Up @@ -504,6 +532,60 @@ weight transformation in advance.
.set_support_level(10)
.add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel);

// relay.nn.contrib_conv2d_gemm_without_weight_transform
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_gemm_without_weight_transform")
.set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
Array<IndexExpr> dilation, int groups, IndexExpr channels,
Array<IndexExpr> kernel_size, std::string data_layout,
std::string kernel_layout, std::string out_layout, DataType out_dtype) {
return MakeConvGemm<Conv2DAttrs>(
data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_gemm_without_weight_transform");
});

RELAY_REGISTER_OP("nn.contrib_conv2d_gemm_without_weight_transform")
.describe(R"code(Compute conv2d with gemm algorithm. Only supports NHWC layout.
This operator assumes the weight tensor is already pre-transformed by
nn.contrib_conv2d_gemm_weight_transform.

- **data**: Input is 4D array of shape (batch_size, height, width, in_channels)
- **weight**: Any shape
We do not check the shape for this input tensor. Since different backend
has different layout strategy.

- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width)
)code" TVM_ADD_FILELINE)
.set_attrs_type<Conv2DAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2DGemm", Conv2DGemmRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);

// relay.nn.contrib_conv2d_gemm_weight_transform

TVM_REGISTER_NODE_TYPE(ConvGemmWeightTransformAttrs);

TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_gemm_weight_transform")
.set_body_typed([](Expr weights, int tile_rows, int tile_cols) {
return MakeConvGemmWeightTransform(weights, tile_rows, tile_cols,
"nn.contrib_conv2d_gemm_weight_transform");
});

RELAY_REGISTER_OP("nn.contrib_conv2d_gemm_weight_transform")
.describe(R"code(Weight transformation of GEMM convolution algorithm.

Separate this into another operator in order to enable Precompute Pass to compute the
weight transformation in advance.

)code" TVM_ADD_FILELINE)
.set_attrs_type<ConvGemmWeightTransformAttrs>()
.set_num_inputs(1)
.add_argument("weights", "Tensor", "The weights tensor.")
.set_support_level(10)
.add_type_rel("Conv2DGemmWeightTransform", Conv2DGemmWeightTransformRel);

// Positional relay function to create conv2d NCHWc operator
// used by frontend FFI.
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc")
Expand Down
Loading