From 08cc0729b4086e2c41ec360085c836b55d3a4111 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Fri, 4 Oct 2019 14:50:27 -0700 Subject: [PATCH 1/3] [Relay][VM] Add more passes to VMCompiler --- src/relay/backend/vm/compiler.cc | 68 ++++++++++++++++++++++++++------ src/relay/backend/vm/compiler.h | 2 +- 2 files changed, 58 insertions(+), 12 deletions(-) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 8c60fe6d0ee7..c6a10dc3cb7b 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -803,7 +804,7 @@ void VMCompiler::Compile(const Module& mod_ref, // Run some optimizations first, this code should // be moved to pass manager. - context_.module = OptimizeModule(mod_ref); + context_.module = OptimizeModule(mod_ref, targets_); // Populate the global map. // @@ -844,18 +845,63 @@ void VMCompiler::Compile(const Module& mod_ref, } } -Module VMCompiler::OptimizeModule(const Module& mod) { - // TODO(@icemelon9): check number of targets and build config, add more optimization pass - transform::Sequential seq({transform::SimplifyInference(), - transform::InlinePrimitives(), - // TODO(@wweic): FuseOps pass currently don't handle Let - // For now, we put FuseOps before ToANormalForm to enable it - transform::FuseOps(), - transform::ToANormalForm(), - transform::LambdaLift(), - transform::InlinePrimitives()}); +Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) { + // TODO(@icemelon9): check number of targets and build config + Array pass_seqs; + // Run all dialect legalization passes. + pass_seqs.push_back(relay::qnn::transform::Legalize()); + + // Legalize pass is restricted to homogeneous execution for now. + if (targets.size() == 1) { + pass_seqs.push_back(transform::Legalize()); + } + + pass_seqs.push_back(transform::SimplifyInference()); + PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + Expr expr = args[0]; + if (expr.as()) { + auto call_node = expr.as(); + auto op_node = call_node->op.as(); + if (op_node->name == "cast") { + auto attrs = call_node->attrs.as(); + if (attrs->dtype == Int(32)) { + *rv = true; + } + } + } + *rv = false; + }); + pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); + pass_seqs.push_back(transform::InlinePrimitives()); + + pass_seqs.push_back(transform::CombineParallelConv2D(3)); + pass_seqs.push_back(transform::CombineParallelDense(3)); + pass_seqs.push_back(transform::FoldConstant()); + pass_seqs.push_back(transform::FoldScaleAxis()); + pass_seqs.push_back(transform::CanonicalizeCast()); + pass_seqs.push_back(transform::CanonicalizeOps()); + + // Alter layout transformation is only applied to homogeneous execution yet. + if (targets.size() == 1) { + pass_seqs.push_back(transform::AlterOpLayout()); + } + + pass_seqs.push_back(transform::FoldConstant()); + + pass_seqs.push_back(transform::FuseOps()); + pass_seqs.push_back(transform::ToANormalForm()); + pass_seqs.push_back(transform::LambdaLift()); + pass_seqs.push_back(transform::InlinePrimitives()); + + transform::Sequential seq(pass_seqs); auto pass_ctx = transform::PassContext::Create(); tvm::With ctx(pass_ctx); + if (targets.size() == 1) { + for (const auto& kv : targets) { + With tctx(kv.second); + return seq(mod); + } + } return seq(mod); } diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index bfe19ac2140e..14a5035b20dc 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -105,7 +105,7 @@ class VMCompiler : public runtime::ModuleNode { const tvm::Target& target_host); protected: - Module OptimizeModule(const Module& mod); + Module OptimizeModule(const Module& mod, const TargetsMap& targets); void PopulateGlobalMap(); From ff4e4b535a11ec04c8b7ff754fe621bac16802a0 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Fri, 4 Oct 2019 16:24:58 -0700 Subject: [PATCH 2/3] Check build config --- src/relay/backend/vm/compiler.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index c6a10dc3cb7b..e1198456583e 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -846,7 +846,6 @@ void VMCompiler::Compile(const Module& mod_ref, } Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) { - // TODO(@icemelon9): check number of targets and build config Array pass_seqs; // Run all dialect legalization passes. pass_seqs.push_back(relay::qnn::transform::Legalize()); @@ -894,7 +893,7 @@ Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) pass_seqs.push_back(transform::InlinePrimitives()); transform::Sequential seq(pass_seqs); - auto pass_ctx = transform::PassContext::Create(); + transform::PassContext pass_ctx = PassContext::Current(); tvm::With ctx(pass_ctx); if (targets.size() == 1) { for (const auto& kv : targets) { From be95ddbf917b80c02961348a115213eac29dc588 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Fri, 4 Oct 2019 22:48:59 -0700 Subject: [PATCH 3/3] Add todo --- src/relay/backend/vm/compiler.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index e1198456583e..49079fbc107e 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -894,6 +894,7 @@ Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) transform::Sequential seq(pass_seqs); transform::PassContext pass_ctx = PassContext::Current(); + // TODO(wweic): Support heterogenous execution tvm::With ctx(pass_ctx); if (targets.size() == 1) { for (const auto& kv : targets) {