Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame committed May 17, 2019
1 parent 4b302cf commit f3ac8bb
Showing 1 changed file with 33 additions and 6 deletions.
39 changes: 33 additions & 6 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -507,9 +507,23 @@ inline pass::Pass FuseOpsPass(int fuse_opt_level) {
* \return The rewritten expression.
*/
TVM_DLL Expr ForwardRewrite(const Expr& expr,
const std::string& rewrite_map_attr_name,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
const std::string& rewrite_map_attr_name,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

inline pass::Pass ForwardRewritePass(const std::string& rewrite_map_attr_name,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)>
fmulti_ref_trigger = nullptr) {
runtime::TypedPackedFunc<Function(Function, Module, pass::PassContext)> pass_func =
[=](Function f, Module m, pass::PassContext pc) {
return Downcast<Function>(ForwardRewrite(f,
rewrite_map_attr_name,
fcontext,
fmulti_ref_trigger));
};
return pass::CreateFunctionPass(pass_func, 1, "forward_rewrite", {});
}

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
Expand All @@ -521,9 +535,22 @@ TVM_DLL Expr ForwardRewrite(const Expr& expr,
* \return The rewritten expression.
*/
TVM_DLL Expr ForwardRewrite(const Expr& expr,
const FForwardRewrite& rewrite_func,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
const FForwardRewrite& rewrite_func,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

inline pass::Pass ForwardRewritePass(const FForwardRewrite& rewrite_func,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr) {
runtime::TypedPackedFunc<Function(Function, Module, pass::PassContext)> pass_func =
[=](Function f, Module m, pass::PassContext pc) {
return Downcast<Function>(ForwardRewrite(f,
rewrite_func,
fcontext,
fmulti_ref_trigger));
};
return pass::CreateFunctionPass(pass_func, 1, "forward_rewrite", {});
}

/*!
* \brief Rewrite the annotated program.
Expand Down

0 comments on commit f3ac8bb

Please sign in to comment.