Skip to content

Commit

Permalink
Legalize - Use Non-recursive Rewriter. (apache#5296)
Browse files Browse the repository at this point in the history
* Legalize - Use Non-recursive Rewriter.

* Cleanup.
  • Loading branch information
anijain2305 authored and dpankratz committed Apr 24, 2020
1 parent d56bf85 commit 0ac8338
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
4 changes: 2 additions & 2 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator {
*
* ExprRewriter provides a Rewrite interface for modifying graphs in Post-DFS order.
*
* The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will
* The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will
* non-recursively unroll the graph and call Rewriting on inputs. It will then pass the original
* node, called `pre`, and a node recreated with any alterned inputs, called `post`, to the
* ExprRewriter. The ExprRewriter can then use the information in those two nodes to do more complex
Expand Down Expand Up @@ -408,7 +408,7 @@ class ExprRewriter {

/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
*
* PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the
* PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the
* ExprRewriter's Rewrite functions on nodes once their inputs are rewritten. At each rewrite call,
* PostOrderRewrite provides the original node and the node with altered inputs for use by the
* ExprRewriter.
Expand Down
19 changes: 9 additions & 10 deletions src/relay/transforms/legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,18 @@ namespace legalize {

// Call registered FTVMLegalize of an op
// Returns the legalized expression
class Legalizer : public ExprMutator {
class Legalizer : public ExprRewriter {
public:
explicit Legalizer(const std::string& legalize_map_attr_name)
: legalize_map_attr_name_{legalize_map_attr_name} {}

Expr VisitExpr_(const CallNode* call_node) {
Expr Rewrite_(const CallNode* call_node, const Expr& post) override {
// Get the new_call node without any changes to current call node.
Expr new_e = ExprMutator::VisitExpr_(call_node);
Call new_call = Downcast<Call>(new_e);
Call new_call = Downcast<Call>(post);

// Check if the string is registered in the OpRegistry.
if (!Op::HasAttr(legalize_map_attr_name_)) {
return new_e;
return post;
}

// Collect the registered legalize function.
Expand All @@ -70,27 +69,27 @@ class Legalizer : public ExprMutator {
// Transform the op by calling the registered legalize function.
Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);

// Reassign new_e if the transformation succeeded.
// Return the new expr if the transformation succeeded.
if (legalized_value.defined()) {
// Check that the returned Expr from legalize is CallNode.
const CallNode* legalized_call_node = legalized_value.as<CallNode>();
CHECK(legalized_call_node)
<< "Can only replace the original operator with another call node";

new_e = legalized_value;
return legalized_value;
}
}
}

return new_e;
return post;
}

private:
std::string legalize_map_attr_name_;
};

Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) {
return Legalizer(legalize_map_attr_name).Mutate(expr);
auto rewriter = Legalizer(legalize_map_attr_name);
return PostOrderRewrite(expr, &rewriter);
}

} // namespace legalize
Expand Down

0 comments on commit 0ac8338

Please sign in to comment.