From f3ac8bb2583ca438fdc50932a1ce804d25deb6c8 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Fri, 17 May 2019 10:43:03 -0700 Subject: [PATCH] save --- include/tvm/relay/pass.h | 39 +++++++++++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 8445939625e40..49c25c64b3e60 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -507,9 +507,23 @@ inline pass::Pass FuseOpsPass(int fuse_opt_level) { * \return The rewritten expression. */ TVM_DLL Expr ForwardRewrite(const Expr& expr, - const std::string& rewrite_map_attr_name, - std::function fcontext = nullptr, - std::function fmulti_ref_trigger = nullptr); + const std::string& rewrite_map_attr_name, + std::function fcontext = nullptr, + std::function fmulti_ref_trigger = nullptr); + +inline pass::Pass ForwardRewritePass(const std::string& rewrite_map_attr_name, + std::function fcontext = nullptr, + std::function + fmulti_ref_trigger = nullptr) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, pass::PassContext pc) { + return Downcast(ForwardRewrite(f, + rewrite_map_attr_name, + fcontext, + fmulti_ref_trigger)); + }; + return pass::CreateFunctionPass(pass_func, 1, "forward_rewrite", {}); + } /*! * \brief Apply rewrite rules to rewrite the expr in post DFS order. @@ -521,9 +535,22 @@ TVM_DLL Expr ForwardRewrite(const Expr& expr, * \return The rewritten expression. */ TVM_DLL Expr ForwardRewrite(const Expr& expr, - const FForwardRewrite& rewrite_func, - std::function fcontext = nullptr, - std::function fmulti_ref_trigger = nullptr); + const FForwardRewrite& rewrite_func, + std::function fcontext = nullptr, + std::function fmulti_ref_trigger = nullptr); + +inline pass::Pass ForwardRewritePass(const FForwardRewrite& rewrite_func, + std::function fcontext = nullptr, + std::function fmulti_ref_trigger = nullptr) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, pass::PassContext pc) { + return Downcast(ForwardRewrite(f, + rewrite_func, + fcontext, + fmulti_ref_trigger)); + }; + return pass::CreateFunctionPass(pass_func, 1, "forward_rewrite", {}); +} /*! * \brief Rewrite the annotated program.