Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame committed May 16, 2019
1 parent 7d82dea commit 4b302cf
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 29 deletions.
17 changes: 5 additions & 12 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ Pass CreateModulePass(
* \return The created function pass.
*/
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func,
const runtime::TypedPackedFunc<Function(Function, Module, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required);
Expand Down Expand Up @@ -489,18 +489,11 @@ TVM_DLL Expr FoldConstant(const Expr& expr);
TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod);

inline pass::Pass FuseOpsPass(int fuse_opt_level) {
runtime::TypedPackedFunc<Module(Module, pass::PassContext)> pass_func =
[=](Module m, pass::PassContext pc) {
Module new_m = ModuleNode::make(m->functions, m->type_definitions);
for (const auto& f : new_m->functions) {
new_m->Update(f.first, Downcast<Function>(FuseOps(f.second, fuse_opt_level, new_m)));
}
return new_m;
runtime::TypedPackedFunc<Function(Function, Module, pass::PassContext)> pass_func =
[=](Function f, Module m, pass::PassContext pc) {
return Downcast<Function>(FuseOps(f, fuse_opt_level, m));
};
return pass::CreateModulePass(pass_func,
1,
"fuse_ops",
{});
return pass::CreateFunctionPass(pass_func, 1, "fuse_ops", {});
}

/*!
Expand Down
26 changes: 9 additions & 17 deletions src/relay/pass/pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class FunctionPassNode : public PassNode {
* `pass_func` and let it run on a given module. The same `pass_func` will
* then be applied on each function in the module.
*/
runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func;
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;

FunctionPassNode() = default;

Expand Down Expand Up @@ -145,7 +145,7 @@ class FunctionPassNode : public PassNode {
void SetContext(const PassContext& pass_ctx) final;

TVM_DLL static FunctionPass make(
runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func,
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func,
PassInfo pass_info);

static constexpr const char* _type_key = "relay.FunctionPass";
Expand Down Expand Up @@ -305,7 +305,7 @@ void ModulePassNode::SetContext(const PassContext& pass_ctx) {
}

FunctionPass FunctionPassNode::make(
runtime::TypedPackedFunc<Function(Function, PassContext)> pass_func,
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func,
PassInfo pass_info) {
auto n = make_node<FunctionPassNode>();
n->pass_func = std::move(pass_func);
Expand All @@ -320,22 +320,14 @@ Module FunctionPassNode::operator()(const Module& mod) const {
LOG(INFO) << "Executing function pass : " << pass_info.operator->()->name
<< " with opt level: " << pass_info.operator->()->opt_level << "\n";
CHECK(mod.defined());
std::vector<std::pair<GlobalVar, Function>> updated_funcs;
ModuleNode* mod_node = mod.operator->();
for (const auto& it : mod_node->functions) {
if (!SkipFunction(it.second)) {
auto updated_func = pass_func(it.second, pass_ctx_);
CHECK(updated_func.defined());
updated_funcs.push_back({std::move(it.first), std::move(updated_func)});
}
}
Module new_mod = ModuleNode::make({}, mod->type_definitions);

// Update the optimized functions.
for (const auto& it : updated_funcs) {
mod_node->Update(it.first, it.second);
for (const auto& it : mod->functions) {
auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, mod, pass_ctx_);
new_mod->Add(it.first, updated_func);
}

return GetRef<Module>(mod_node);
return new_mod;
}

void FunctionPassNode::SetContext(const PassContext& pass_ctx) {
Expand Down Expand Up @@ -406,7 +398,7 @@ Pass CreateModulePass(
}

Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func,
const runtime::TypedPackedFunc<Function(Function, Module, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required) {
Expand Down

0 comments on commit 4b302cf

Please sign in to comment.