Skip to content

Commit

Permalink
[REFACTOR][IR] Allow Module to store BaseFunc.
Browse files Browse the repository at this point in the history
Under the unified IR. We will allow a single IRModule
to store different function variants, such as relay::Function,
ExternFunc, and low-level function.

This PR changes relay::Function -> BaseFunc in the module file
to support multiple function variants.
  • Loading branch information
tqchen committed Jan 11, 2020
1 parent 12e51e6 commit fc4c42e
Show file tree
Hide file tree
Showing 15 changed files with 206 additions and 126 deletions.
16 changes: 8 additions & 8 deletions include/tvm/relay/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ struct Module;
class ModuleNode : public RelayNode {
public:
/*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, Function> functions;
tvm::Map<GlobalVar, BaseFunc> functions;
/*! \brief A map from global type vars to ADT type data. */
tvm::Map<GlobalTypeVar, TypeData> type_definitions;

Expand All @@ -75,7 +75,7 @@ class ModuleNode : public RelayNode {
v->Visit("global_type_var_map_", &global_type_var_map_);
}

TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs,
TVM_DLL static Module make(tvm::Map<GlobalVar, BaseFunc> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
std::unordered_set<std::string> imports = {});

Expand All @@ -86,7 +86,7 @@ class ModuleNode : public RelayNode {
* \param update Controls whether you can replace a definition in the
* environment.
*/
TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false);
TVM_DLL void Add(const GlobalVar& var, const BaseFunc& func, bool update = false);

/*!
* \brief Add a function to the global environment.
Expand All @@ -95,7 +95,7 @@ class ModuleNode : public RelayNode {
*
* It does not do type inference as Add does.
*/
TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func);
TVM_DLL void AddUnchecked(const GlobalVar& var, const BaseFunc& func);

/*!
* \brief Add a type-level definition to the global environment.
Expand Down Expand Up @@ -124,7 +124,7 @@ class ModuleNode : public RelayNode {
* \param var The name of the global function to update.
* \param func The new function.
*/
TVM_DLL void Update(const GlobalVar& var, const Function& func);
TVM_DLL void Update(const GlobalVar& var, const BaseFunc& func);

/*!
* \brief Update a type definition in the global environment.
Expand Down Expand Up @@ -184,14 +184,14 @@ class ModuleNode : public RelayNode {
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
TVM_DLL Function Lookup(const GlobalVar& var) const;
TVM_DLL BaseFunc Lookup(const GlobalVar& var) const;

/*!
* \brief Look up a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
TVM_DLL Function Lookup(const std::string& name) const;
TVM_DLL BaseFunc Lookup(const std::string& name) const;

/*!
* \brief Look up a global type definition by its variable.
Expand Down Expand Up @@ -256,7 +256,7 @@ class ModuleNode : public RelayNode {
*/
TVM_DLL static Module FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs = {},
const tvm::Map<GlobalVar, BaseFunc>& global_funcs = {},
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});

static constexpr const char* _type_key = "relay.Module";
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ class RelayBuildModule : public runtime::ModuleNode {
// Optimize input Relay Function and returns Relay Module
relay::Module relay_module = Optimize(func, targets_, params);
// Get the updated function.
func = relay_module->Lookup("main");
func = Downcast<Function>(relay_module->Lookup("main"));

// Generate code for the updated function.
graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
Expand Down
29 changes: 20 additions & 9 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,13 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
CHECK(it != context_->global_map.end());
DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint
<< " with func_index=" << it->second;
auto func = context_->module->Lookup(global);

// TODO(tvm-team):
// Think about mixed call into global that is not a relay::Function
// perhaps establish as an invariance(all functions in mod must be relay::Function)
auto func = Downcast<Function>(context_->module->Lookup(global));


if (IsClosure(func)) {
auto arity = func->params.size();
Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
Expand Down Expand Up @@ -813,7 +819,10 @@ void VMCompiler::Lower(Module mod,
CHECK_EQ(targets.size(), 1)
<< "Currently VM compiler doesn't support heterogeneous compilation";
if (params_.size()) {
auto f = BindParamsByName(mod->Lookup("main"), params_);
BaseFunc base_func = mod->Lookup("main");
CHECK(base_func->IsInstance<FunctionNode>())
<< "VM compiler expects to compile relay::Function";
auto f = BindParamsByName(Downcast<Function>(base_func), params_);
auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f);
}
Expand All @@ -837,13 +846,15 @@ void VMCompiler::Lower(Module mod,

for (auto named_func : context_.module->functions) {
auto gvar = named_func.first;
auto func = named_func.second;
VMFunctionCompiler func_compiler(&context_, targets_, target_host_);
auto vm_func = func_compiler.Compile(gvar, func);

size_t func_index = context_.global_map.at(gvar);
CHECK(func_index < exec_->functions.size());
exec_->functions[func_index] = vm_func;
if (auto* n = named_func.second.as<FunctionNode>()) {
auto func = GetRef<Function>(n);
VMFunctionCompiler func_compiler(&context_, targets_, target_host_);
auto vm_func = func_compiler.Compile(gvar, func);

size_t func_index = context_.global_map.at(gvar);
CHECK(func_index < exec_->functions.size());
exec_->functions[func_index] = vm_func;
}
}

#if USE_RELAY_DEBUG
Expand Down
30 changes: 17 additions & 13 deletions src/relay/backend/vm/inline_primitives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,23 @@ struct PrimitiveInliner : ExprMutator {
auto gvar_funcs = module_->functions;
for (auto pair : gvar_funcs) {
auto global = pair.first;
auto func = pair.second;
DLOG(INFO) << "Before inlining primitives: " << global
<< std::endl << AsText(func, false);

func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(global, func, true);

DLOG(INFO) << "After inlining primitives: " << global
<< std::endl << AsText(func, false);
auto base_func = pair.second;
if (auto* n = base_func.as<FunctionNode>()) {
auto func = GetRef<Function>(n);

DLOG(INFO) << "Before inlining primitives: " << global
<< std::endl << AsText(func, false);

func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(global, func, true);

DLOG(INFO) << "After inlining primitives: " << global
<< std::endl << AsText(func, false);
}
}
return module_;
}
Expand Down
16 changes: 9 additions & 7 deletions src/relay/backend/vm/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,15 @@ class LambdaLifter : public ExprMutator {
// There is an ordering bug here.
auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) {
auto func = pair.second;
func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(pair.first, func, true);
if (auto* n = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(n);
func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(pair.first, func, true);
}
}
return module_;
}
Expand Down
Loading

0 comments on commit fc4c42e

Please sign in to comment.