diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 6f8ac69bde26..04b275431f2b 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -330,7 +330,7 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator { * * ExprRewriter provides a Rewrite interface for modifying graphs in Post-DFS order. * - * The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will + * The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will * non-recursively unroll the graph and call Rewriting on inputs. It will then pass the original * node, called `pre`, and a node recreated with any alterned inputs, called `post`, to the * ExprRewriter. The ExprRewriter can then use the information in those two nodes to do more complex @@ -408,7 +408,7 @@ class ExprRewriter { /*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes * - * PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the + * PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the * ExprRewriter's Rewrite functions on nodes once their inputs are rewritten. At each rewrite call, * PostOrderRewrite provides the original node and the node with altered inputs for use by the * ExprRewriter. diff --git a/src/relay/transforms/legalize.cc b/src/relay/transforms/legalize.cc index 250dd69cd62f..01411a63eb2f 100644 --- a/src/relay/transforms/legalize.cc +++ b/src/relay/transforms/legalize.cc @@ -35,19 +35,18 @@ namespace legalize { // Call registered FTVMLegalize of an op // Returns the legalized expression -class Legalizer : public ExprMutator { +class Legalizer : public ExprRewriter { public: explicit Legalizer(const std::string& legalize_map_attr_name) : legalize_map_attr_name_{legalize_map_attr_name} {} - Expr VisitExpr_(const CallNode* call_node) { + Expr Rewrite_(const CallNode* call_node, const Expr& post) override { // Get the new_call node without any changes to current call node. - Expr new_e = ExprMutator::VisitExpr_(call_node); - Call new_call = Downcast(new_e); + Call new_call = Downcast(post); // Check if the string is registered in the OpRegistry. if (!Op::HasAttr(legalize_map_attr_name_)) { - return new_e; + return post; } // Collect the registered legalize function. @@ -70,19 +69,18 @@ class Legalizer : public ExprMutator { // Transform the op by calling the registered legalize function. Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types); - // Reassign new_e if the transformation succeeded. + // Return the new expr if the transformation succeeded. if (legalized_value.defined()) { // Check that the returned Expr from legalize is CallNode. const CallNode* legalized_call_node = legalized_value.as(); CHECK(legalized_call_node) << "Can only replace the original operator with another call node"; - - new_e = legalized_value; + return legalized_value; } } } - return new_e; + return post; } private: @@ -90,7 +88,8 @@ class Legalizer : public ExprMutator { }; Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) { - return Legalizer(legalize_map_attr_name).Mutate(expr); + auto rewriter = Legalizer(legalize_map_attr_name); + return PostOrderRewrite(expr, &rewriter); } } // namespace legalize