diff --git a/docs/api/python/relay/backend.rst b/docs/api/python/relay/backend.rst index fa6ab883c20d..c30f226e8437 100644 --- a/docs/api/python/relay/backend.rst +++ b/docs/api/python/relay/backend.rst @@ -28,3 +28,6 @@ tvm.relay.backend .. automodule:: tvm.relay.backend.graph_runtime_codegen :members: + +.. automodule:: tvm.relay.backend.vm + :members: diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 4a407714b414..55763afcf7bc 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -32,7 +32,7 @@ # TODO(moreau89) find a more elegant way to lower for VTAs -def _lower(func, +def _lower(mod, target, params): """ Helper to lower VTA properly. @@ -45,16 +45,16 @@ def _lower(func, with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): import vta with vta.build_config(): - mod, _ = relay.optimize(func, target, params) + mod, _ = relay.optimize(mod, target, params) grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) - return grc.codegen(mod["main"]) + grc.codegen(mod["main"]) # default case - mod, _ = relay.optimize(func, target, params) - grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) - return grc.codegen(mod["main"]) + compiler = relay.vm.VMCompiler() + compiler.set_params(params) + compiler.lower(mod, target=target) -def extract_from_program(func, params, ops, target, target_host=None, +def extract_from_program(mod, params, ops, target, target_host=None, template_keys=None): """ Extract tuning tasks from a relay program. @@ -62,8 +62,8 @@ def extract_from_program(func, params, ops, target, target_host=None, Parameters ---------- - func: relay.expr.Function - The func to tune + mod: relay.module.Module or relay.expr.Function + The module or function to tune params: dict of str to numpy array The associated parameters of the program ops: List of relay op @@ -81,11 +81,11 @@ def extract_from_program(func, params, ops, target, target_host=None, task: Array of autotvm.task.Task collected tasks """ - return extract_from_multiple_program([func], [params], ops, target, target_host, - template_keys=template_keys) + return extract_from_multiple_program([mod], [params], ops, target, target_host, + template_keys) -def extract_from_multiple_program(funcs, params, ops, target, target_host=None, +def extract_from_multiple_program(mods, params, ops, target, target_host=None, template_keys=None): """ Extract tuning tasks from multiple relay programs. @@ -94,8 +94,8 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None, Parameters ---------- - funcs: List of relay.expr.Function - The list of functions to tune + mods: List[relay.module.Module] or List[relay.expr.Function] + The list of modules or functions to tune params: List of dict of str to numpy array The associated parameters of the programs ops: List of relay op @@ -145,10 +145,13 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None, old_state = logger.disabled logger.disabled = True - for func, param in zip(funcs, params): + for mod, param in zip(mods, params): + if isinstance(mod, relay.expr.Function): + mod = relay.Module.from_expr(mod) + assert isinstance(mod, relay.module.Module), \ + "only support relay Module or Function to be tuned" relay.backend.compile_engine.get().clear() # wrap build call in thread to avoid multiprocessing problems - mod = relay.Module.from_expr(func) build_thread = threading.Thread(target=_lower, args=(mod, target, param)) build_thread.start() diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index a523722def61..bad4ac227d09 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -363,7 +363,8 @@ def run(self, *args, **kwargs): def compile(mod, target=None, target_host=None, params=None): - """ + """Compile the module to VM executable. A helper function for VMCompiler. + Parameters ---------- mod : relay.Module @@ -393,26 +394,31 @@ def compile(mod, target=None, target_host=None, params=None): The VM executable that contains both library code and bytecode. """ compiler = VMCompiler() - - target = compiler.update_target(target) - target_host = compiler.update_target_host(target, target_host) if params: compiler.set_params(params) - tophub_context = compiler.tophub_context(target) - with tophub_context: - compiler._compile(mod, target, target_host) - return Executable(compiler._get_exec()) + compiler.lower(mod, target, target_host) + compiler.codegen() + return compiler.get_exec() + class VMCompiler(object): - """Build Relay module to run on VM runtime.""" + """Compiler that compiles Relay module to VM executable.""" def __init__(self): self.mod = _vm._VMCompiler() - self._compile = self.mod["compile"] + self._lower = self.mod["lower"] + self._codegen = self.mod["codegen"] self._get_exec = self.mod["get_executable"] self._set_params_func = self.mod["set_params"] def set_params(self, params): - """Set constant parameters for the model""" + """Set constant parameters for the model. + + Parameters + ---------- + params : dict of str to NDArray + Input parameters to the graph that do not change + during inference time. Used for constant folding. + """ inputs = {} for name, param in params.items(): if isinstance(param, np.ndarray): @@ -420,8 +426,50 @@ def set_params(self, params): inputs[name] = _expr.const(param) self._set_params_func(inputs) - def update_target(self, target): - """Update target""" + def lower(self, mod, target=None, target_host=None): + """Lower the module to VM bytecode. + + Parameters + ---------- + mod : relay.Module + The Relay module to build. + + target : str, :any:`tvm.target.Target`, or dict of str(i.e. + device/context name) to str/tvm.target.Target, optional + For heterogeneous compilation, it is a dictionary indicating context + to target mapping. For homogeneous compilation, it is a build target. + + target_host : str or :any:`tvm.target.Target`, optional + Host compilation target, if target is device. + When TVM compiles device specific program such as CUDA, + we also need host(CPU) side code to interact with the driver + to setup the dimensions and parameters correctly. + target_host is used to specify the host side codegen target. + By default, llvm is used if it is enabled, + otherwise a stackvm intepreter is used. + """ + target = self._update_target(target) + target_host = self._update_target_host(target, target_host) + tophub_context = self._tophub_context(target) + with tophub_context: + self._lower(mod, target, target_host) + + def codegen(self): + """Generate the kernel library.""" + self._codegen() + + def get_exec(self): + """Get the VM executable. + + Returns + ------- + exec : Executable + The VM executable that contains both library code and bytecode. + """ + return Executable(self._get_exec()) + + def _update_target(self, target): + """Update target.""" target = target if target else tvm.target.current_target() if target is None: raise ValueError("Target is not set in env or passed as argument.") @@ -439,8 +487,8 @@ def update_target(self, target): "{}".format(type(target))) return tgts - def update_target_host(self, target, target_host): - """Update target host""" + def _update_target_host(self, target, target_host): + """Update target host.""" target_host = None if target_host == "" else target_host if not target_host: for device_type, tgt in target.items(): @@ -449,9 +497,12 @@ def update_target_host(self, target, target_host): break if not target_host: target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm" - return tvm.target.create(target_host) + if isinstance(target_host, str): + target_host = tvm.target.create(target_host) + return target_host - def tophub_context(self, target): + def _tophub_context(self, target): + """Get the autotvm context.""" # If current dispatch context is fallback context (the default root context), # then load pre-tuned parameters from TopHub if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 6a3c580aa56e..7946ee66007c 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -743,11 +743,16 @@ class VMFunctionCompiler : ExprFunctor { PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { - if (name == "compile") { + if (name == "lower") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.num_args, 3); Module mod = args[0]; - this->Compile(mod, args[1], args[2]); + this->Lower(mod, args[1], args[2]); + }); + } else if (name == "codegen") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.num_args, 0); + this->Codegen(); }); } else if (name == "get_executable") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -802,9 +807,9 @@ relay::Function VMCompiler::BindParamsByName( return ret; } -void VMCompiler::Compile(Module mod, - const TargetsMap& targets, - const tvm::Target& target_host) { +void VMCompiler::Lower(Module mod, + const TargetsMap& targets, + const tvm::Target& target_host) { CHECK_EQ(targets.size(), 1) << "Currently VM compiler doesn't support heterogeneous compilation"; if (params_.size()) { @@ -813,7 +818,7 @@ void VMCompiler::Compile(Module mod, mod->Add(gvar, f); } - InitVM(); + exec_ = make_object(); targets_ = targets; target_host_ = target_host; @@ -852,11 +857,20 @@ void VMCompiler::Compile(Module mod, exec_->constants.push_back(vm::Tensor(data)); } - LibraryCodegen(); - + // update global function map for (auto gv : context_.global_map) { exec_->global_map.insert({gv.first->name_hint, gv.second}); } + + // update primitive function map + size_t primitive_index = 0; + for (const auto& cfunc : context_.cached_funcs) { + if (cfunc->target->str() == "ext_dev") { + exec_->primitive_map.insert({cfunc->func_name, primitive_index++}); + } else { + exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++}); + } + } } Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) { @@ -942,7 +956,11 @@ void VMCompiler::PopulateGlobalMap() { } } -void VMCompiler::LibraryCodegen() { +void VMCompiler::Codegen() { + if (!context_.module.defined()) { + LOG(WARNING) << "Did you forget to call VMCompiler::Lower?"; + return; + } auto const &cached_funcs = context_.cached_funcs; if (cached_funcs.size() == 0) { return; @@ -980,14 +998,6 @@ void VMCompiler::LibraryCodegen() { } } exec_->lib = mod; - size_t primitive_index = 0; - for (auto cfunc : cached_funcs) { - if (cfunc->target->str() == "ext_dev") { - exec_->primitive_map.insert({cfunc->func_name, primitive_index++}); - } else { - exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++}); - } - } } runtime::Module CreateVMCompiler() { diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 2beab1536a18..7efcb4ba8d81 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -91,10 +91,6 @@ class VMCompiler : public runtime::ModuleNode { return "VMCompiler"; } - void InitVM() { - exec_ = make_object(); - } - /*! * \brief Set the parameters * @@ -104,16 +100,19 @@ class VMCompiler : public runtime::ModuleNode { void SetParam(const std::string& name, runtime::NDArray data_in); /*! - * \brief Compile functions in a Module + * \brief Lower the functions in a Module * * \param mod Relay Module * \param targets For heterogeneous compilation, it is a dictionary indicating context to target mapping. For homogeneous compilation, it is a build target. * \param target_host Host compilation target, if target is device. */ - void Compile(Module mod, - const TargetsMap& targets, - const tvm::Target& target_host); + void Lower(Module mod, + const TargetsMap& targets, + const tvm::Target& target_host); + + /*! \brief Generate the machine code for lowered functions. */ + void Codegen(); protected: /*! @@ -130,8 +129,6 @@ class VMCompiler : public runtime::ModuleNode { void PopulateGlobalMap(); - void LibraryCodegen(); - protected: /*! \brief Target devices. */ TargetsMap targets_; diff --git a/tests/python/relay/test_autotvm_task_extraction.py b/tests/python/relay/test_autotvm_task_extraction.py index d29d74322f6f..8f550d82c4f6 100644 --- a/tests/python/relay/test_autotvm_task_extraction.py +++ b/tests/python/relay/test_autotvm_task_extraction.py @@ -45,12 +45,20 @@ def test_task_extraction(): params=params, ops=(relay.op.nn.conv2d,)) assert len(tasks) == 12 + tasks = autotvm.task.extract_from_program(mod, target=target, + params=params, + ops=(relay.op.nn.conv2d,)) + assert len(tasks) == 12 mod, params, _ = get_network('resnet-18', batch_size=1) tasks = autotvm.task.extract_from_program(mod["main"], target=target, params=params, ops=(relay.op.nn.dense,)) assert len(tasks) == 1 + tasks = autotvm.task.extract_from_program(mod, target=target, + params=params, + ops=(relay.op.nn.dense,)) + assert len(tasks) == 1 mod, params, _ = get_network('resnet-18', batch_size=1) mod_list.append(mod) @@ -59,22 +67,26 @@ def test_task_extraction(): params=params, ops=(relay.op.nn.conv2d, relay.op.nn.dense)) assert len(tasks) == 13 + tasks = autotvm.task.extract_from_program(mod, target=target, + params=params, + ops=(relay.op.nn.conv2d, relay.op.nn.dense)) + assert len(tasks) == 13 mod, params, _ = get_network('mobilenet', batch_size=1) mod_list.append(mod) params_list.append(params) - tasks = autotvm.task.extract_from_program(mod["main"], target=target, + tasks = autotvm.task.extract_from_program(mod, target=target, params=params, ops=(relay.op.nn.conv2d, relay.op.nn.dense)) assert len(tasks) == 20 mod, params, _ = get_network('dcgan', batch_size=1) - tasks = autotvm.task.extract_from_program(mod["main"], target=target, + tasks = autotvm.task.extract_from_program(mod, target=target, params=params, ops=(relay.op.nn.conv2d_transpose,)) assert len(tasks) == 4 - tasks = autotvm.task.extract_from_multiple_program([m['main'] for m in mod_list], params_list, + tasks = autotvm.task.extract_from_multiple_program(mod_list, params_list, target=target, ops=(relay.op.nn.conv2d,)) assert len(tasks) == 31 diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index a4c5b7d2a3c3..8a160b11ee65 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -572,4 +572,4 @@ def test_add_op_broadcast(): if __name__ == "__main__": - pytest.main() + pytest.main([__file__])