Skip to content

Commit

Permalink
Merge branch 'vm_te_migration' of github.com:mikepapadim/tvm into vm_…
Browse files Browse the repository at this point in the history
…te_migration
  • Loading branch information
mikepapadim committed Jul 20, 2021
2 parents 2d1847c + a33e069 commit 9fd6552
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
19 changes: 10 additions & 9 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@
#include <vector>

#include "../../../target/source/codegen_source_base.h"
#include "../../backend/compile_engine.h"
#include "../../op/op_common.h"
#include "../../transforms/pass_utils.h"
#include "../te_compiler_cache.h"
#include "../utils.h"
#include "compiler.h"
#include "../te_compiler.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -79,6 +80,7 @@ namespace vm {
using namespace tvm::runtime;
using namespace tvm::runtime::vm;
using namespace relay::transform;
using namespace tec;

// (@jroesch): VM passes, eventually declare as passes.
bool IsClosure(const Function& func);
Expand Down Expand Up @@ -253,7 +255,6 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
ExprDeviceMap expr_device_map)
: last_register_(0),
registers_num_(0),
engine_(CompileEngine::Global()),
context_(context),
target_host_(target_host),
expr_device_map_(std::move(expr_device_map)) {
Expand Down Expand Up @@ -465,7 +466,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
// Lower shape function
CCacheKey key(func, target_host_);
auto cfunc = engine_->LowerShapeFunc(key);
auto cfunc = compiler_->LowerShapeFunc(key);
int op_index = -1;
// pick the only function inside the context
ICHECK_EQ(cfunc->funcs->functions.size(), 1);
Expand Down Expand Up @@ -551,7 +552,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {

CCacheKey key(func, target);
auto mangle_fn = [](String name) { return name; };
auto cfunc = engine_->Lower(key, mangle_fn);
auto cfunc = compiler_->Lower(key, mangle_fn);

auto op_index = -1;
if (func->GetAttr<String>(attr::kCompiler).defined()) {
Expand Down Expand Up @@ -858,7 +859,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
/*! \brief Total number of virtual registers allocated. */
size_t registers_num_;
/*! \brief Compiler engine to lower primitive functions. */
CompileEngine engine_;
TECompiler compiler_;
/*! \brief Global shared meta data */
VMCompilerContext* context_;
/*! \brief Target devices. */
Expand Down Expand Up @@ -1184,8 +1185,9 @@ void VMCompiler::Codegen() {
}
}

auto compile_engine = CompileEngine::Global();
auto ext_mods = compile_engine->LowerExternalFunctions();
TECompiler compiler;
auto ext_mods = compiler->LowerExternalFunctions();

runtime::Module lib;
if (funcs.size() > 0) {
lib = tvm::build(funcs, target_host_);
Expand All @@ -1196,8 +1198,7 @@ void VMCompiler::Codegen() {
}
lib = codegen::CreateMetadataModule(params_, lib, ext_mods, target_host_, runtime::Metadata());
exec_->SetLib(lib);
CompileEngine::Global()->Clear();
}
}

ExprDeviceMap VMCompiler::AnalyzeContext() const {
Device default_device;
Expand Down
6 changes: 4 additions & 2 deletions src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@

#include "../../../runtime/vm/naive_allocator.h"
#include "../../../runtime/vm/profiler/vm.h"
#include "../../backend/compile_engine.h"
#include "../../transforms/pass_utils.h"
#include "../te_compiler_cache.h"
#include "../te_compiler.h"


namespace tvm {
namespace relay {
Expand Down Expand Up @@ -80,7 +82,7 @@ struct VMCompilerContext {
// Device type for constants
std::vector<Index> const_device_type;
// List of cached functions
std::vector<CachedFunc> cached_funcs;
std::vector<tec::CachedFunc> cached_funcs;
// The functions that have been lowered.
std::unordered_map<tir::PrimFunc, size_t, ObjectPtrHash, ObjectPtrEqual> seen_funcs;
};
Expand Down
11 changes: 8 additions & 3 deletions src/relay/transforms/memory_alloc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,17 @@
#include <unordered_set>
#include <vector>

#include "../backend/compile_engine.h"
#include "../op/memory/memory.h"
#include "../op/vm/vm.h"
#include "./pass_utils.h"
#include "let_list.h"
#include "pattern_utils.h"

#include "../backend/te_compiler_cache.h"
#include "../backend/te_compiler.h"

using namespace tvm::runtime;
using namespace tvm::relay::tec;

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -271,9 +274,11 @@ class DialectRewriter : public ExprMutator {
Array<Expr> EmitShapeFunc(LetList* scope, const Function& func,
const std::vector<Expr>& new_args) {
Array<Expr> shape_func_ins;
auto engine = CompileEngine::Global();

TECompiler compiler;

CCacheKey key(func, target_host_);
auto cfunc = engine->LowerShapeFunc(key);
auto cfunc = compiler->LowerShapeFunc(key);
auto input_states = cfunc->shape_func_param_states;

Array<Integer> is_inputs;
Expand Down

0 comments on commit 9fd6552

Please sign in to comment.