Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] enable blocking format in x86 conv2d and fold scale axis #5357

Merged
merged 13 commits into from
May 12, 2020
10 changes: 10 additions & 0 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
import logging

import re
import topi
from tvm.te import SpecializedCondition
from .generic import *
from .. import op as _op

logger = logging.getLogger('strategy')

_NCHWc_matcher = re.compile("^NCHW[0-9]+c$")
_OIHWio_matcher = re.compile("^OIHW[0-9]+i[0-9]+o$")

@schedule_injective.register("cpu")
def schedule_injective_cpu(attrs, outs, target):
"""schedule injective ops for x86"""
Expand Down Expand Up @@ -96,6 +100,9 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
wrap_compute_conv2d(topi.x86.conv2d_nchw),
wrap_topi_schedule(topi.x86.schedule_conv2d_nchw),
name="conv2d_nchw.x86")
elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc
assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio
return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
elif layout == "NHWC":
assert kernel_layout == "HWIO"
logger.warning("For x86 target, NCHW layout is recommended for conv2d.")
Expand Down Expand Up @@ -128,6 +135,9 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.generic")
elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc
assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio
return depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
elif layout == "NHWC":
assert kernel_layout == "HWOI"
logger.warning("depthwise_conv2d NHWC layout is not optimized for x86.")
Expand Down
2 changes: 2 additions & 0 deletions src/relay/op/tensor/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
namespace tvm {
namespace relay {

extern Expr MakeReshape(Expr data, Array<Integer> newshape);

template <typename AttrType>
bool ConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
Expand Down
151 changes: 118 additions & 33 deletions src/relay/transforms/fold_scale_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/relay/transform.h>
#include <tvm/tir/data_layout.h>

#include "../op/tensor/transform.h"
#include "pass_util.h"
#include "pattern_util.h"

Expand All @@ -39,6 +40,7 @@ namespace relay {
*
* Use namespace to reduce potential naming conflict.
*/

namespace fold_scale_axis {

using runtime::TypedPackedFunc;
Expand Down Expand Up @@ -305,6 +307,41 @@ class ForwardPrep : private ExprVisitor {
}
};

static bool IsIntInArray(const Array<Integer>& axis, int v) {
for (size_t i = 0; i < axis.size(); i++) {
if (axis[i] == v) return true;
}
return false;
}

static Expr ReshapeToMatchAxis(Expr scale, const Array<PrimExpr>& shape,
const Array<Integer>& axis) {
Array<Integer> arr;
for (size_t i = 0; i < shape.size(); i++) {
if (IsIntInArray(axis, i)) {
auto node = shape[i].as<IntImmNode>();
if (!node) {
// if the shape is not a constant, use normal transform
return Expr();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will cause failure as you later do CHECK(scale.defined()); ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed CHECK(scale.defined()) now. It now fallbacks to a "optimization failure" rather than a compilation error.
For ForwardRewrite, if the shape is not constant, the rewriter will return Expr().
For BackwardRewrite, if the shape is not constant, the rewriter will return transformer->NormalCallTransform(call.operator->()).

}
arr.push_back(node->value);
} else {
arr.push_back(1);
}
}
return MakeReshape(scale, std::move(arr));
}

// if only one axis, use expand dim. Else, use reshape
static Expr ReshapeOrExpandToMatchAxis(Expr scale, const Array<PrimExpr>& shape,
const Array<Integer>& axis) {
if (axis.size() > 1) {
return ReshapeToMatchAxis(scale, shape, axis);
} else {
return ExpandBiasToMatchAxis(scale, shape.size(), axis);
}
}

//----------------------------------------------
// Per operator defs for FScaleAxisForward
//----------------------------------------------
Expand Down Expand Up @@ -365,15 +402,21 @@ Expr AddSubForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
if (slhs != nullptr) {
CHECK(srhs == nullptr);
CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes));
Expr scale = ExpandBiasToMatchAxis(slhs->scale, tlhs->shape.size(), slhs->axes);
Expr scale = ReshapeOrExpandToMatchAxis(slhs->scale, tlhs->shape, slhs->axes);
if (!scale.defined()) {
return Expr();
icemelon marked this conversation as resolved.
Show resolved Hide resolved
}
Expr rhs = Divide(new_args[1], scale);
rnode->value = Call(ref_call->op, {slhs->value, rhs}, ref_call->attrs, ref_call->type_args);
rnode->scale = slhs->scale;
rnode->axes = slhs->axes;
} else {
CHECK(srhs != nullptr);
CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes));
Expr scale = ExpandBiasToMatchAxis(srhs->scale, trhs->shape.size(), srhs->axes);
Expr scale = ReshapeOrExpandToMatchAxis(srhs->scale, trhs->shape, srhs->axes);
if (!scale.defined()) {
return Expr();
icemelon marked this conversation as resolved.
Show resolved Hide resolved
}
Expr lhs = Divide(new_args[0], scale);
rnode->value = Call(ref_call->op, {lhs, srhs->value}, ref_call->attrs, ref_call->type_args);
rnode->scale = srhs->scale;
Expand Down Expand Up @@ -445,7 +488,6 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {

CHECK_GE(c_big_axis, 0);
Message none = NullValue<Message>();
AxesSet data_axes = NullValue<AxesSet>();
// For now, we only support simple pattern (no folded weight/data)
// More general layout can be supported under the current framework.
// By using a unified layout transformation.
Expand All @@ -454,12 +496,17 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if (kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) {
data_axes = {c_big_axis};
}
if (data_axes.defined()) {
return {Message(data_axes, false), none};
if (param->groups == 1 || is_depthwise_conv2d) {
auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
(ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout
Array<Integer> arr{c_big_axis};
if (c_small_axis >= 0) {
arr.push_back(c_small_axis);
}
return {Message(arr, false), none};
}
}
return {none, none};
}
Expand All @@ -478,12 +525,14 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
Layout kernel_layout(param->kernel_layout);
int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout
CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
CHECK(sdata->axes.size() == 1 && c_big_axis == sdata->axes[0]->value);
int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
int big_ic_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
int small_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));

bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0);
bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0);
CHECK(is_simple || is_blocking);

// Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout);
Expand All @@ -493,11 +542,26 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,

// match the ic_axis
if (is_depthwise_conv2d) {
Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_oc_axis});
weight = Multiply(weight, scale);
if (is_simple) {
Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ko_axis});
weight = Multiply(weight, scale);
} else {
weight = Multiply(weight,
ReshapeToMatchAxis(sdata->scale, weight->type_as<TensorTypeNode>()->shape,
{big_ko_axis, small_ko_axis}));
if (!weight.defined()) return Expr();
}

} else {
Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ic_axis});
weight = Multiply(weight, scale);
if (is_simple) {
Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ki_axis});
weight = Multiply(weight, scale);
} else {
weight = Multiply(weight,
ReshapeToMatchAxis(sdata->scale, weight->type_as<TensorTypeNode>()->shape,
{big_ki_axis, small_ki_axis}));
if (!weight.defined()) return Expr();
}
}
// return transformed conv2d
return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
Expand Down Expand Up @@ -752,14 +816,20 @@ Expr AddSubBackwardTransform(const Call& call, const Message& message, const Exp
CHECK(equal(message->axes, lhs_message->axes));
Expr lhs = transformer->Transform(call->args[0], message, scale);
Expr rhs = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>());
Expr rhs_scale = ExpandBiasToMatchAxis(scale, tlhs->shape.size(), message->axes);
Expr rhs_scale = ReshapeOrExpandToMatchAxis(scale, tlhs->shape, message->axes);
if (!rhs_scale.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
rhs = Multiply(rhs, rhs_scale);
return Call(call->op, {lhs, rhs}, call->attrs, call->type_args);
} else if (rhs_message.defined()) {
CHECK(equal(message->axes, rhs_message->axes));
Expr lhs = transformer->Transform(call->args[0], NullValue<Message>(), NullValue<Expr>());
Expr rhs = transformer->Transform(call->args[1], message, scale);
Expr lhs_scale = ExpandBiasToMatchAxis(scale, trhs->shape.size(), message->axes);
Expr lhs_scale = ReshapeOrExpandToMatchAxis(scale, trhs->shape, message->axes);
if (!lhs_scale.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
lhs = Multiply(lhs, lhs_scale);
return Call(call->op, {lhs, rhs}, call->attrs, call->type_args);
} else {
Expand Down Expand Up @@ -829,13 +899,19 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
if (kernel_layout.IndexOf(LayoutAxis::Get('o')) < 0 &&
kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) {
return Message({c_big_axis}, false);
} else {
return NullValue<Message>();
if (param->groups == 1 || is_depthwise_conv2d) {
auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
(ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout
Array<Integer> arr{c_big_axis};
if (c_small_axis >= 0) {
arr.push_back(c_small_axis);
}
return Message(arr, false);
}
}
return NullValue<Message>();
}

// Conv2D consumes the scale axis during transformation.
Expand All @@ -852,19 +928,28 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp
CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout
CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('o')), -1);
CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
CHECK(message->axes.size() == 1 && c_big_axis == message->axes[0]->value);

int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
int small_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
// Check it must be depthwise or full conv2d.
bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
CHECK(param->groups == 1 || is_depthwise_conv2d);
bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0);
bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0);
CHECK(is_simple || is_blocking);

Expr data = transformer->Transform(call->args[0], NullValue<Message>(), NullValue<Expr>());
Expr weight = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>());
// scale on input for deptwise.
Expr wscale = ExpandBiasToMatchAxis(scale, kernel_layout.ndim(), {big_oc_axis});
Expr wscale;
if (is_simple) {
wscale = ExpandBiasToMatchAxis(scale, kernel_layout.ndim(), {big_ko_axis});
} else {
wscale = ReshapeToMatchAxis(scale, weight->type_as<TensorTypeNode>()->shape,
{big_ko_axis, small_ko_axis});
if (!wscale.defined()) return transformer->NormalCallTransform(call.operator->());
}
weight = Multiply(weight, wscale);
return Call(call->op, {data, weight}, call->attrs, call->type_args);
}
Expand Down
Loading