diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 8c60fe6d0ee7..49079fbc107e 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -803,7 +804,7 @@ void VMCompiler::Compile(const Module& mod_ref, // Run some optimizations first, this code should // be moved to pass manager. - context_.module = OptimizeModule(mod_ref); + context_.module = OptimizeModule(mod_ref, targets_); // Populate the global map. // @@ -844,18 +845,63 @@ void VMCompiler::Compile(const Module& mod_ref, } } -Module VMCompiler::OptimizeModule(const Module& mod) { - // TODO(@icemelon9): check number of targets and build config, add more optimization pass - transform::Sequential seq({transform::SimplifyInference(), - transform::InlinePrimitives(), - // TODO(@wweic): FuseOps pass currently don't handle Let - // For now, we put FuseOps before ToANormalForm to enable it - transform::FuseOps(), - transform::ToANormalForm(), - transform::LambdaLift(), - transform::InlinePrimitives()}); - auto pass_ctx = transform::PassContext::Create(); +Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) { + Array pass_seqs; + // Run all dialect legalization passes. + pass_seqs.push_back(relay::qnn::transform::Legalize()); + + // Legalize pass is restricted to homogeneous execution for now. + if (targets.size() == 1) { + pass_seqs.push_back(transform::Legalize()); + } + + pass_seqs.push_back(transform::SimplifyInference()); + PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + Expr expr = args[0]; + if (expr.as()) { + auto call_node = expr.as(); + auto op_node = call_node->op.as(); + if (op_node->name == "cast") { + auto attrs = call_node->attrs.as(); + if (attrs->dtype == Int(32)) { + *rv = true; + } + } + } + *rv = false; + }); + pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); + pass_seqs.push_back(transform::InlinePrimitives()); + + pass_seqs.push_back(transform::CombineParallelConv2D(3)); + pass_seqs.push_back(transform::CombineParallelDense(3)); + pass_seqs.push_back(transform::FoldConstant()); + pass_seqs.push_back(transform::FoldScaleAxis()); + pass_seqs.push_back(transform::CanonicalizeCast()); + pass_seqs.push_back(transform::CanonicalizeOps()); + + // Alter layout transformation is only applied to homogeneous execution yet. + if (targets.size() == 1) { + pass_seqs.push_back(transform::AlterOpLayout()); + } + + pass_seqs.push_back(transform::FoldConstant()); + + pass_seqs.push_back(transform::FuseOps()); + pass_seqs.push_back(transform::ToANormalForm()); + pass_seqs.push_back(transform::LambdaLift()); + pass_seqs.push_back(transform::InlinePrimitives()); + + transform::Sequential seq(pass_seqs); + transform::PassContext pass_ctx = PassContext::Current(); + // TODO(wweic): Support heterogenous execution tvm::With ctx(pass_ctx); + if (targets.size() == 1) { + for (const auto& kv : targets) { + With tctx(kv.second); + return seq(mod); + } + } return seq(mod); } diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index bfe19ac2140e..14a5035b20dc 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -105,7 +105,7 @@ class VMCompiler : public runtime::ModuleNode { const tvm::Target& target_host); protected: - Module OptimizeModule(const Module& mod); + Module OptimizeModule(const Module& mod, const TargetsMap& targets); void PopulateGlobalMap();