Skip to content

Commit

Permalink
clang format
Browse files Browse the repository at this point in the history
  • Loading branch information
Menooker committed May 12, 2020
1 parent c8efc1b commit 5192ad7
Showing 1 changed file with 48 additions and 60 deletions.
108 changes: 48 additions & 60 deletions src/relay/transforms/fold_scale_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
#include <tvm/relay/transform.h>
#include <tvm/tir/data_layout.h>
#include "../op/tensor/transform.h"
#include "pattern_util.h"
#include "pass_util.h"
#include "pattern_util.h"

#include "pass_util.h"
#include "pattern_util.h"
Expand Down Expand Up @@ -311,8 +311,7 @@ class ForwardPrep : private ExprVisitor {

static bool IsIntInArray(const Array<Integer>& axis, int v) {
for (size_t i = 0; i < axis.size(); i++) {
if (axis[i] == v)
return true;
if (axis[i] == v) return true;
}
return false;
}
Expand Down Expand Up @@ -370,8 +369,8 @@ Expr ReluForwardRewrite(const Call& ref_call, const Array<Expr>& new_args, const

RELAY_REGISTER_OP("nn.relu").set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);

RELAY_REGISTER_OP("nn.relu").set_attr<FForwardRewrite>("FScaleAxisForwardRewrite",
ReluForwardRewrite);
RELAY_REGISTER_OP("nn.relu")
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", ReluForwardRewrite);

RELAY_REGISTER_OP("nn.leaky_relu").set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);

Expand Down Expand Up @@ -405,8 +404,7 @@ Expr AddSubForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
if (slhs != nullptr) {
CHECK(srhs == nullptr);
CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes));
Expr scale = ReshapeOrExpandToMatchAxis(
slhs->scale, tlhs->shape, slhs->axes);
Expr scale = ReshapeOrExpandToMatchAxis(slhs->scale, tlhs->shape, slhs->axes);
if (!scale.defined()) {
return Expr();
}
Expand All @@ -417,8 +415,7 @@ Expr AddSubForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
} else {
CHECK(srhs != nullptr);
CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes));
Expr scale = ReshapeOrExpandToMatchAxis(
srhs->scale, trhs->shape, srhs->axes);
Expr scale = ReshapeOrExpandToMatchAxis(srhs->scale, trhs->shape, srhs->axes);
if (!scale.defined()) {
return Expr();
}
Expand All @@ -432,8 +429,8 @@ Expr AddSubForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,

RELAY_REGISTER_OP("add").set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);

RELAY_REGISTER_OP("add").set_attr<FForwardRewrite>("FScaleAxisForwardRewrite",
AddSubForwardRewrite);
RELAY_REGISTER_OP("add")
.set_attr<FForwardRewrite>("FScaleAxisForwardRewrite", AddSubForwardRewrite);

RELAY_REGISTER_OP("subtract").set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);

Expand Down Expand Up @@ -504,14 +501,14 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
if (param->groups == 1 || is_depthwise_conv2d) {
auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
if ( (ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
(ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout
Array<Integer> arr{c_big_axis};
if (c_small_axis >= 0) {
arr.push_back(c_small_axis);
}
return {Message(arr, false), none};
Array<Integer> arr{c_big_axis};
if (c_small_axis >= 0) {
arr.push_back(c_small_axis);
}
return {Message(arr, false), none};
}
}
return {none, none};
}
Expand Down Expand Up @@ -548,28 +545,24 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
// match the ic_axis
if (is_depthwise_conv2d) {
if (is_simple) {
Expr scale = ExpandBiasToMatchAxis(
sdata->scale, kernel_layout.ndim(), {big_ko_axis});
Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ko_axis});
weight = Multiply(weight, scale);
} else {
weight = Multiply(weight, ReshapeToMatchAxis(sdata->scale,
weight->type_as<TensorTypeNode>()->shape,
{big_ko_axis, small_ko_axis}));
if (!weight.defined())
return Expr();
weight = Multiply(weight,
ReshapeToMatchAxis(sdata->scale, weight->type_as<TensorTypeNode>()->shape,
{big_ko_axis, small_ko_axis}));
if (!weight.defined()) return Expr();
}

} else {
if (is_simple) {
Expr scale = ExpandBiasToMatchAxis(
sdata->scale, kernel_layout.ndim(), {big_ki_axis});
Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ki_axis});
weight = Multiply(weight, scale);
} else {
weight = Multiply(weight, ReshapeToMatchAxis(sdata->scale,
weight->type_as<TensorTypeNode>()->shape,
{big_ki_axis, small_ki_axis}));
if (!weight.defined())
return Expr();
weight = Multiply(weight,
ReshapeToMatchAxis(sdata->scale, weight->type_as<TensorTypeNode>()->shape,
{big_ki_axis, small_ki_axis}));
if (!weight.defined()) return Expr();
}
}
// return transformed conv2d
Expand Down Expand Up @@ -775,8 +768,8 @@ Expr ReluBackwardTransform(const Call& call, const Message& message, const Expr&

RELAY_REGISTER_OP("nn.relu").set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep);

