Skip to content

Commit

Permalink
restructure vm compiler to reduce task extraction time
Browse files Browse the repository at this point in the history
  • Loading branch information
icemelon committed Jan 8, 2020
1 parent f5499b1 commit ed1dd3d
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 43 deletions.
4 changes: 3 additions & 1 deletion python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def _lower(mod,
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
return grc.codegen(mod["main"])
# default case
return relay.vm.compile(mod, target=target, params=params)
compiler = relay.vm.VMCompiler()
compiler.set_params(params)
compiler.lower(mod, target=target)


def extract_from_program(mod, params, ops, target, target_host=None,
Expand Down
43 changes: 29 additions & 14 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ def run(self, *args, **kwargs):


def compile(mod, target=None, target_host=None, params=None):
"""
"""Compile the module to VM executable.
Parameters
----------
mod : relay.Module
Expand Down Expand Up @@ -393,21 +394,19 @@ def compile(mod, target=None, target_host=None, params=None):
The VM executable that contains both library code and bytecode.
"""
compiler = VMCompiler()

target = compiler.update_target(target)
target_host = compiler.update_target_host(target, target_host)
if params:
compiler.set_params(params)
tophub_context = compiler.tophub_context(target)
with tophub_context:
compiler._compile(mod, target, target_host)
return Executable(compiler._get_exec())
compiler.lower(mod, target, target_host)
compiler.codegen()
return compiler.get_exec()


class VMCompiler(object):
"""Build Relay module to run on VM runtime."""
def __init__(self):
self.mod = _vm._VMCompiler()
self._compile = self.mod["compile"]
self._lower = self.mod["lower"]
self._codegen = self.mod["codegen"]
self._get_exec = self.mod["get_executable"]
self._set_params_func = self.mod["set_params"]

Expand All @@ -420,8 +419,24 @@ def set_params(self, params):
inputs[name] = _expr.const(param)
self._set_params_func(inputs)

def update_target(self, target):
"""Update target"""
def lower(self, mod, target=None, target_host=None):
"""Lower the module to VM bytecode."""
target = self._update_target(target)
target_host = self._update_target_host(target, target_host)
tophub_context = self._tophub_context(target)
with tophub_context:
self._lower(mod, target, target_host)

def codegen(self):
"""Generate the machine code."""
self._codegen()

def get_exec(self):
"""Return the executable."""
return Executable(self._get_exec())

def _update_target(self, target):
"""Update target."""
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
Expand All @@ -439,8 +454,8 @@ def update_target(self, target):
"{}".format(type(target)))
return tgts

def update_target_host(self, target, target_host):
"""Update target host"""
def _update_target_host(self, target, target_host):
"""Update target host."""
target_host = None if target_host == "" else target_host
if not target_host:
for device_type, tgt in target.items():
Expand All @@ -453,7 +468,7 @@ def update_target_host(self, target, target_host):
target_host = tvm.target.create(target_host)
return target_host

def tophub_context(self, target):
def _tophub_context(self, target):
# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
Expand Down
46 changes: 28 additions & 18 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -743,11 +743,16 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {

PackedFunc VMCompiler::GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
if (name == "compile") {
if (name == "lower") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3);
Module mod = args[0];
this->Compile(mod, args[1], args[2]);
this->Lower(mod, args[1], args[2]);
});
} if (name == "codegen") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 0);
this->Codegen();
});
} else if (name == "get_executable") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Expand Down Expand Up @@ -802,9 +807,9 @@ relay::Function VMCompiler::BindParamsByName(
return ret;
}

void VMCompiler::Compile(Module mod,
const TargetsMap& targets,
const tvm::Target& target_host) {
void VMCompiler::Lower(Module mod,
const TargetsMap& targets,
const tvm::Target& target_host) {
CHECK_EQ(targets.size(), 1)
<< "Currently VM compiler doesn't support heterogeneous compilation";
if (params_.size()) {
Expand All @@ -813,7 +818,7 @@ void VMCompiler::Compile(Module mod,
mod->Add(gvar, f);
}

InitVM();
exec_ = make_object<Executable>();
targets_ = targets;
target_host_ = target_host;

Expand Down Expand Up @@ -852,11 +857,20 @@ void VMCompiler::Compile(Module mod,
exec_->constants.push_back(vm::Tensor(data));
}

LibraryCodegen();

// update global function map
for (auto gv : context_.global_map) {
exec_->global_map.insert({gv.first->name_hint, gv.second});
}

// update primitive function map
size_t primitive_index = 0;
for (const auto& cfunc : context_.cached_funcs) {
if (cfunc->target->str() == "ext_dev") {
exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
} else {
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
}
}
}

Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) {
Expand Down Expand Up @@ -942,7 +956,11 @@ void VMCompiler::PopulateGlobalMap() {
}
}

void VMCompiler::LibraryCodegen() {
void VMCompiler::Codegen() {
if (!context_.module.defined()) {
LOG(WARNING) << "Did you forget to call VMCompiler::Lower?";
return;
}
auto const &cached_funcs = context_.cached_funcs;
if (cached_funcs.size() == 0) {
return;
Expand Down Expand Up @@ -979,15 +997,7 @@ void VMCompiler::LibraryCodegen() {
}
}
}
exec_->lib = mod;
size_t primitive_index = 0;
for (auto cfunc : cached_funcs) {
if (cfunc->target->str() == "ext_dev") {
exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
} else {
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
}
}
exec_->lib = mod;
}

runtime::Module CreateVMCompiler() {
Expand Down
15 changes: 6 additions & 9 deletions src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,6 @@ class VMCompiler : public runtime::ModuleNode {
return "VMCompiler";
}

void InitVM() {
exec_ = make_object<Executable>();
}

/*!
* \brief Set the parameters
*
Expand All @@ -111,9 +107,12 @@ class VMCompiler : public runtime::ModuleNode {
to target mapping. For homogeneous compilation, it is a build target.
* \param target_host Host compilation target, if target is device.
*/
void Compile(Module mod,
const TargetsMap& targets,
const tvm::Target& target_host);
void Lower(Module mod,
const TargetsMap& targets,
const tvm::Target& target_host);

/*! \brief Generate the machine code for lowered functions. */
void Codegen();

protected:
/*!
Expand All @@ -130,8 +129,6 @@ class VMCompiler : public runtime::ModuleNode {

void PopulateGlobalMap();

void LibraryCodegen();

protected:
/*! \brief Target devices. */
TargetsMap targets_;
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,4 +572,4 @@ def test_add_op_broadcast():


if __name__ == "__main__":
pytest.main()
pytest.main([__file__])

0 comments on commit ed1dd3d

Please sign in to comment.