From 112945a856b74614fa7541dbb3bb96b5b5bb0598 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 28 Oct 2018 14:20:00 -0700 Subject: [PATCH 1/2] [RELAY][PASS] FoldScaleAxis Backward --- include/tvm/relay/expr_functor.h | 6 +- python/tvm/relay/ir_pass.py | 29 ++ src/relay/ir/expr_functor.cc | 12 +- src/relay/pass/fold_scale_axis.cc | 453 +++++++++++++++++- src/relay/pass/pattern_util.h | 23 +- .../python/relay/test_pass_fold_scale_axis.py | 177 ++++++- 6 files changed, 666 insertions(+), 34 deletions(-) diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index bf4025f79224..85a6b502d845 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -135,9 +135,9 @@ class ExprVisitor void VisitExpr_(const TupleGetItemNode* op) override; virtual void VisitType(const Type& t); - private: - // internal visited flag. - std::unordered_set visited_; + protected: + // Internal visiting counter + std::unordered_map visit_counter_; }; /*! diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 6adfaacdc86d..82afa83ee376 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -31,6 +31,29 @@ def infer_type(expr, env=None): return _ir_pass.infer_type(expr, env) +def backward_fold_scale_axis(expr): + """Backward fold axis scaling into weights of conv2d/dense. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression, we expect that expr's types + should be fully inferred by infer_type. + + Returns + ------- + folded_expr : tvm.relay.Expr + The folded expression after transformation. + + Note + ---- + It is recommended to call backward_fold_scale_axis + before using forward_fold_scale_axis. + As backward folding targets common conv-bn pattern. + """ + return _ir_pass.backward_fold_scale_axis(expr) + + def forward_fold_scale_axis(expr): """Fold the scaling of axis into weights of conv2d/dense. @@ -44,6 +67,12 @@ def forward_fold_scale_axis(expr): ------- folded_expr : tvm.relay.Expr The folded expression after transformation. + + Note + ---- + It is recommended to call backward_fold_scale_axis + before using forward_fold_scale_axis. + As backward folding targets common conv-bn pattern. """ return _ir_pass.forward_fold_scale_axis(expr) diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index b7a752d43a5c..ed7c1d1d1e5a 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -160,10 +160,14 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) { Type ExprMutator::VisitType(const Type& t) { return t; } void ExprVisitor::VisitExpr(const Expr& expr) { - if (visited_.count(expr.get())) return; - using TParent = ExprFunctor; - TParent::VisitExpr(expr); - visited_.insert(expr.get()); + auto it = visit_counter_.find(expr.get()); + if (it != visit_counter_.end()) { + ++it->second; + } else { + using TParent = ExprFunctor; + TParent::VisitExpr(expr); + visit_counter_.insert({expr.get(), 1}); + } } void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index b1c767704372..68b79d852917 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -24,9 +24,9 @@ namespace fold_scale_axis { using runtime::TypedPackedFunc; -// FoldScaleAxisFoward algorithm: +// FoldScaleAxis algorithm: // -// The general idea is that we transform Expr to tuple of +// The general idea is to transform Expr to tuple of // (value, axes, scale), where the final result satiesfies: // // result = value @@ -41,9 +41,14 @@ using runtime::TypedPackedFunc; // we run a backward "preparation phase", which propagates the demand // of the potential axes scaling back to its input. // -// The folding process is done in two steps: +// Forward folding process is done in two steps: // - Prepare phase: backward propagation of demand. // - Transform phase: forward transformation, +// +// Similarly, borward folding process is done in two steps: +// - Prepare phase: forward propagation of demand. +// - Transform phase: transformation by push down the axes scale signal to inputs. +// /*! * \brief sorted array axis, can also be nullptr. @@ -144,7 +149,7 @@ using FForwardTransform = TypedPackedFunc< //---------------------------------------------- // Generic Visitors for FScaleAxisForward //---------------------------------------------- -class FScaleAxisForwardPrep : private ExprVisitor { +class ForwardPrep : private ExprVisitor { public: std::unordered_map Prepare(const Expr& body) { @@ -255,12 +260,12 @@ class FScaleAxisForwardPrep : private ExprVisitor { } }; -class FScaleAxisForwardTransform : private ExprMutator { +class ForwardTransformer : private ExprMutator { public: // Transform expression. - Expr Transform(Expr expr) { + Expr Fold(Expr expr) { expected_scale_axes_ = - FScaleAxisForwardPrep().Prepare(expr); + ForwardPrep().Prepare(expr); return this->Mutate(expr); } @@ -346,13 +351,13 @@ Array ReluForwardPrep(const Call& call, AxesSet out) { } STuple ReluForwardTransform(const Call& ref_call, - const AxesSet& expected_axes, - const Array& sargs) { + const AxesSet& expected_axes, + const Array& sargs) { if (!sargs[0]->axes.defined()) return STuple(); // return transformed conv2d auto rnode = make_node(); rnode->value = CallNode::make( - ref_call->op, {sargs[0]->value}, ref_call->attrs, {}); + ref_call->op, {sargs[0]->value}, ref_call->attrs, ref_call->type_args); rnode->scale = sargs[0]->scale; rnode->axes = sargs[0]->axes; return STuple(rnode); @@ -474,8 +479,6 @@ Array Conv2DForwardPrep(const Call& call, AxesSet out) { Layout weight_layout(param->weight_layout); int c_big_axis = data_layout.indexof('C'); int c_small_axis = data_layout.indexof('c'); - const auto* tdata = call->args[0]->type_as(); - CHECK(tdata) << "require checked type"; CHECK_GE(c_big_axis, 0); AxesSet data_axes = NullValue(); @@ -486,8 +489,7 @@ Array Conv2DForwardPrep(const Call& call, AxesSet out) { // // only handle depthwise or full conv2d. // TODO(tvm-team) handle grouped conv by reshape + bcast - bool is_depthwise_conv2d = - is_const_int(tdata->shape[c_big_axis], param->groups); + bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout); if (weight_layout.indexof('i') < 0 && c_small_axis < 0 && (param->groups == 1 || is_depthwise_conv2d)) { @@ -515,18 +517,24 @@ STuple Conv2DForwardTransform(const Call& ref_call, CHECK_EQ(weight_layout.indexof('i'), -1); CHECK(sdata->axes.size() == 1 && c_big_axis == sdata->axes[0]->value); + int big_oc_axis = weight_layout.indexof('O'); int big_ic_axis = weight_layout.indexof('I'); - const auto* tdata = ref_call->args[0]->type_as(); // Check it must be depthwise or full conv2d. - bool is_depthwise_conv2d = - is_const_int(tdata->shape[c_big_axis], param->groups); + bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, weight_layout); CHECK(param->groups == 1 || is_depthwise_conv2d); + Expr weight = sargs[1]->value; // match the ic_axis - Expr scale = ExpandBiasToMatchAxis( - sdata->scale, weight_layout.ndim(), {big_ic_axis}); - Expr weight = Multiply(sargs[1]->value, scale); + if (is_depthwise_conv2d) { + Expr scale = ExpandBiasToMatchAxis( + sdata->scale, weight_layout.ndim(), {big_oc_axis}); + weight = Multiply(weight, scale); + } else { + Expr scale = ExpandBiasToMatchAxis( + sdata->scale, weight_layout.ndim(), {big_ic_axis}); + weight = Multiply(weight, scale); + } // return transformed conv2d auto rnode = make_node(); rnode->value = CallNode::make( @@ -542,13 +550,416 @@ RELAY_REGISTER_OP("nn.conv2d") Expr ForwardFoldScaleAxis(Expr data) { - return FScaleAxisForwardTransform().Transform(data); + return ForwardTransformer().Fold(data); } // Expose the FoldScaleAxisFoward TVM_REGISTER_API("relay._ir_pass.forward_fold_scale_axis") .set_body_typed(ForwardFoldScaleAxis); +//---------------------------------------- +// Implement backward transformations. +//---------------------------------------- +class BackwardTransformer; + +/*! + * \brief Preparation function for for pass scale backward. + * \param call The call node. + * \param in_scale_axes Allowed input scaling. + * \return The result scaling on axes of the input. + */ +using FBackwardPrep = TypedPackedFunc< + AxesSet(const Call& call, const Array& in_scale_axes)>; + +using FBackwardTransform = TypedPackedFunc< + Expr(const Call& call, + const AxesSet& axes, + const Expr& scale, + const BackwardTransformer& transformer)>; + +//---------------------------------------------- +// Generic Visitors for FScaleAxisBackward +//---------------------------------------------- +/*! + * \brief Get reference counter of each internal ExprNode in body. + * \param body The body expression. + * \return The reference count mapping. + */ +std::unordered_map +GetExprRefCount(const Expr& body) { + class ExprRefCounter : private ExprVisitor { + public: + std::unordered_map + Get(const Expr& body) { + this->VisitExpr(body); + return std::move(this->visit_counter_); + } + }; + return ExprRefCounter().Get(body); +} + +class BackwardPrep : private ExprVisitor { + public: + // The message on each node. + std::unordered_map + Prepare(const Expr& body) { + ref_counter_ = GetExprRefCount(body); + this->VisitExpr(body); + return std::move(message_); + } + + private: + // The message on each node. + std::unordered_map message_; + // reference counter of an internal expr + std::unordered_map ref_counter_; + // Visit the expression. + void VisitExpr_(const CallNode* call) { + ExprVisitor::VisitExpr_(call); + static const auto& fprep = + Op::GetAttr("FScaleAxisBackwardPrep"); + auto f = GetFunc(fprep, call->op); + if (f == nullptr) return; + auto rit = ref_counter_.find(call); + CHECK(rit != ref_counter_.end()); + // We only allow propagation of scale backward + // if the expression is only referred by a single parent. + if (rit->second != 1) return; + Array in_axes; + for (Expr arg : call->args) { + auto it = message_.find(arg.get()); + if (it != message_.end()) { + in_axes.push_back(it->second); + } else { + in_axes.push_back(NullValue()); + } + } + AxesSet out_axes = f(GetRef(call), in_axes); + if (out_axes.defined()) { + message_[call] = out_axes; + } + } +}; + +class BackwardTransformerNode : + public Node, + private ExprMutator { + public: + // Run forward transform. + Expr Fold(Expr expr) { + expected_scale_axes_ = BackwardPrep().Prepare(expr); + return this->Mutate(expr); + } + /*! + * \brief Transform the expr to consider the scaling. + * + * \param expr The input expression. + * \param axes The axes to scale. + * \param scale The scale applied to the axes. + * \return The result of transformation. + */ + Expr Transform(const Expr& expr, AxesSet axes, Expr scale) { + // NOTE: the result of Transform is not memoized. + // However, in the current rule, Transform will + // only be called to expr that is referred once. + if (const CallNode* call_node = expr.as()) { + return Transform(call_node, axes, scale); + } else { + CHECK(!axes.defined()) << "outstanding scale"; + return ExprMutator::VisitExpr(expr); + } + } + /*! + * \brief Normal way of mutating call node. + * \param call_node The call node to be mutated. + * \return the result of the call Mutation. + */ + Expr NormalCallTransform(const CallNode* call_node) { + return ExprMutator::VisitExpr_(call_node); + } + /*! + * \brief Get the expected axes on expr. + * \param expr The expresison. + * \return The expected axes. + */ + AxesSet GetExpectedAxes(const Expr& expr) const { + auto it = expected_scale_axes_.find(expr.get()); + if (it != expected_scale_axes_.end()) return it->second; + return NullValue(); + } + + // solver is not serializable. + void VisitAttrs(tvm::AttrVisitor* v) final {} + + static constexpr const char* _type_key = "relay.fold_scale_axis.FBackwardTransformer"; + TVM_DECLARE_NODE_TYPE_INFO(BackwardTransformerNode, Node); + + private: + // Valid axes on each node. + std::unordered_map expected_scale_axes_; + // Override mutation of call. + Expr VisitExpr_(const CallNode* call_node) final { + return Transform(call_node, NullValue(), NullValue()); + } + // Transform of CallNode. + Expr Transform(const CallNode* call_node, AxesSet axes, Expr scale); +}; + +class BackwardTransformer : public NodeRef { + public: + BackwardTransformer() {} + explicit BackwardTransformer( + ::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) { + } + BackwardTransformerNode* operator->() const { + return static_cast(node_.get()); + } + using ContainerType = BackwardTransformerNode; +}; + +Expr BackwardTransformerNode::Transform( + const CallNode* call_node, AxesSet axes, Expr scale) { + static const auto& ftransform = + Op::GetAttr("FScaleAxisBackwardTransform"); + auto f = GetFunc(ftransform, call_node->op); + if (f != nullptr) { + return f(GetRef(call_node), + axes, + scale, + GetRef(this)); + } else { + CHECK(!axes.defined()) << "outstanding scale"; + return NormalCallTransform(call_node); + } +} + + +//---------------------------------------------- +// Per operator defs for FScaleAxisForward +//---------------------------------------------- + +// Intermediate operators +AxesSet ReluBackwardPrep(const Call& call, const Array& in_axes) { + return in_axes[0]; +} + +Expr ReluBackwardTransform(const Call& call, + const AxesSet& axes, + const Expr& scale, + const BackwardTransformer& transformer) { + if (!axes.defined()) { + return transformer->NormalCallTransform(call.operator->()); + } + Expr input = transformer->Transform( + call->args[0], axes, scale); + return CallNode::make(call->op, {input}, call->attrs, call->type_args); +} + +RELAY_REGISTER_OP("nn.relu") +.set_attr("FScaleAxisBackwardPrep", ReluBackwardPrep); + +RELAY_REGISTER_OP("nn.relu") +.set_attr("FScaleAxisBackwardTransform", ReluBackwardTransform); + +RELAY_REGISTER_OP("nn.leaky_relu") +.set_attr("FScaleAxisBackwardPrep", ReluBackwardPrep); + +RELAY_REGISTER_OP("nn.leaky_relu") +.set_attr("FScaleAxisBackwardTransform", ReluBackwardTransform); + +// AddSub +AxesSet AddSubBackwardPrep(const Call& call, const Array& in_axes) { + const auto* tlhs = call->args[0]->type_as(); + const auto* trhs = call->args[1]->type_as(); + AttrsEqual equal; + if (in_axes[0].defined() && + MatchBroadcastToLeftAxes(tlhs, trhs, in_axes[0])) { + return in_axes[0]; + } else if (in_axes[1].defined() && + MatchBroadcastToLeftAxes(trhs, tlhs, in_axes[1])) { + return in_axes[1]; + } else if (in_axes[0].defined() && + in_axes[1].defined() && + equal(in_axes[0], in_axes[1]) && + equal(tlhs->shape, trhs->shape)) { + // add of two elements. + return in_axes[0]; + } else { + return NullValue(); + } +} + +Expr AddSubBackwardTransform(const Call& call, + const AxesSet& axes, + const Expr& scale, + const BackwardTransformer& transformer) { + const auto* tlhs = call->args[0]->type_as(); + const auto* trhs = call->args[1]->type_as(); + if (!axes.defined()) { + return transformer->NormalCallTransform(call.operator->()); + } + AxesSet lhs_axes = transformer->GetExpectedAxes(call->args[0]); + AxesSet rhs_axes = transformer->GetExpectedAxes(call->args[1]); + AttrsEqual equal; + + if (lhs_axes.defined() && rhs_axes.defined()) { + CHECK(equal(lhs_axes, rhs_axes)); + CHECK(equal(axes, lhs_axes)); + Expr lhs = transformer->Transform(call->args[0], axes, scale); + Expr rhs = transformer->Transform(call->args[1], axes, scale); + return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args); + } else if (lhs_axes.defined()) { + CHECK(equal(axes, lhs_axes)); + Expr lhs = transformer->Transform(call->args[0], axes, scale); + Expr rhs = transformer->Transform( + call->args[1], NullValue(), NullValue()); + Expr rhs_scale = ExpandBiasToMatchAxis( + scale, tlhs->shape.size(), axes); + rhs = Multiply(rhs, rhs_scale); + return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args); + } else if (rhs_axes.defined()) { + CHECK(equal(axes, rhs_axes)); + Expr lhs = transformer->Transform( + call->args[0], NullValue(), NullValue()); + Expr rhs = transformer->Transform(call->args[1], axes, scale); + Expr lhs_scale = ExpandBiasToMatchAxis( + scale, trhs->shape.size(), axes); + lhs = Multiply(lhs, lhs_scale); + return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args); + } else { + LOG(FATAL) << "outstanding scale"; + return Expr(); + } +} + +RELAY_REGISTER_OP("add") +.set_attr("FScaleAxisBackwardPrep", AddSubBackwardPrep); + +RELAY_REGISTER_OP("add") +.set_attr("FScaleAxisBackwardTransform", AddSubBackwardTransform); + +RELAY_REGISTER_OP("subtract") +.set_attr("FScaleAxisBackwardPrep", AddSubBackwardPrep); + +RELAY_REGISTER_OP("subtract") +.set_attr("FScaleAxisBackwardTransform", AddSubBackwardTransform); + +// Producer operators +// Multiply produces the scale-axis pair. +Expr MultiplyBackwardTransform(const Call& call, + const AxesSet& axes, + const Expr& scale, + const BackwardTransformer& transformer) { + CHECK(!axes.defined()) << "outstanding scale"; + const auto* tlhs = call->args[0]->type_as(); + const auto* trhs = call->args[1]->type_as(); + AxesSet lhs_axes = transformer->GetExpectedAxes(call->args[0]); + AxesSet rhs_axes = transformer->GetExpectedAxes(call->args[1]); + if (lhs_axes.defined()) { + // NOTE we won't recursively call mutating on scale part. + // since there won't be scale chance within scale part. + Expr rhs = call->args[1]; + if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs)) { + return transformer->Transform(call->args[0], lhs_axes, rhs); + } + } else if (rhs_axes.defined()) { + Expr lhs = call->args[0]; + if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs)) { + return transformer->Transform(call->args[1], rhs_axes, lhs); + } + } + return transformer->NormalCallTransform(call.operator->()); +} + +RELAY_REGISTER_OP("multiply") +.set_attr("FScaleAxisBackwardTransform", MultiplyBackwardTransform); + +// Consumer operators +// Conv2D send out requirement of axis folding. +AxesSet Conv2DBackwardPrep(const Call& call, const Array& in_axes) { + const auto* param = call->attrs.as(); + CHECK(param != nullptr); + Layout out_layout(param->out_layout); + if (!out_layout.defined()) { + out_layout = Layout(param->data_layout); + } + Layout weight_layout(param->weight_layout); + int c_big_axis = out_layout.indexof('C'); + int c_small_axis = out_layout.indexof('c'); + + CHECK_GE(c_big_axis, 0); + // For now, we only support simple pattern (no folded weight/data) + // More general layout can be supported under the current framework. + // By using a unified layout transformation. + // We only need to change the Prep and Mutate function. + // + // only handle depthwise or full conv2d. + // TODO(tvm-team) handle grouped conv by reshape + bcast + bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout); + if (weight_layout.indexof('o') < 0 && + weight_layout.indexof('i') < 0 && + c_small_axis < 0 && + (param->groups == 1 || is_depthwise_conv2d)) { + return {c_big_axis}; + } else { + return NullValue(); + } +} + +// Conv2D consumes the scale axis during transformation. +Expr Conv2DBackwardTransform(const Call& call, + const AxesSet& axes, + const Expr& scale, + const BackwardTransformer& transformer) { + if (!axes.defined()) { + return transformer->NormalCallTransform(call.operator->()); + } + const auto* param = call->attrs.as(); + CHECK(param != nullptr); + Layout out_layout(param->out_layout); + if (!out_layout.defined()) { + out_layout = Layout(param->data_layout); + } + Layout weight_layout(param->weight_layout); + int c_big_axis = out_layout.indexof('C'); + CHECK_GE(c_big_axis, 0); + // For now, we only support simple pattern (no folded weight/data) + // TODO(tvm-team) support general data layout + CHECK_EQ(weight_layout.indexof('o'), -1); + CHECK_EQ(weight_layout.indexof('i'), -1); + CHECK(axes.size() == 1 && + c_big_axis == axes[0]->value); + + int big_oc_axis = weight_layout.indexof('O'); + // Check it must be depthwise or full conv2d. + bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout); + CHECK(param->groups == 1 || is_depthwise_conv2d); + + Expr data = transformer->Transform( + call->args[0], NullValue(), NullValue()); + Expr weight = transformer->Transform( + call->args[1], NullValue(), NullValue()); + // scale on input for deptwise. + Expr wscale = ExpandBiasToMatchAxis( + scale, weight_layout.ndim(), {big_oc_axis}); + weight = Multiply(weight, wscale); + return CallNode::make( + call->op, {data, weight}, call->attrs, call->type_args); +} + +RELAY_REGISTER_OP("nn.conv2d") +.set_attr("FScaleAxisBackwardPrep", Conv2DBackwardPrep); + +RELAY_REGISTER_OP("nn.conv2d") +.set_attr("FScaleAxisBackwardTransform", Conv2DBackwardTransform); + +Expr BackwardFoldScaleAxis(Expr data) { + return make_node()->Fold(data); +} + +TVM_REGISTER_API("relay._ir_pass.backward_fold_scale_axis") +.set_body_typed(BackwardFoldScaleAxis); + } // namespace fold_scale_axis } // namespace relay } // namespace tvm diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index a395e74cdf0b..a41e6c35b93a 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -11,6 +11,7 @@ #include #include #include +#include "../op/nn/layout.h" namespace tvm { namespace relay { @@ -100,11 +101,31 @@ inline Expr ExpandBiasToMatchAxis(Expr bias, return bias; } +/*! + * \brief Check if the call is depthwise conv2d. + * + * \param call The conv2d call. + * \param param The conv2d attributes. + * \return Whether it is depthwise_conv2d. + */ +inline bool IsDepthwiseConv2D(const Call& call, + const Conv2DAttrs* param, + const Layout& weight_layout) { + static const Layout kOIHW("OIHW"); + auto wshape = ConvertLayout( + call->args[1]->type_as()->shape, + weight_layout, kOIHW); + return is_const_int(wshape[0], param->groups) && + is_const_int(wshape[1], 1); +} + + inline Expr Multiply(Expr lhs, Expr rhs) { static const Op& op = Op::Get("multiply"); return CallNode::make(op, {lhs, rhs}, Attrs(), {}); } + inline Expr Divide(Expr lhs, Expr rhs) { static const Op& op = Op::Get("divide"); return CallNode::make(op, {lhs, rhs}, Attrs(), {}); @@ -116,8 +137,6 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) { return CallNode::make(op, {lhs, rhs}, Attrs(), {}); } - - } // namespace relay } // namespace tvm #endif // TVM_RELAY_PASS_PATTERN_UTIL_H_ diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py index 7ce3b35efe46..1b57bdce0e0c 100644 --- a/tests/python/relay/test_pass_fold_scale_axis.py +++ b/tests/python/relay/test_pass_fold_scale_axis.py @@ -62,14 +62,14 @@ def before(x, conv_weight, in_bias, in_scale, channels): channels=channels, kernel_size=(3, 3), data_layout="NHWC", - weight_layout="HWOI", + weight_layout="HWIO", groups=channels, padding=(1, 1)) y2 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), data_layout="NHWC", - weight_layout="HWOI", + weight_layout="HWIO", groups=channels, padding=(1, 1)) z = relay.add(y1, y2) @@ -85,7 +85,7 @@ def expected(x, conv_weight, in_bias, in_scale, channels): channels=channels, kernel_size=(3, 3), data_layout="NHWC", - weight_layout="HWOI", + weight_layout="HWIO", groups=channels, padding=(1, 1)) y2 = relay.nn.conv2d(x, @@ -93,7 +93,7 @@ def expected(x, conv_weight, in_bias, in_scale, channels): channels=channels, kernel_size=(3, 3), data_layout="NHWC", - weight_layout="HWOI", + weight_layout="HWIO", groups=channels, padding=(1, 1)) z = relay.add(y1, y2) @@ -147,7 +147,176 @@ def check(shape, channels): check((2, 11, 10, 4), 4) +def test_fold_bwd_simple(): + """Simple testcase.""" + def before(x, conv_weight, out_bias, out_scale, channels): + args = [x, conv_weight, out_bias, out_scale] + out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2) + out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) + y = relay.nn.conv2d(x, conv_weight, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.add(y, out_bias) + y = relay.nn.relu(y) + y = relay.multiply(y, out_scale) + return relay.Function(args, y) + + def expected(x, conv_weight, out_bias, out_scale, channels): + # use a fixed order of args so alpha equal check can pass + args = [x, conv_weight, out_bias, out_scale] + out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2) + out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) + squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) + conv_weight = relay.multiply( + conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) + + y = relay.nn.conv2d(x, conv_weight, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1)) + out_bias = relay.multiply(out_bias, + relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)) + y = relay.add(y, out_bias) + y = relay.nn.relu(y) + return relay.Function(args, y) + + def check(shape, channels): + x = relay.var("x", shape=shape) + in_channels = shape[1] + weight = relay.var("weight") + out_bias = relay.var("out_bias", shape=(channels,)) + out_scale = relay.var("out_scale", shape=(channels,)) + + y1 = before(x, weight, out_bias, out_scale, channels) + y1 = relay.ir_pass.infer_type(y1) + type_dict = {x.name_hint:x.checked_type for x in y1.params} + weight = relay.var("weight", type_dict["weight"]) + y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) + y1_expected = expected(x, weight, out_bias, out_scale, channels) + assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) + + check((2, 4, 10, 10), 8) + + +def test_fold_bwd_dual_path(): + """Dual path testcase.""" + def before(x, conv_weight, out_bias, out_scale, channels): + args = [x, conv_weight, out_bias, out_scale] + out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2) + out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) + y1 = relay.nn.conv2d(x, conv_weight, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1)) + y1 = relay.nn.relu(y1) + y2 = relay.nn.conv2d(x, conv_weight, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1)) + y2 = relay.nn.relu(y2) + y = relay.add(y1, y2) + y = relay.multiply(y, out_scale) + return relay.Function(args, y) + + def expected(x, conv_weight, out_bias, out_scale, channels): + # use a fixed order of args so alpha equal check can pass + args = [x, conv_weight, out_bias, out_scale] + out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2) + out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) + squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) + def fold_conv_weight(): + return relay.multiply( + conv_weight , + relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) + y1 = relay.nn.conv2d(x, fold_conv_weight(), + channels=channels, + kernel_size=(3, 3), + padding=(1, 1)) + y1 = relay.nn.relu(y1) + y2 = relay.nn.conv2d(x, fold_conv_weight(), + channels=channels, + kernel_size=(3, 3), + padding=(1, 1)) + y2 = relay.nn.relu(y2) + y = relay.add(y1, y2) + return relay.Function(args, y) + + def check(shape, channels): + x = relay.var("x", shape=shape) + in_channels = shape[1] + weight = relay.var("weight") + out_bias = relay.var("out_bias", shape=(channels,)) + out_scale = relay.var("out_scale", shape=(channels,)) + + y1 = before(x, weight, out_bias, out_scale, channels) + y1 = relay.ir_pass.infer_type(y1) + type_dict = {x.name_hint:x.checked_type for x in y1.params} + weight = relay.var("weight", type_dict["weight"]) + y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) + y1_expected = expected(x, weight, out_bias, out_scale, channels) + assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) + + check((2, 4, 10, 10), 8) + + +def test_fold_bwd_fail(): + """Dual path testcase.""" + def fail1(x, conv_weight, out_bias, out_scale, channels): + args = [x, conv_weight, out_bias, out_scale] + out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2) + out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) + y1 = relay.nn.conv2d(x, conv_weight, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1)) + y1 = relay.nn.relu(y1) + y2 = relay.nn.conv2d(x, conv_weight, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1), + out_layout="CNHW") + # fold will fail because the axis from two path + # differs from each other. + y2 = relay.nn.relu(y2) + y = relay.add(y1, y2) + y = relay.multiply(y, out_scale) + return relay.Function(args, y) + + def fail2(x, conv_weight, out_bias, out_scale, channels): + args = [x, conv_weight, out_bias, out_scale] + out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2) + out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) + y1 = relay.nn.conv2d(x, conv_weight, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1)) + y2 = relay.nn.relu(y1) + # fold will fail because y1 is referred also by y2 + y1 = relay.multiply(y1, out_scale) + y = relay.add(y1, y2) + return relay.Function(args, y) + + + def check(shape, channels, fbefore): + x = relay.var("x", shape=shape) + in_channels = shape[1] + weight = relay.var("weight") + out_bias = relay.var("out_bias", shape=(channels,)) + out_scale = relay.var("out_scale", shape=(channels,)) + y1 = fbefore(x, weight, out_bias, out_scale, channels) + y1 = relay.ir_pass.infer_type(y1) + y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) + assert relay.ir_pass.alpha_equal(y1_folded, y1) + + check((4, 4, 10, 10), 4, fail1) + check((4, 4, 10, 10), 4, fail2) + + if __name__ == "__main__": test_fold_fwd_simple() test_fold_fwd_dual_path() test_fold_fwd_fail() + test_fold_bwd_simple() + test_fold_bwd_dual_path() + test_fold_bwd_fail() From 914438191c19ac02fc190686a4f93dfa236c2d34 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 29 Oct 2018 15:32:04 -0700 Subject: [PATCH 2/2] fix review comments --- src/relay/pass/fold_scale_axis.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index 68b79d852917..e757118f33f2 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -45,7 +45,7 @@ using runtime::TypedPackedFunc; // - Prepare phase: backward propagation of demand. // - Transform phase: forward transformation, // -// Similarly, borward folding process is done in two steps: +// Similarly, backward folding process is done in two steps: // - Prepare phase: forward propagation of demand. // - Transform phase: transformation by push down the axes scale signal to inputs. // @@ -104,7 +104,7 @@ ValueType GetFunc(const OpMap& op_map, } /*! - * \brief Preparation function for for pass scale forward. + * \brief Preparation function for pass scale forward. * \param call The call node. * \param out_scale_axes Possible scaling on axes of the output. * \return The result scaling on axes of the input.