Skip to content

Commit

Permalink
Move TEcompiler to VMCompilerContext; add global func into IRmodule w…
Browse files Browse the repository at this point in the history
…hen lowering in TEcompiler
  • Loading branch information
YuchenJin committed Jul 22, 2021
1 parent aed5d3b commit a30cf6a
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ class TECompilerImpl : public TECompilerNode {
auto target = Target("ext_dev");
auto global_var = GlobalVar(func_name);
global_var->checked_type_ = key->source_func->checked_type();
ir_module->Add(global_var, key->source_func);
value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
return value;
}
Expand Down
11 changes: 3 additions & 8 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@
#include "../../../target/source/codegen_source_base.h"
#include "../../op/op_common.h"
#include "../../transforms/pass_utils.h"
#include "../te_compiler.h"
#include "../te_compiler_cache.h"
#include "../utils.h"
#include "compiler.h"

Expand Down Expand Up @@ -466,7 +464,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
// Lower shape function
CCacheKey key(func, target_host_);
auto cfunc = compiler_->LowerShapeFunc(key);
auto cfunc = context_->compiler->LowerShapeFunc(key);
int op_index = -1;
// pick the only function inside the context
ICHECK_EQ(cfunc->funcs->functions.size(), 1);
Expand Down Expand Up @@ -552,7 +550,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {

CCacheKey key(func, target);
auto mangle_fn = [](String name) { return name; };
auto cfunc = compiler_->Lower(key, mangle_fn);
auto cfunc = context_->compiler->Lower(key, mangle_fn);

auto op_index = -1;
if (func->GetAttr<String>(attr::kCompiler).defined()) {
Expand Down Expand Up @@ -858,8 +856,6 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
size_t last_register_;
/*! \brief Total number of virtual registers allocated. */
size_t registers_num_;
/*! \brief Compiler engine to lower primitive functions. */
TECompiler compiler_;
/*! \brief Global shared meta data */
VMCompilerContext* context_;
/*! \brief Target devices. */
Expand Down Expand Up @@ -1185,8 +1181,7 @@ void VMCompiler::Codegen() {
}
}

TECompiler compiler;
auto ext_mods = compiler->LowerExternalFunctions();
auto ext_mods = context_.compiler->LowerExternalFunctions();

runtime::Module lib;
if (funcs.size() > 0) {
Expand Down
2 changes: 2 additions & 0 deletions src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ struct VMCompilerContext {
TagMap tag_map;
// Map from global var to a unique integer
GlobalMap global_map;
// TEcompiler for lowering
tec::TECompiler compiler;
// List of constants
std::vector<NDArray> constants;
// Device type for constants
Expand Down

0 comments on commit a30cf6a

Please sign in to comment.