Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Autotvm] Use VM compile to extract autotvm tasks #4328

Merged
merged 8 commits into from
Jan 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/api/python/relay/backend.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ tvm.relay.backend

.. automodule:: tvm.relay.backend.graph_runtime_codegen
:members:

.. automodule:: tvm.relay.backend.vm
:members:
35 changes: 19 additions & 16 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


# TODO(moreau89) find a more elegant way to lower for VTAs
def _lower(func,
def _lower(mod,
target,
params):
""" Helper to lower VTA properly.
Expand All @@ -45,25 +45,25 @@ def _lower(func,
with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
import vta
with vta.build_config():
mod, _ = relay.optimize(func, target, params)
mod, _ = relay.optimize(mod, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
return grc.codegen(mod["main"])
grc.codegen(mod["main"])
# default case
mod, _ = relay.optimize(func, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
return grc.codegen(mod["main"])
compiler = relay.vm.VMCompiler()
compiler.set_params(params)
compiler.lower(mod, target=target)


def extract_from_program(func, params, ops, target, target_host=None,
def extract_from_program(mod, params, ops, target, target_host=None,
template_keys=None):
""" Extract tuning tasks from a relay program.

This function is the single program version of extract_from_multiple_program.

Parameters
----------
func: relay.expr.Function
The func to tune
mod: relay.module.Module or relay.expr.Function
The module or function to tune
params: dict of str to numpy array
The associated parameters of the program
ops: List of relay op
Expand All @@ -81,11 +81,11 @@ def extract_from_program(func, params, ops, target, target_host=None,
task: Array of autotvm.task.Task
collected tasks
"""
return extract_from_multiple_program([func], [params], ops, target, target_host,
template_keys=template_keys)
return extract_from_multiple_program([mod], [params], ops, target, target_host,
template_keys)


def extract_from_multiple_program(funcs, params, ops, target, target_host=None,
def extract_from_multiple_program(mods, params, ops, target, target_host=None,
template_keys=None):
""" Extract tuning tasks from multiple relay programs.

Expand All @@ -94,8 +94,8 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None,

Parameters
----------
funcs: List of relay.expr.Function
The list of functions to tune
mods: List[relay.module.Module] or List[relay.expr.Function]
The list of modules or functions to tune
params: List of dict of str to numpy array
The associated parameters of the programs
ops: List of relay op
Expand Down Expand Up @@ -145,10 +145,13 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None,
old_state = logger.disabled
logger.disabled = True

for func, param in zip(funcs, params):
for mod, param in zip(mods, params):
if isinstance(mod, relay.expr.Function):
mod = relay.Module.from_expr(mod)
assert isinstance(mod, relay.module.Module), \
"only support relay Module or Function to be tuned"
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
mod = relay.Module.from_expr(func)
build_thread = threading.Thread(target=_lower,
args=(mod, target, param))
build_thread.start()
Expand Down
85 changes: 68 additions & 17 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ def run(self, *args, **kwargs):


def compile(mod, target=None, target_host=None, params=None):
"""
"""Compile the module to VM executable. A helper function for VMCompiler.

Parameters
----------
mod : relay.Module
Expand Down Expand Up @@ -393,35 +394,82 @@ def compile(mod, target=None, target_host=None, params=None):
The VM executable that contains both library code and bytecode.
"""
compiler = VMCompiler()

target = compiler.update_target(target)
target_host = compiler.update_target_host(target, target_host)
if params:
compiler.set_params(params)
tophub_context = compiler.tophub_context(target)
with tophub_context:
compiler._compile(mod, target, target_host)
return Executable(compiler._get_exec())
compiler.lower(mod, target, target_host)
compiler.codegen()
return compiler.get_exec()


class VMCompiler(object):
"""Build Relay module to run on VM runtime."""
"""Compiler that compiles Relay module to VM executable."""
def __init__(self):
self.mod = _vm._VMCompiler()
self._compile = self.mod["compile"]
self._lower = self.mod["lower"]
self._codegen = self.mod["codegen"]
self._get_exec = self.mod["get_executable"]
self._set_params_func = self.mod["set_params"]

def set_params(self, params):
"""Set constant parameters for the model"""
"""Set constant parameters for the model.

Parameters
----------
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
"""
inputs = {}
for name, param in params.items():
if isinstance(param, np.ndarray):
param = _nd.array(param)
inputs[name] = _expr.const(param)
self._set_params_func(inputs)

def update_target(self, target):
"""Update target"""
def lower(self, mod, target=None, target_host=None):
"""Lower the module to VM bytecode.

Parameters
----------
mod : relay.Module
The Relay module to build.

target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.

target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
"""
target = self._update_target(target)
target_host = self._update_target_host(target, target_host)
tophub_context = self._tophub_context(target)
with tophub_context:
self._lower(mod, target, target_host)

def codegen(self):
"""Generate the kernel library."""
self._codegen()

def get_exec(self):
"""Get the VM executable.

Returns
-------
exec : Executable
The VM executable that contains both library code and bytecode.
"""
return Executable(self._get_exec())

def _update_target(self, target):
"""Update target."""
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
Expand All @@ -439,8 +487,8 @@ def update_target(self, target):
"{}".format(type(target)))
return tgts

def update_target_host(self, target, target_host):
"""Update target host"""
def _update_target_host(self, target, target_host):
"""Update target host."""
target_host = None if target_host == "" else target_host
if not target_host:
for device_type, tgt in target.items():
Expand All @@ -449,9 +497,12 @@ def update_target_host(self, target, target_host):
break
if not target_host:
target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
return tvm.target.create(target_host)
if isinstance(target_host, str):
target_host = tvm.target.create(target_host)
return target_host

def tophub_context(self, target):
def _tophub_context(self, target):
"""Get the autotvm context."""
# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
Expand Down
44 changes: 27 additions & 17 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -743,11 +743,16 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {

PackedFunc VMCompiler::GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
if (name == "compile") {
if (name == "lower") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3);
Module mod = args[0];
this->Compile(mod, args[1], args[2]);
this->Lower(mod, args[1], args[2]);
});
} else if (name == "codegen") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 0);
this->Codegen();
});
} else if (name == "get_executable") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Expand Down Expand Up @@ -802,9 +807,9 @@ relay::Function VMCompiler::BindParamsByName(
return ret;
}

void VMCompiler::Compile(Module mod,
const TargetsMap& targets,
const tvm::Target& target_host) {
void VMCompiler::Lower(Module mod,
const TargetsMap& targets,
const tvm::Target& target_host) {
CHECK_EQ(targets.size(), 1)
<< "Currently VM compiler doesn't support heterogeneous compilation";
if (params_.size()) {
Expand All @@ -813,7 +818,7 @@ void VMCompiler::Compile(Module mod,
mod->Add(gvar, f);
}

InitVM();
exec_ = make_object<Executable>();
targets_ = targets;
target_host_ = target_host;

Expand Down Expand Up @@ -852,11 +857,20 @@ void VMCompiler::Compile(Module mod,
exec_->constants.push_back(vm::Tensor(data));
}

LibraryCodegen();

// update global function map
for (auto gv : context_.global_map) {
exec_->global_map.insert({gv.first->name_hint, gv.second});
}

// update primitive function map
size_t primitive_index = 0;
for (const auto& cfunc : context_.cached_funcs) {
if (cfunc->target->str() == "ext_dev") {
exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
} else {
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
}
}
}

Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) {
Expand Down Expand Up @@ -942,7 +956,11 @@ void VMCompiler::PopulateGlobalMap() {
}
}

void VMCompiler::LibraryCodegen() {
void VMCompiler::Codegen() {
if (!context_.module.defined()) {
LOG(WARNING) << "Did you forget to call VMCompiler::Lower?";
return;
}
auto const &cached_funcs = context_.cached_funcs;
if (cached_funcs.size() == 0) {
return;
Expand Down Expand Up @@ -980,14 +998,6 @@ void VMCompiler::LibraryCodegen() {
}
}
exec_->lib = mod;
size_t primitive_index = 0;
for (auto cfunc : cached_funcs) {
if (cfunc->target->str() == "ext_dev") {
exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
} else {
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
}
}
}

runtime::Module CreateVMCompiler() {
Expand Down
17 changes: 7 additions & 10 deletions src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,6 @@ class VMCompiler : public runtime::ModuleNode {
return "VMCompiler";
}

void InitVM() {
exec_ = make_object<Executable>();
}

/*!
* \brief Set the parameters
*
Expand All @@ -104,16 +100,19 @@ class VMCompiler : public runtime::ModuleNode {
void SetParam(const std::string& name, runtime::NDArray data_in);

/*!
* \brief Compile functions in a Module
* \brief Lower the functions in a Module
*
* \param mod Relay Module
* \param targets For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
* \param target_host Host compilation target, if target is device.
*/
void Compile(Module mod,
const TargetsMap& targets,
const tvm::Target& target_host);
void Lower(Module mod,
const TargetsMap& targets,
const tvm::Target& target_host);

/*! \brief Generate the machine code for lowered functions. */
void Codegen();

protected:
/*!
Expand All @@ -130,8 +129,6 @@ class VMCompiler : public runtime::ModuleNode {

void PopulateGlobalMap();

void LibraryCodegen();

protected:
/*! \brief Target devices. */
TargetsMap targets_;
Expand Down
Loading