From ac20b98fe59e682f4aaeb3fc5f6ed734aa668a27 Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Mon, 10 Jun 2019 17:47:31 -0700 Subject: [PATCH] [relay][vm] move vm opt passes to pass manager (#3323) --- python/tvm/relay/backend/vm.py | 52 ++++++++----- src/relay/backend/vm/compiler.cc | 24 ++++-- src/relay/backend/vm/inline_primitives.cc | 92 ++++++++++++----------- src/relay/backend/vm/lambda_lift.cc | 80 ++++++++++---------- src/relay/pass/pass_manager.cc | 15 ++-- 5 files changed, 150 insertions(+), 113 deletions(-) diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index bebadd167fe9..3b9946a3958d 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -20,24 +20,45 @@ Implements a Python interface to compiling and executing on the Relay VM. """ +import numpy as np + import tvm from tvm._ffi.function import Object -import numpy as np -from .. import ir_pass +from .. import transform from ..backend.interpreter import Executor -from ..expr import GlobalVar, Function, Expr +from ..expr import GlobalVar, Expr from . import _vm Object = Object -def optimize(expr, mod=None): - # TODO: We need to move this optimization code into the optimizer/pass manager - ck_expr = ir_pass.infer_type(expr, mod=mod) - simplified_expr = ir_pass.simplify_inference(ck_expr) - simplified_expr = ir_pass.infer_type(simplified_expr, mod=mod) - fused_expr = ir_pass.fuse_ops(simplified_expr, mod=mod) - ck_fused = ir_pass.infer_type(fused_expr, mod=mod) - return ck_fused +def optimize(mod): + """Perform several optimizations on a module before executing it in the + Relay virtual machine. + + Parameters + ---------- + mod : tvm.relay.Module + The module to optimize. + + Returns + ------- + ret : tvm.relay.Module + The optimized module. + """ + main_func = mod[mod.entry_func] + + opt_passes = [] + if not main_func.params and isinstance(main_func.body, GlobalVar): + opt_passes.append(transform.EtaExpand()) + + opt_passes = opt_passes + [ + transform.SimplifyInference(), + transform.FuseOps(), + transform.InferType() + ] + + seq = transform.Sequential(opt_passes) + return seq(mod) def _convert(arg, cargs): if isinstance(arg, np.ndarray): @@ -76,15 +97,8 @@ def _eval_vm(mod, ctx, *args): args: List[tvm.NDArray, np.ndarray] The arguments to evaluate. """ - main_func = mod[mod.entry_func] - - if not main_func.params and isinstance(main_func.body, GlobalVar): - main_func = ir_pass.eta_expand(main_func.body, mod) - - assert isinstance(main_func, Function) - main_func = optimize(mod[mod.entry_func], mod) - mod[mod.entry_func] = main_func + mod = optimize(mod) args = list(args) assert isinstance(args, list) cargs = convert(args) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index db98a9a9d3fd..07633fc346ec 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include #include @@ -38,15 +38,22 @@ namespace tvm { namespace relay { + +namespace transform { + +Pass LambdaLift(); +Pass InlinePrimitives(); + +} // namespace transform + namespace vm { using namespace tvm::runtime; using namespace tvm::runtime::vm; +using namespace relay::transform; // (@jroesch): VM passes, eventually declare as passes. bool IsClosure(const Function& func); -Module LambdaLift(const Module& module); -Module InlinePrimitives(const Module& module); template using NodeMap = std::unordered_map; @@ -560,10 +567,13 @@ VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const F } Module OptimizeModule(const Module& mod) { - ToANormalForm(mod->entry_func, mod); - InlinePrimitives(mod); - LambdaLift(mod); - return InlinePrimitives(mod); + transform::Sequential seq({transform::ToANormalForm(), + transform::InlinePrimitives(), + transform::LambdaLift(), + transform::InlinePrimitives()}); + auto pass_ctx = transform::PassContext::Create(); + tvm::With ctx(pass_ctx); + return seq(mod); } void PopulateGlobalMap(GlobalMap* global_map, const Module& mod) { diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index b033a37e42b8..1e561f8a8214 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include #include @@ -37,6 +37,21 @@ namespace tvm { namespace relay { namespace vm { +// TODO(@jroesch): write verifier + +/* This pass will eliminate primitives which have been lifted by the ANF + * transform inlining them directly into call sites. + * + * This makes VM related code generation easier as the call target is always + * a primitive function. + * + * let prim = fn(...) { ... }; + * prim(...) + * + * will become: + * + * (fn(...) { ... })(...) + */ struct PrimitiveInliner : ExprMutator { Module module_; std::unordered_map var_map; @@ -92,55 +107,46 @@ struct PrimitiveInliner : ExprMutator { } } - Function Inline(const Function& func) { - DLOG(INFO) << "Before inlining primitives: " << std::endl - << "func= " << AsText(func, false) << std::endl; - - auto inlined = FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type, - func->type_params, func->attrs); - - inlined = Downcast(DeadCodeElimination(inlined)); - - DLOG(INFO) << "After inlining primitives" << std::endl - << "after_func= " << AsText(inlined, false) << std::endl; - return inlined; + Module Inline() { + auto gvar_funcs = module_->functions; + for (auto pair : gvar_funcs) { + auto global = pair.first; + auto func = pair.second; + DLOG(INFO) << "Before inlining primitives: " << global + << std::endl << AsText(func, false); + + func = FunctionNode::make(func->params, + VisitExpr(func->body), + func->ret_type, + func->type_params, + func->attrs); + module_->Add(global, func, true); + + DLOG(INFO) << "After inlining primitives: " << global + << std::endl << AsText(func, false); + } + return module_; } }; -// TODO(@jroesch): write verifier - -/* This pass will eliminate primitives which have been lifted by the ANF - * transform inlining them directly into call sites. - * - * This makes VM related code generation easier as the call target is always - * a primitive function. - * - * let prim = fn(...) { ... }; - * prim(...) - * - * will become: - * - * (fn(...) { ... })(...) - */ -Module InlinePrimitives(const Module& module) { - PrimitiveInliner inliner(module); +} // namespace vm - tvm::Map updates; +namespace transform { - // There is an ordering bug here. - for (auto pair : module->functions) { - auto global = pair.first; - auto func = pair.second; - updates.Set(global, inliner.Inline(func)); - } +Pass InlinePrimitives() { + runtime::TypedPackedFunc pass_func = + [=](Module m, PassContext pc) { + return relay::vm::PrimitiveInliner(m).Inline(); + }; + auto inline_pass = CreateModulePass(pass_func, 1, "Inline", {}); + // Eliminate dead code for each function after inlining. + return Sequential({inline_pass, DeadCodeElimination()}, "InlinePrimitives"); +} - for (auto pair : updates) { - module->Add(pair.first, pair.second, true); - } +TVM_REGISTER_API("relay._transform.InlinePrimitives") +.set_body_typed(InlinePrimitives); - return module; -} +} // namespace transform -} // namespace vm } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 13d8112440fb..a55a9273d078 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -54,9 +55,14 @@ Function MarkClosure(const Function& func) { return FunctionSetAttr(func, kIsClosure, tvm::Integer(1)); } +/* The goal of this class is to lift out any nested functions into top-level + * functions. + * + * We will lift a function out into a global which takes the set of the free + * vars and then return the new created function. + */ struct LambdaLifter : ExprMutator { Module module_; - std::vector> lifted_; explicit LambdaLifter(const Module& module) : module_(module) {} Expr VisitExpr_(const FunctionNode* func_node) final { @@ -71,8 +77,7 @@ struct LambdaLifter : ExprMutator { auto free_type_vars = FreeTypeVars(func, module_); auto body = Downcast(ExprMutator::VisitExpr_(func_node)); - // When performing this optimization there are two - // cases. + // When performing this optimization there are two cases. // // The first case in which we have no free variables // we can just lift the function into the global @@ -80,7 +85,7 @@ struct LambdaLifter : ExprMutator { // // // The second case requires that we generate a special - // function with makes a distinction between allocating + // function which makes a distinction between allocating // a closure, and then the code for the closure. // // We represent a closure allocation by lifting the @@ -92,7 +97,7 @@ struct LambdaLifter : ExprMutator { // function marked as a closure is used to emit allocation // code for the closure's environment. // - // The "inner" function is should be used to generate the + // The "inner" function should be used to generate the // code for the closure. Function lifted_func; if (free_vars.size() == 0) { @@ -107,16 +112,16 @@ struct LambdaLifter : ExprMutator { CHECK(lifted_func.defined()); auto name = GenerateName(lifted_func); - auto global = this->module_->GetGlobalVar(name); + auto global = module_->GetGlobalVar(name); - lifted_.push_back({global, lifted_func}); + // Add the lifted function to the module. + module_->Add(global, lifted_func); if (free_vars.size() == 0) { return std::move(global); } else { - // If we need to allocate a closure - // we pass the variables in its environment - // here. + // If we need to allocate a closure, + // we pass the variables in its environment here. Array fvs; for (auto fv : free_vars) { fvs.push_back(fv); @@ -125,42 +130,39 @@ struct LambdaLifter : ExprMutator { } } - Function Lift(const Function& func) { - DLOG(INFO) << "Lifting: " << AsText(func, false) << std::endl; - return FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type, - func->type_params, func->attrs); + Module Lift() { + // There is an ordering bug here. + auto glob_funcs = module_->functions; + for (auto pair : glob_funcs) { + auto func = pair.second; + DLOG(INFO) << "Lifting " << AsText(func, false); + func = FunctionNode::make(func->params, + VisitExpr(func->body), + func->ret_type, + func->type_params, + func->attrs); + module_->Add(pair.first, func, true); + } + return module_; } }; -/* The goal of this pass is to lift out any nested functions into top-level - * functions. - * - * We will lift the functions out into globals which take the set of the free vars - * and then return a function whcih has b - */ -Module LambdaLift(const Module& module) { - LambdaLifter lifter(module); - - tvm::Map updates; +} // namespace vm - // There is an ordering bug here. - for (auto pair : module->functions) { - auto global = pair.first; - auto func = pair.second; - updates.Set(global, lifter.Lift(func)); - } +namespace transform { - for (auto i = lifter.lifted_.begin(); i != lifter.lifted_.end(); i++) { - module->Add(i->first, i->second); - } +Pass LambdaLift() { + runtime::TypedPackedFunc pass_func = + [=](Module m, PassContext pc) { + return relay::vm::LambdaLifter(m).Lift(); + }; + return CreateModulePass(pass_func, 1, "LambdaLift", {}); +} - for (auto pair : updates) { - module->Add(pair.first, pair.second, true); - } +TVM_REGISTER_API("relay._transform.LambdaLift") +.set_body_typed(LambdaLift); - return module; -} +} // namespace transform -} // namespace vm } // namespace relay } // namespace tvm diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 05eb43d6a653..782bb6a5980f 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -309,20 +309,24 @@ Module FunctionPassNode::operator()(const Module& mod, const PassContext& pass_ctx) const { const PassInfo& pass_info = Info(); CHECK(mod.defined()); - DLOG(INFO) << "Executing module pass : " + DLOG(INFO) << "Executing function pass : " << pass_info->name << " with opt level: " << pass_info->opt_level; Module updated_mod = mod; - Module new_mod = ModuleNode::make({}, mod->type_definitions); // Execute the pass function and return a new module. + std::vector > updates; for (const auto& it : mod->functions) { auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, updated_mod, pass_ctx); - new_mod->Add(it.first, updated_func); + updates.push_back({it.first, updated_func}); + } + + for (const auto& pair : updates) { + updated_mod->Add(pair.first, pair.second, true); } - return new_mod; + return updated_mod; } // TODO(zhiics) Create an enum attribute for FunctionNode @@ -539,7 +543,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) tvm::IRPrinter* p) { p->stream << "Pass context information: " << "\n"; p->stream << "\topt_level: " << node->opt_level << "\n"; - p->stream << "\tfallback device: " << runtime::DeviceName(node->opt_level) + p->stream << "\tfallback device: " + << runtime::DeviceName(node->fallback_device) << "\n"; p->stream << "\trequired passes: [" << node->opt_level;