Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FoldScaleAxis became non-recursive #8325

Merged
merged 2 commits into from
Jul 15, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 78 additions & 59 deletions src/relay/transforms/fold_scale_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ using FForwardRewrite = TypedPackedFunc<Expr(const Call& ref_call, const Array<E
//----------------------------------------------
// Generic Visitors for FScaleAxisForward
//----------------------------------------------
class ForwardPrep : private ExprVisitor {
class ForwardPrep : private MixedModeVisitor {
public:
std::unordered_map<const Object*, Message> Prepare(const Expr& body) {
this->Update(body, NullValue<Message>());
Expand Down Expand Up @@ -585,15 +585,22 @@ RELAY_REGISTER_OP("nn.conv2d")

Expr ForwardFoldScaleAxis(const Expr& data) {
auto message = ForwardPrep().Prepare(data);
auto fcontext = [&](const Call& call) -> ObjectRef {
auto it = message.find(call.get());
if (it != message.end()) {
return it->second;
} else {
return ObjectRef(nullptr);
for (const auto& m : message) {
if (m.second.defined()) {
// run optimization
auto fcontext = [&](const Call& call) -> ObjectRef {
auto it = message.find(call.get());
if (it != message.end()) {
return it->second;
} else {
return ObjectRef(nullptr);
}
};
return ForwardRewrite(data, "FScaleAxisForwardRewrite", fcontext);
}
};
return ForwardRewrite(data, "FScaleAxisForwardRewrite", fcontext);
}
// no messages - no optimization
return data;
}

//----------------------------------------
Expand All @@ -618,7 +625,7 @@ using FBackwardTransform =
// Generic Visitors for FScaleAxisBackward
//----------------------------------------------

class BackwardPrep : private ExprVisitor {
class BackwardPrep : private MixedModeVisitor {
public:
// The message on each node.
std::unordered_map<const Object*, Message> Prepare(const Expr& body) {
Expand All @@ -643,6 +650,14 @@ class BackwardPrep : private ExprVisitor {
// We only allow propagation of scale backward
// if the expression is only referred by a single parent.
if (rit->second != 1) return;
Array<Message> in_messages = GetInMessages(call);
Message out_message = f(GetRef<Call>(call), in_messages);
if (out_message.defined()) {
message_[call] = out_message;
}
}

Array<Message> GetInMessages(const CallNode* call) {
Array<Message> in_messages;
for (Expr arg : call->args) {
auto it = message_.find(arg.get());
Expand All @@ -652,52 +667,34 @@ class BackwardPrep : private ExprVisitor {
in_messages.push_back(NullValue<Message>());
}
}
Message out_message = f(GetRef<Call>(call), in_messages);
if (out_message.defined()) {
message_[call] = out_message;
}
return in_messages;
}
};

class BackwardTransformerNode : public Object, private ExprMutator {
/*
* Hybrid apporach is used with the transformation
* itself is recursive but the traversal is non-recursive
*/
class BackwardTransformerNode : public Object, private MixedModeMutator {
public:
using MixedModeMutator::Mutate;
// Run forward transform.
Expr Fold(Expr expr) {
message_ = BackwardPrep().Prepare(expr);
return this->Mutate(expr);
}
/*!
* \brief Transform the expr to consider the scaling.
*
* \param expr The input expression.
* \param axes The axes to scale.
* \param scale The scale applied to the axes.
* \return The result of transformation.
*/
Expr Transform(const Expr& expr, Message message, Expr scale) {
// NOTE: the result of Transform is memoized.
if (const CallNode* call_node = expr.as<CallNode>()) {
return Transform(call_node, message, scale);
} else {
ICHECK(!message.defined()) << "outstanding scale";
return ExprMutator::VisitExpr(expr);
for (const auto& m : message_) {
if (m.second.defined()) {
// run optimization
return this->Mutate(expr);
}
}
// no messages - no optimization
return expr;
}

/*!
* \brief Normal way of mutating call node.
* \param call_node The call node to be mutated.
* \return the result of the call Mutation.
* \brief Transform the expr to consider the scaling.
*/
Expr NormalCallTransform(const CallNode* call_node) {
const Call call = GetRef<Call>(call_node);
const auto it = memo_.find(call);
if (it != memo_.end()) {
return it->second;
}
Expr new_expr = ExprMutator::VisitExpr_(call_node);
memo_[call] = new_expr;
return new_expr;
}
Expr Transform(const Expr& expr, Message message, Expr scale);
/*!
* \brief Get the message propogated to the expr.
* \param expr The expresison.
Expand All @@ -719,11 +716,12 @@ class BackwardTransformerNode : public Object, private ExprMutator {
// Valid axes on each node.
std::unordered_map<const Object*, Message> message_;
// Override mutation of call.
Expr VisitExpr_(const CallNode* call_node) final {
return Transform(call_node, NullValue<Message>(), NullValue<Expr>());
Expr Rewrite_(const CallNode* call_node, const Expr& post) final {
return Transform(GetRef<Call>(call_node), NullValue<Message>(), NullValue<Expr>());
}
// Transform of CallNode.
Expr Transform(const CallNode* call_node, Message message, Expr scale);

public:
Expr NormalCallTransform(const CallNode* call_node) { return ExprMutator::VisitExpr_(call_node); }
};

class BackwardTransformer : public ObjectRef {
Expand All @@ -736,21 +734,39 @@ class BackwardTransformer : public ObjectRef {
using ContainerType = BackwardTransformerNode;
};

Expr BackwardTransformerNode::Transform(const CallNode* call_node, Message message, Expr scale) {
static const auto& ftransform = Op::GetAttrMap<FBackwardTransform>("FScaleAxisBackwardTransform");
auto f = ftransform.get(call_node->op, nullptr);
if (f != nullptr) {
/*!
d-smirnov marked this conversation as resolved.
Show resolved Hide resolved
* \brief Transform the expr to consider the scaling.
*
* \param expr The input expression.
* \param message The axes to scale.
* \param scale The scale applied to the axes.
* \return The result of transformation.
*/
Expr BackwardTransformerNode::Transform(const Expr& expr, Message message, Expr scale) {
if (const CallNode* call_node = expr.as<CallNode>()) {
static const auto& ftransform =
Op::GetAttrMap<FBackwardTransform>("FScaleAxisBackwardTransform");
auto f = ftransform.get(call_node->op, nullptr);
const Call call = GetRef<Call>(call_node);
const auto it = memo_.find(call);
if (it != memo_.end()) {
return it->second;
// ignore if there is a message
if (!message.defined()) {
const auto it = memo_.find(call);
if (it != memo_.end()) {
return it->second;
}
}
Expr new_expr = NullValue<Expr>();
if (f != nullptr) {
new_expr = f(call, message, scale, GetRef<BackwardTransformer>(this));
} else {
ICHECK(!message.defined()) << "outstanding scale";
new_expr = NormalCallTransform(call.operator->());
}
Expr new_expr = f(GetRef<Call>(call_node), message, scale, GetRef<BackwardTransformer>(this));
memo_[call] = new_expr;
return new_expr;
} else {
ICHECK(!message.defined()) << "outstanding scale";
return NormalCallTransform(call_node);
return this->Mutate(expr);
}
}

d-smirnov marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -813,6 +829,7 @@ Expr AddSubBackwardTransform(const Call& call, const Message& message, const Exp
if (!message.defined()) {
return transformer->NormalCallTransform(call.operator->());
}

Message lhs_message = transformer->GetMessage(call->args[0]);
Message rhs_message = transformer->GetMessage(call->args[1]);
StructuralEqual equal;
Expand Down Expand Up @@ -959,7 +976,9 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp
} else {
wscale = ReshapeToMatchAxis(scale, weight->type_as<TensorTypeNode>()->shape,
{big_ko_axis, small_ko_axis});
if (!wscale.defined()) return transformer->NormalCallTransform(call.operator->());
if (!wscale.defined()) {
return transformer->NormalCallTransform(call.operator->());
}
}
weight = Multiply(weight, wscale);
return Call(call->op, {data, weight}, call->attrs, call->type_args);
Expand Down