Skip to content

Commit

Permalink
[Relay][VM] Add more passes to VMCompiler (#4058)
Browse files Browse the repository at this point in the history
* [Relay][VM] Add more passes to VMCompiler

* Check build config

* Add todo
  • Loading branch information
wweic authored and zhiics committed Oct 5, 2019
1 parent 7084081 commit 92ffa06
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 13 deletions.
70 changes: 58 additions & 12 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/relay/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/qnn/transform.h>
#include <tvm/logging.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
Expand Down Expand Up @@ -803,7 +804,7 @@ void VMCompiler::Compile(const Module& mod_ref,

// Run some optimizations first, this code should
// be moved to pass manager.
context_.module = OptimizeModule(mod_ref);
context_.module = OptimizeModule(mod_ref, targets_);

// Populate the global map.
//
Expand Down Expand Up @@ -844,18 +845,63 @@ void VMCompiler::Compile(const Module& mod_ref,
}
}

Module VMCompiler::OptimizeModule(const Module& mod) {
// TODO(@icemelon9): check number of targets and build config, add more optimization pass
transform::Sequential seq({transform::SimplifyInference(),
transform::InlinePrimitives(),
// TODO(@wweic): FuseOps pass currently don't handle Let
// For now, we put FuseOps before ToANormalForm to enable it
transform::FuseOps(),
transform::ToANormalForm(),
transform::LambdaLift(),
transform::InlinePrimitives()});
auto pass_ctx = transform::PassContext::Create();
Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) {
Array<Pass> pass_seqs;
// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());

// Legalize pass is restricted to homogeneous execution for now.
if (targets.size() == 1) {
pass_seqs.push_back(transform::Legalize());
}

pass_seqs.push_back(transform::SimplifyInference());
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
if (expr.as<CallNode>()) {
auto call_node = expr.as<CallNode>();
auto op_node = call_node->op.as<OpNode>();
if (op_node->name == "cast") {
auto attrs = call_node->attrs.as<CastAttrs>();
if (attrs->dtype == Int(32)) {
*rv = true;
}
}
}
*rv = false;
});
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::InlinePrimitives());

pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeCast());
pass_seqs.push_back(transform::CanonicalizeOps());

// Alter layout transformation is only applied to homogeneous execution yet.
if (targets.size() == 1) {
pass_seqs.push_back(transform::AlterOpLayout());
}

pass_seqs.push_back(transform::FoldConstant());

pass_seqs.push_back(transform::FuseOps());
pass_seqs.push_back(transform::ToANormalForm());
pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives());

transform::Sequential seq(pass_seqs);
transform::PassContext pass_ctx = PassContext::Current();
// TODO(wweic): Support heterogenous execution
tvm::With<relay::transform::PassContext> ctx(pass_ctx);
if (targets.size() == 1) {
for (const auto& kv : targets) {
With<Target> tctx(kv.second);
return seq(mod);
}
}
return seq(mod);
}

Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class VMCompiler : public runtime::ModuleNode {
const tvm::Target& target_host);

protected:
Module OptimizeModule(const Module& mod);
Module OptimizeModule(const Module& mod, const TargetsMap& targets);

void PopulateGlobalMap();

Expand Down

0 comments on commit 92ffa06

Please sign in to comment.