RELAY_REGISTER_OP("nn.relu").set_attr<FBackwardTransform>("FScaleAxisBackwardTransform",
ReluBackwardTransform);
RELAY_REGISTER_OP("nn.relu")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform);

RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep);
Expand Down Expand Up @@ -824,10 +817,8 @@ Expr AddSubBackwardTransform(const Call& call, const Message& message, const Exp
} else if (lhs_message.defined()) {
CHECK(equal(message->axes, lhs_message->axes));
Expr lhs = transformer->Transform(call->args[0], message, scale);
Expr rhs = transformer->Transform(
call->args[1], NullValue<Message>(), NullValue<Expr>());
Expr rhs_scale = ReshapeOrExpandToMatchAxis(scale, tlhs->shape,
message->axes);
Expr rhs = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>());
Expr rhs_scale = ReshapeOrExpandToMatchAxis(scale, tlhs->shape, message->axes);
if (!rhs_scale.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
Expand All @@ -837,8 +828,7 @@ Expr AddSubBackwardTransform(const Call& call, const Message& message, const Exp
CHECK(equal(message->axes, rhs_message->axes));
Expr lhs = transformer->Transform(call->args[0], NullValue<Message>(), NullValue<Expr>());
Expr rhs = transformer->Transform(call->args[1], message, scale);
Expr lhs_scale = ReshapeOrExpandToMatchAxis(
scale, trhs->shape, message->axes);
Expr lhs_scale = ReshapeOrExpandToMatchAxis(scale, trhs->shape, message->axes);
if (!lhs_scale.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
Expand All @@ -852,8 +842,8 @@ Expr AddSubBackwardTransform(const Call& call, const Message& message, const Exp

RELAY_REGISTER_OP("add").set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep);

RELAY_REGISTER_OP("add").set_attr<FBackwardTransform>("FScaleAxisBackwardTransform",
AddSubBackwardTransform);
RELAY_REGISTER_OP("add")
.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);

RELAY_REGISTER_OP("subtract").set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep);

Expand Down Expand Up @@ -914,14 +904,14 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
if (param->groups == 1 || is_depthwise_conv2d) {
auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
if ( (ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout
(ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout
Array<Integer> arr{c_big_axis};
if (c_small_axis >= 0) {
arr.push_back(c_small_axis);
}
return Message(arr, false);
Array<Integer> arr{c_big_axis};
if (c_small_axis >= 0) {
arr.push_back(c_small_axis);
}
return Message(arr, false);
}
}
return NullValue<Message>();
}
Expand Down Expand Up @@ -956,13 +946,11 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp
// scale on input for deptwise.
Expr wscale;
if (is_simple) {
wscale = ExpandBiasToMatchAxis(
scale, kernel_layout.ndim(), {big_ko_axis});
wscale = ExpandBiasToMatchAxis(scale, kernel_layout.ndim(), {big_ko_axis});
} else {
wscale = ReshapeToMatchAxis(scale, weight->type_as<TensorTypeNode>()->shape,
{big_ko_axis, small_ko_axis});
if (!wscale.defined())
return transformer->NormalCallTransform(call.operator->());
{big_ko_axis, small_ko_axis});
if (!wscale.defined()) return transformer->NormalCallTransform(call.operator->());
}
weight = Multiply(weight, wscale);
return Call(call->op, {data, weight}, call->attrs, call->type_args);
Expand All @@ -983,20 +971,20 @@ Expr BackwardFoldScaleAxis(const Expr& data) {
namespace transform {

Pass ForwardFoldScaleAxis() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::fold_scale_axis::ForwardFoldScaleAxis(f));
};
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = [=](
Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::fold_scale_axis::ForwardFoldScaleAxis(f));
};
return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis").set_body_typed(ForwardFoldScaleAxis);

Pass BackwardFoldScaleAxis() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::fold_scale_axis::BackwardFoldScaleAxis(f));
};
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = [=](
Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::fold_scale_axis::BackwardFoldScaleAxis(f));
};
return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", {"InferType"});
}

Expand Down

0 comments on commit 5192ad7

Please sign in to comment.