diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index aa8543d569af..a276c658c496 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -430,15 +431,184 @@ struct VMFrame { caller_return_register(0) {} }; +/*! \brief The executable emitted by the VM compiler. + * + * The executable contains information (e.g. data in different memory regions) + * to run in a virtual machine. + * + * - Global section, containing all globals. + * - Constant section, storing the constant pool. + * - Primitive name section, containing the function name of the primitive ops + * used by the virtual machine. + * - Code section, handling the VM functions and bytecode. + */ +class Executable : public ModuleNode { + public: + /*! + * \brief Get a PackedFunc from an executable module. + * + * \param name the name of the function. + * \param sptr_to_self The shared_ptr that points to this module node. + * + * \return PackedFunc or nullptr when it is not available. + */ + PackedFunc GetFunction(const std::string& name, + const std::shared_ptr& sptr_to_self) final; + + /*! + * \brief Serialize the executable into global section, constant section, and + * code section. + * + * \return The binary representation of the VM. + */ + TVMByteArray Save(); + + /*! + * \brief Load the saved VM executable. + * + * \param code The bytecode in string. + * \param lib The compiled runtime library. + * + * \return exe The constructed executable. + */ + static runtime::Module Load(const std::string& code, const runtime::Module lib); + + /*! + * \brief Get the serialized form of the `functions`. This is + * essentially bytecode serialization. + * + * \return The serialized vm bytecode. + * + * \note The bytecode is in the following format: + * func_name reg_file_size num_instructions + * param1 param2 ... paramM + * instruction1 + * instruction2 + * ... + * instructionN + * + * Each instruction is printed in the following format: + * opcode num_fields field1 ... fieldX # The text format. + * + * Serializing an `Instruction` requires us to deal with the bytecode. Each line + * of the instructions could be serialized as the following format: + * hash, opcode, f1, f2, ..., fX, field with variable length + * 1. hash: the hash of the instruction. This number will be used to help us + * validate if an instruction is well-formed during deserialization. + * 2. opcode: the opcode code of the instruction. + * 3. f1, f2, ..., fX. These fields together represent the fixed fields in + * an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For + * example, `DLDataType` will be unpacked into three fields (code, bits, lanes). + * 4. The rest of the line indicates the field with variable length, e.g., + * the shape of a tensor, the args used by an `InvokPacked` instruction, etc. + + * The field starting from # is only used for debugging. The serialized code + * doesn't contain it, therefore the deserializer doens't need to handle it. + */ + std::string GetBytecode() const; + +/*! + * \brief Print the detailed statistics of the given code, i.e. number of + * globls and constants, etc. + */ + std::string Stats() const; + + /*! \brief Get the `lib` module in an executable. Users have the flexibility to call + * `export_library` from the frontend to save the library to disk. + * + * \return The runtime module that contains the hardwre dependent code. + */ + runtime::Module GetLib() const { return lib; } + + virtual ~Executable() {} + + const char* type_key() const final { + return "VMExecutable"; + } + + /*! \brief The runtime module/library that contains both the host and also the device + * code when executing on non-CPU devices. */ + runtime::Module lib; + /*! \brief The global constant pool. */ + std::vector constants; + /*! \brief A map from globals (as strings) to their index in the function map. */ + std::unordered_map global_map; + /*! \brief A mapping from the packed function (as string) to the index that + * corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object. + */ + std::unordered_map primitive_map; + /*! \brief The virtual machine's function table. */ + std::vector functions; + + private: + /*! + * \brief Save the globals. + * + * \param strm The input stream. + */ + void SaveGlobalSection(dmlc::Stream* strm); + + /*! + * \brief Save the constant pool. + * + * \param strm The input stream. + */ + void SaveConstantSection(dmlc::Stream* strm); + + /*! + * \brief Save primitive op names. + * + * \param strm The input stream. + */ + void SavePrimitiveOpNames(dmlc::Stream* strm); + + /*! + * \brief Save the vm functions. + * + * \param strm The input stream. + */ + void SaveCodeSection(dmlc::Stream* strm); + + /*! + * \brief Load the globals. + * + * \param strm The input stream. + */ + void LoadGlobalSection(dmlc::Stream* strm); + + /*! + * \brief Load the constant pool. + * + * \param strm The input stream. + */ + void LoadConstantSection(dmlc::Stream* strm); + + /*! + * \brief Load primitive op names. + * + * \param strm The input stream. + */ + void LoadPrimitiveOpNames(dmlc::Stream* strm); + + /*! + * \brief Load the vm functions. + * + * \param strm The input stream. + */ + void LoadCodeSection(dmlc::Stream* strm); + + /*! \brief The serialized bytecode. */ + std::string code_; +}; + /*! \brief The virtual machine. * * The virtual machine contains all the current execution state, - * as well as the global view of functions, the global constant - * table, the compiled operators. + * as well as the executable. * * The goal is to have a single self-contained object, * enabling one to easily pass around VMs, execute them on - * multiple threads, or serialized them to disk or over the + * multiple threads, or serialize them to disk or over the * wire. */ class VirtualMachine : public runtime::ModuleNode { @@ -486,16 +656,18 @@ class VirtualMachine : public runtime::ModuleNode { return "VirtualMachine"; } - /*! \brief The runtime module/library that contains generated code. */ - runtime::Module lib; + VirtualMachine() : frames(), func_index(0), code(nullptr), pc(0), exec(nullptr) {} + + /*! \brief load the executable for the virtual machine. + * \param exec The executable. + */ + void LoadExecutable(const Executable* exec); + + protected: /*! \brief The virtual machine's packed function table. */ std::vector packed_funcs; - /*! \brief The virtual machine's function table. */ - std::vector functions; /*! \brief The current stack of call frames. */ std::vector frames; - /*! \brief The global constant pool. */ - std::vector constants; /*! \brief The fuction table index of the current function. */ Index func_index; /*! \brief The current pointer to the code section. */ @@ -506,6 +678,9 @@ class VirtualMachine : public runtime::ModuleNode { /*! \brief The special return register. */ ObjectRef return_register; + /*! \brief The executable the VM will operate on. */ + const Executable* exec; + /*! \brief The set of TVM contexts the VM is currently executing on. */ std::vector ctxs; @@ -550,8 +725,6 @@ class VirtualMachine : public runtime::ModuleNode { */ ObjectRef Invoke(const std::string& name, const std::vector& args); - VirtualMachine() : functions(), frames(), func_index(0), code(nullptr), pc(0) {} - /*! \brief Initialize the virtual machine for a set of contexts. * \param contexts The set of TVM contexts. */ @@ -565,21 +738,6 @@ class VirtualMachine : public runtime::ModuleNode { */ TVMContext GetParamsContext() const; - /*! - * \brief Load parameters from the parameter bytearray. - * \param params The binary file that contains parameters. - */ - void LoadParams(const std::string& params); - - /*! \brief A map from globals (as strings) to their index in the function map. - */ - std::unordered_map global_map; - - /*! \brief A mapping from the packed function (as string) to the index that - * corresponds to the position of the `packed_funcs` list. - */ - std::unordered_map primitive_map; - private: /*! \brief Invoke a global setting up the VM state to execute. * diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index ceb98c4d251e..fff9c99e5007 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -37,8 +37,6 @@ from . import feature from .backend import vm from .backend import profiler_vm -from .backend import serializer -from .backend import deserializer from .backend import vmobj # Root operators diff --git a/python/tvm/relay/backend/deserializer.py b/python/tvm/relay/backend/deserializer.py deleted file mode 100644 index fde702b1cd04..000000000000 --- a/python/tvm/relay/backend/deserializer.py +++ /dev/null @@ -1,81 +0,0 @@ -# License .to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -""" -The Relay Virtual Machine deserializer. - -Python interface for deserializing a Relay VM. -""" -from tvm import module -from tvm._ffi.runtime_ctypes import TVMByteArray -from . import _vm -from . import vm as rly_vm - -def _create_deserializer(code, lib): - """Create a deserializer object. - - Parameters - ---------- - code : bytearray - The serialized virtual machine code. - - lib : :py:class:`~tvm.module.Module` - The serialized runtime module/library that contains the hardware - dependent binary code. - - Returns - ------- - ret : Deserializer - The created virtual machine deserializer. - """ - if isinstance(code, (bytes, str)): - code = bytearray(code) - elif not isinstance(code, (bytearray, TVMByteArray)): - raise TypeError("vm is expected to be the type of bytearray or " + - "TVMByteArray, but received {}".format(type(code))) - - if not isinstance(lib, module.Module): - raise TypeError("lib is expected to be the type of tvm.module.Module" + - ", but received {}".format(type(lib))) - return _vm._Deserializer(code, lib) - - -class Deserializer: - """Relay VM deserializer. - - Parameters - ---------- - code : bytearray - The serialized virtual machine code. - - lib : :py:class:`~tvm.module.Module` - The serialized runtime module/library that contains the hardware - dependent binary code. - """ - def __init__(self, code, lib): - self.mod = _create_deserializer(code, lib) - self._deserialize = self.mod["deserialize"] - - def deserialize(self): - """Deserialize the serialized bytecode into a Relay VM. - - Returns - ------- - ret : VirtualMachine - The deserialized Relay VM. - """ - return rly_vm.VirtualMachine(self._deserialize()) diff --git a/python/tvm/relay/backend/profiler_vm.py b/python/tvm/relay/backend/profiler_vm.py index 8ae3161e0b83..b36715249f0a 100644 --- a/python/tvm/relay/backend/profiler_vm.py +++ b/python/tvm/relay/backend/profiler_vm.py @@ -49,8 +49,8 @@ def compile(mod, target=None, target_host=None, params=None): Returns ------- - vm : VirtualMachineProfiler - The profile VM runtime. + exec : Executable + The executable with profiling code. """ compiler = VMCompilerProfiler() target = compiler.update_target(target) @@ -60,7 +60,7 @@ def compile(mod, target=None, target_host=None, params=None): tophub_context = compiler.tophub_context(target) with tophub_context: compiler._compile(mod, target, target_host) - return VirtualMachineProfiler(compiler._get_vm()) + return vm.Executable(compiler._get_exec()) class VMCompilerProfiler(vm.VMCompiler): """Build Relay module to run on VM runtime.""" @@ -68,13 +68,17 @@ def __init__(self): super().__init__() self.mod = _vm._VMCompilerProfiler() self._compile = self.mod["compile"] - self._get_vm = self.mod["get_vm"] + self._get_exec = self.mod["get_executable"] self._set_params_func = self.mod["set_params"] class VirtualMachineProfiler(vm.VirtualMachine): """Relay profile VM runtime.""" def __init__(self, mod): super().__init__(mod) + m = mod.module if isinstance(mod, vm.Executable) else mod + self.mod = _vm._VirtualMachineDebug(m) + self._init = self.mod["init"] + self._invoke = self.mod["invoke"] self._get_stat = self.mod["get_stat"] def get_stat(self): diff --git a/python/tvm/relay/backend/serializer.py b/python/tvm/relay/backend/serializer.py deleted file mode 100644 index b45ba9116a15..000000000000 --- a/python/tvm/relay/backend/serializer.py +++ /dev/null @@ -1,191 +0,0 @@ -# License .to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -""" -The Relay Virtual Machine serializer. - -Python interface for serializing a Relay VM. -""" -import tvm -from . import _vm -from . import vm as rly_vm - -def _create_serializer(vm): - """Create a VM serializer. - - Parameters - ---------- - vm : Union[VirtualMachine, :py:class:`~tvm.module.Module`] - The virtual machine to be serialized. - - Returns - ------- - ret : Serializer - The created virtual machine serializer. - """ - if isinstance(vm, rly_vm.VirtualMachine): - vm = vm.module - elif not isinstance(vm, tvm.module.Module): - raise TypeError("vm is expected to be the type of VirtualMachine or " + - "tvm.Module, but received {}".format(type(vm))) - - return _vm._Serializer(vm) - - -class Serializer: - """Relay VM serializer.""" - def __init__(self, vm): - self.mod = _create_serializer(vm) - self._get_lib = self.mod["get_lib"] - self._get_bytecode = self.mod["get_bytecode"] - self._get_globals = self.mod["get_globals"] - self._get_stats = self.mod["get_stats"] - self._get_primitive_ops = self.mod["get_primitive_ops"] - self._serialize = self.mod["serialize"] - - @property - def stats(self): - """Get the statistics of the Relay VM. - - Returns - ------- - ret : String - The serialized statistic information. - """ - return self._get_stats() - - @property - def primitive_ops(self): - """Get the name of the primitive ops that are executed in the VM. - - Returns - ------- - ret : List[:py:class:`~tvm.expr.StringImm`] - The list of primitive ops. - """ - return [prim_op.value for prim_op in self._get_primitive_ops()] - - @property - def bytecode(self): - """Get the bytecode of the Relay VM. - - Returns - ------- - ret : String - The serialized bytecode. - - Notes - ----- - The bytecode is in the following format: - func_name reg_file_size num_instructions - param1 param2 ... paramM - instruction1 - instruction2 - ... - instructionN - - Each instruction is printed in the following format: - hash opcode field1 ... fieldX # The text format. - - The part starting from # is only used for visualization and debugging. - The real serialized code doesn't contain it, therefore the deserializer - doesn't need to deal with it as well. - """ - return self._get_bytecode() - - @property - def globals(self): - """Get the globals used by the Relay VM. - - Returns - ------- - ret : List[:py:class:`~tvm.expr.StringImm`] - The serialized globals. - """ - return [glb.value for glb in self._get_globals()] - - def serialize(self): - """Serialize the Relay VM. - - Returns - ------- - code : bytearray - The binary blob representing a serialized Relay VM. It can then be - saved to disk and later deserialized into a new VM. - - lib : :py:class:`~tvm.module.Module` - The runtime module that contains the generated code. It is - basically a library that is composed of hardware dependent code. - - Notes - ----- - The returned code is organized with the following sections in order. - - Global section. This section contains the globals used by the - virtual machine. - - Constant section. This section is used to store the constant pool of - a virtual machine. - - Primitive name section. This section is introduced to accommodate - the list of primitive operator names that will be invoked by the - virtual machine. - - Code section. The VM functions, including bytecode, are sitting in - this section. - - Examples - -------- - .. code-block:: python - - import numpy as np - import tvm - from tvm import relay - - # define a simple network. - x = relay.var('x', shape=(10, 10)) - f = relay.Function([x], x + x) - mod = relay.Module({"main": f}) - - # create a Relay VM. - ctx = tvm.cpu() - target = "llvm" - compiler = relay.vm.VMCompiler() - vm = compiler.compile(mod, target) - vm.init(ctx) - - # serialize. - ser = relay.serializer.Serializer(vm) - code, lib = ser.serialize() - - # save and load the code and lib file. - tmp = tvm.contrib.util.tempdir() - path_lib = tmp.relpath("lib.so") - lib.export_library(path_lib) - with open(tmp.relpath("code.bc"), "wb") as fo: - fo.write(code) - - loaded_lib = tvm.module.load(path_lib) - loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read()) - - # deserialize. - deser = relay.deserializer.Deserializer(loaded_code, loaded_lib) - des_vm = deser.deserialize() - - # execute the deserialized vm. - des_vm.init(ctx) - x_data = np.random.rand(10, 10).astype('float32') - res = des_vm.run(x_data) - print(res.asnumpy()) - """ - return self._serialize(), self._get_lib() diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index c24b16ca6437..942c93b866f4 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -24,8 +24,8 @@ import tvm from tvm import autotvm -from tvm._ffi.runtime_ctypes import TVMByteArray from tvm.relay import expr as _expr +from tvm._ffi.runtime_ctypes import TVMByteArray from . import _vm from . import vmobj as _obj from .interpreter import Executor @@ -44,6 +44,7 @@ def _convert(arg, cargs): else: raise "unsupported type" + def convert(args): cargs = [] for arg in args: @@ -52,12 +53,202 @@ def convert(args): return cargs +class Executable(object): + """Relay VM executable""" + def __init__(self, mod): + self.mod = mod + self._save = self.mod["save"] + self._get_lib = self.mod["get_lib"] + self._get_bytecode = self.mod["get_bytecode"] + self._get_stats = self.mod["get_stats"] + + def save(self): + """Save the Relay VM Executable. + + Returns + ------- + code : bytearray + The binary blob representing a serialized Relay VM executable. It + can then be saved to disk and later deserialized into a new + Executable. + + lib : :py:class:`~tvm.module.Module` + The runtime module that contains the generated code. It is + basically a library that is composed of hardware dependent code. + + Notes + ----- + The returned code is organized with the following sections in order. + - Global section. This section contains the globals used by the + virtual machine. + - Constant section. This section is used to store the constant pool of + a virtual machine. + - Primitive name section. This section is introduced to accommodate + the list of primitive operator names that will be invoked by the + virtual machine. + - Code section. The VM functions, including bytecode, are sitting in + this section. + + Examples + -------- + + .. code-block:: python + + import numpy as np + import tvm + from tvm import relay + # define a simple network. + x = relay.var('x', shape=(10, 10)) + f = relay.Function([x], x + x) + mod = relay.Module({"main": f}) + # create a Relay VM. + ctx = tvm.cpu() + target = "llvm" + executable = relay.vm.compile(mod, target) + code, lib = executable.save() + # save and load the code and lib file. + tmp = tvm.contrib.util.tempdir() + path_lib = tmp.relpath("lib.so") + lib.export_library(path_lib) + with open(tmp.relpath("code.ro"), "wb") as fo: + fo.write(code) + loaded_lib = tvm.module.load(path_lib) + loaded_code = bytearray(open(tmp.relpath("code.ro"), "rb").read()) + # deserialize. + des_exec = relay.vm.Executable.load_exec(loaded_code, loaded_code) + # execute the deserialized executable. + x_data = np.random.rand(10, 10).astype('float32') + des_vm = relay.vm.VirtualMachine(des_exec) + des_vm.init(ctx) + res = des_vm.run(x_data) + print(res.asnumpy()) + """ + return self._save(), self._get_lib() + + @staticmethod + def load_exec(bytecode, lib): + """Construct an executable from saved artifacts. + + Parameters + ---------- + bytecode : bytearray + The binary blob representing a the Relay VM bytecode. + + lib : :py:class:`~tvm.module.Module` + The runtime module that contains the generated code. + + Returns + ------- + exec: Executable + An executable constructed using the provided artifacts. + """ + if isinstance(bytecode, (bytes, str)): + code = bytearray(bytecode) + elif not isinstance(bytecode, (bytearray, TVMByteArray)): + raise TypeError("bytecode is expected to be the type of bytearray " + + "or TVMByteArray, but received {}".format(type(code))) + + if not isinstance(lib, tvm.module.Module): + raise TypeError("lib is expected to be the type of tvm.module.Module" + + ", but received {}".format(type(lib))) + + return Executable(_vm.Load_Executable(bytecode, lib)) + + @property + def lib(self): + """Get the library that contains hardware dependent code. + + Returns + ------- + ret : :py:class:`~tvm.Module` + The runtime module that contains hardware dependent code. + """ + return self._get_lib() + + @property + def stats(self): + """Get the statistics of the Relay VM executable. + + Returns + ------- + ret : String + The statistic information of the VM executable. + """ + return self._get_stats() + + @property + def primitive_ops(self): + """Get the name of the primitive ops contained in the executable. + + Returns + ------- + ret : List[String] + The list of primitive ops. + """ + ret = [] + num_primitives = _vm.GetNumOfPrimitives(self.module) + for i in range(num_primitives): + ret.append(_vm.GetPrimitiveFields(self.module, i)) + return ret + + @property + def bytecode(self): + """Get the bytecode of the Relay VM executable. + + Returns + ------- + ret : String + The bytecode of the executable. + + Notes + ----- + The bytecode is in the following format: + func_name reg_file_size num_instructions + param1 param2 ... paramM + instruction1 + instruction2 + ... + instructionN + + Each instruction is printed in the following format: + hash opcode field1 ... fieldX # The text format. + + The part starting from # is only used for visualization and debugging. + The real serialized code doesn't contain it, therefore the deserializer + doesn't need to deal with it as well. + """ + return self._get_bytecode() + + @property + def globals(self): + """Get the globals used by the Relay VM executable. + + Returns + ------- + ret : List[String] + The globals contained in the executable. + """ + ret = [] + num_globals = _vm.GetNumOfGlobals(self.module) + for i in range(num_globals): + ret.append(_vm.GetGlobalFields(self.module, i)) + return ret + + @property + def module(self): + """Return the runtime module contained in a virtual machine executable.""" + return self.mod + + class VirtualMachine(object): """Relay VM runtime.""" def __init__(self, mod): - self.mod = mod + if not isinstance(mod, (Executable, tvm.module.Module)): + raise TypeError("mod is expected to be the type of Executable or " + + "tvm.Module, but received {}".format(type(mod))) + m = mod.module if isinstance(mod, Executable) else mod + self.mod = _vm._VirtualMachine(m) self._init = self.mod["init"] - self._load_params = self.mod["load_params"] self._invoke = self.mod["invoke"] def init(self, ctx): @@ -71,23 +262,6 @@ def init(self, ctx): args = [ctx.device_type, ctx.device_id] self._init(*args) - def load_params(self, params): - """Load parameters for the VM. - - Parameters - ---------- - params : Union[bytearray, Dict] - The dictionary that contains serialized parameters. - """ - if isinstance(params, dict): - params = tvm.relay.save_param_dict(params) - elif isinstance(params, (bytes, str)): - params = bytearray(params) - if not isinstance(params, (bytearray, TVMByteArray)): - raise TypeError("params must be a bytearray") - - self._load_params(bytearray(params)) - def invoke(self, func_name, *args): """Invoke a function. @@ -122,11 +296,6 @@ def run(self, *args): """ return self.invoke("main", *args) - @property - def module(self): - """Return the runtime module contained in a virtual machine.""" - return self.mod - def compile(mod, target=None, target_host=None, params=None): """ @@ -155,8 +324,8 @@ def compile(mod, target=None, target_host=None, params=None): Returns ------- - vm : VirtualMachine - The VM runtime. + exec : Executable + The VM executable that contains both library code and bytecode. """ compiler = VMCompiler() @@ -167,14 +336,14 @@ def compile(mod, target=None, target_host=None, params=None): tophub_context = compiler.tophub_context(target) with tophub_context: compiler._compile(mod, target, target_host) - return VirtualMachine(compiler._get_vm()) + return Executable(compiler._get_exec()) class VMCompiler(object): """Build Relay module to run on VM runtime.""" def __init__(self): self.mod = _vm._VMCompiler() self._compile = self.mod["compile"] - self._get_vm = self.mod["get_vm"] + self._get_exec = self.mod["get_executable"] self._set_params_func = self.mod["set_params"] def set_params(self, params): @@ -240,7 +409,7 @@ class VMExecutor(Executor): mod : :py:class:`~tvm.relay.module.Module` The module to support the execution. - ctx : :py:class:`TVMContext` + ctx : :py:class:`~tvm.TVMContext` The runtime context to run the code on. target : :py:class:`Target` @@ -252,7 +421,8 @@ def __init__(self, mod, ctx, target): self.mod = mod self.ctx = ctx self.target = target - self.vm = compile(mod, target) + self.executable = compile(mod, target) + self.vm = VirtualMachine(self.executable) self.vm.init(ctx) def _make_executor(self, expr=None): diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 0cfae374ab2c..f295ccd7a555 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -783,9 +783,9 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, Module mod = args[0]; this->Compile(mod, args[1], args[2]); }); - } else if (name == "get_vm") { + } else if (name == "get_executable") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = runtime::Module(vm_); + *rv = runtime::Module(exec_); }); } else if (name == "set_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -864,7 +864,7 @@ void VMCompiler::Compile(Module mod, // Next we get ready by allocating space for // the global state. - vm_->functions.resize(context_.module->functions.size()); + exec_->functions.resize(context_.module->functions.size()); for (auto named_func : context_.module->functions) { auto gvar = named_func.first; @@ -873,25 +873,25 @@ void VMCompiler::Compile(Module mod, auto vm_func = func_compiler.Compile(gvar, func); size_t func_index = context_.global_map.at(gvar); - CHECK(func_index < vm_->functions.size()); - vm_->functions[func_index] = vm_func; + CHECK(func_index < exec_->functions.size()); + exec_->functions[func_index] = vm_func; } #if USE_RELAY_DEBUG - for (auto vm_func : vm_->functions) { + for (auto vm_func : exec_->functions) { DLOG(INFO) << vm_func << "-------------"; } #endif // USE_RELAY_DEBUG // populate constants for (auto data : context_.constants) { - vm_->constants.push_back(runtime::vm::Tensor(data)); + exec_->constants.push_back(runtime::vm::Tensor(data)); } LibraryCodegen(); for (auto gv : context_.global_map) { - vm_->global_map.insert({gv.first->name_hint, gv.second}); + exec_->global_map.insert({gv.first->name_hint, gv.second}); } } @@ -987,13 +987,13 @@ void VMCompiler::LibraryCodegen() { // therefore target won't be used in the build function runtime::Module mod = (*f)(funcs, Target(), target_host_); CHECK(mod.operator->()); - vm_->lib = mod; + exec_->lib = mod; } else { LOG(FATAL) << "relay.backend.build is not registered"; } size_t primitive_index = 0; for (auto cfunc : cached_funcs) { - vm_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++}); + exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++}); } } diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index dff1ef7f4569..215cc12c4cdb 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -92,12 +92,8 @@ class VMCompiler : public runtime::ModuleNode { return "VMCompiler"; } - std::shared_ptr GetVirtualMachine() const { - return vm_; - } - - virtual void InitVM() { - vm_ = std::make_shared(); + void InitVM() { + exec_ = std::make_shared(); } /*! @@ -144,8 +140,8 @@ class VMCompiler : public runtime::ModuleNode { tvm::Target target_host_; /*! \brief Global shared meta data */ VMCompilerContext context_; - /*! \brief Compiled virtual machine. */ - std::shared_ptr vm_; + /*! \brief Compiled executable. */ + std::shared_ptr exec_; /*! \brief parameters */ std::unordered_map params_; }; diff --git a/src/relay/backend/vm/deserializer.cc b/src/relay/backend/vm/deserializer.cc deleted file mode 100644 index 777282782e99..000000000000 --- a/src/relay/backend/vm/deserializer.cc +++ /dev/null @@ -1,324 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/deserializer.cc - * \brief Implementation of APIs to deserialize the serialized VM bytecode. - */ - -#include "deserializer.h" - -#include -#include -#include - -#include "serialize_util.h" - -namespace tvm { -namespace relay { -namespace vm { - -#define STREAM_CHECK(val, section) \ - CHECK(val) << "Invalid VM file format in the " << section << " section." \ - << "\n"; - -void Deserializer::Init(const std::string& code, const runtime::Module& lib) { - code_ = code; - vm_ = std::make_shared(); - vm_->lib = lib; - strm_ = new dmlc::MemoryStringStream(&code_); -} - -runtime::PackedFunc Deserializer::GetFunction( - const std::string& name, - const std::shared_ptr& sptr_to_self) { - if (name == "deserialize") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->Deserialize(); - *rv = runtime::Module(vm_); - }); - } else { - LOG(FATAL) << "Unknown packed function: " << name; - return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); - } -} - -void Deserializer::Deserialize() { - // Check header. - uint64_t header; - STREAM_CHECK(strm_->Read(&header), "header"); - STREAM_CHECK(header == kTVMVMBytecodeMagic, "header"); - - // Check version. - std::string version; - STREAM_CHECK(strm_->Read(&version), "version"); - STREAM_CHECK(version == TVM_VERSION, "version"); - - // Global section. - DeserializeGlobalSection(); - - // Constant section. - DeserializeConstantSection(); - - // Primitive names that will be invoked by `InvokePacked` instructions. - DeserializePrimitiveOpNames(); - - // Code section. - DeserializeCodeSection(); -} - -void Deserializer::DeserializeGlobalSection() { - std::vector globals; - STREAM_CHECK(strm_->Read(&globals), "global"); - for (size_t i = 0; i < globals.size(); i++) { - vm_->global_map.insert({globals[i], i}); - } -} - -void Deserializer::DeserializeConstantSection() { - uint64_t sz; - // Load the number of constants. - STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "constant"); - - size_t size = static_cast(sz); - // Load each of the constants. - for (size_t i = 0; i < size; i++) { - runtime::NDArray constant; - STREAM_CHECK(constant.Load(strm_), "constant"); - runtime::ObjectRef obj = runtime::vm::Tensor(constant); - vm_->constants.push_back(obj); - } -} - -void Deserializer::DeserializePrimitiveOpNames() { - std::vector primitive_names; - STREAM_CHECK(strm_->Read(&primitive_names), "primitive name"); - for (size_t i = 0; i < primitive_names.size(); i++) { - vm_->primitive_map.insert({primitive_names[i], i}); - } -} - -// Extract the `cnt` number of fields started at `start` from the list -// `instr_fields`. -inline std::vector ExtractFields(const std::vector& instr_fields, - Index start, - Index cnt) { - CHECK_LE(static_cast(start + cnt), instr_fields.size()); - std::vector ret; - for (auto i = start; i < start + cnt; i++) { - ret.push_back(instr_fields[i]); - } - return ret; -} - -Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { - Opcode opcode = static_cast(instr.opcode); - switch (opcode) { - case Opcode::Move: { - // Number of fields = 2 - DCHECK_EQ(instr.fields.size(), 2U); - return Instruction::Move(instr.fields[0], instr.fields[1]); - } - case Opcode::Ret: { - // Number of fields = 1 - DCHECK_EQ(instr.fields.size(), 1U); - return Instruction::Ret(instr.fields[0]); - } - case Opcode::Fatal: { - // Number of fields = 0 - DCHECK(instr.fields.empty()); - return Instruction::Fatal(); - } - case Opcode::InvokePacked: { - // Number of fields = 3 + instr.arity - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index packed_index = instr.fields[0]; - Index arity = instr.fields[1]; - Index output_size = instr.fields[2]; - std::vector args = ExtractFields(instr.fields, 3, arity); - return Instruction::InvokePacked(packed_index, arity, output_size, args); - } - case Opcode::AllocTensor: { - // Number of fields = 5 + instr.alloc_tensor.ndim - DCHECK_GE(instr.fields.size(), 5U); - DCHECK_EQ(instr.fields.size(), 5U + static_cast(instr.fields[3])); - - DLDataType dtype; - dtype.code = instr.fields[0]; - dtype.bits = instr.fields[1]; - dtype.lanes = instr.fields[2]; - - Index ndim = instr.fields[3]; - RegName dst = instr.fields[4]; - - std::vector shape = ExtractFields(instr.fields, 5, ndim); - - return Instruction::AllocTensor(shape, dtype, dst); - } - case Opcode::AllocTensorReg: { - // Number of fields = 5 - DCHECK_EQ(instr.fields.size(), 5U); - Index shape_register = instr.fields[0]; - - DLDataType dtype; - dtype.code = instr.fields[1]; - dtype.bits = instr.fields[2]; - dtype.lanes = instr.fields[3]; - - RegName dst = instr.fields[4]; - - return Instruction::AllocTensorReg(shape_register, dtype, dst); - } - case Opcode::AllocDatatype: { - // Number of fields = 3 + instr.num_fields - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index constructor_tag = instr.fields[0]; - Index num_fields = instr.fields[1]; - RegName dst = instr.fields[2]; - std::vector fields = ExtractFields(instr.fields, 3, num_fields); - - return Instruction::AllocDatatype(constructor_tag, num_fields, fields, dst); - } - case Opcode::AllocClosure: { - // Number of fields = 3 + instr.num_freevar - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index clo_index = instr.fields[0]; - Index num_freevar = instr.fields[1]; - RegName dst = instr.fields[2]; - std::vector free_vars = ExtractFields(instr.fields, 3, num_freevar); - - return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst); - } - case Opcode::If: { - // Number of fields = 4 - DCHECK_EQ(instr.fields.size(), 4U); - Index test = instr.fields[0]; - Index target = instr.fields[1]; - Index true_offset = instr.fields[2]; - Index false_offset = instr.fields[3]; - - return Instruction::If(test, target, true_offset, false_offset); - } - case Opcode::Invoke: { - // Number of fields = 3 + instr.num_args - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index func_index = instr.fields[0]; - Index num_args = instr.fields[1]; - RegName dst = instr.fields[2]; - std::vector args = ExtractFields(instr.fields, 3, num_args); - - return Instruction::Invoke(func_index, args, dst); - } - case Opcode::InvokeClosure: { - // Number of fields = 3 + instr.num_closure_args - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index closure = instr.fields[0]; - Index num_closure_args = instr.fields[1]; - RegName dst = instr.fields[2]; - std::vector args = ExtractFields(instr.fields, 3, num_closure_args); - - return Instruction::InvokeClosure(closure, args, dst); - } - case Opcode::LoadConst: { - // Number of fields = 2 - DCHECK_EQ(instr.fields.size(), 2U); - return Instruction::LoadConst(instr.fields[0], instr.fields[1]); - } - case Opcode::LoadConsti: { - // Number of fields = 2 - DCHECK_EQ(instr.fields.size(), 2U); - return Instruction::LoadConsti(instr.fields[0], instr.fields[1]); - } - case Opcode::GetField: { - // Number of fields = 3 - DCHECK_EQ(instr.fields.size(), 3U); - return Instruction::GetField(instr.fields[0], instr.fields[1], instr.fields[2]); - } - case Opcode::GetTag: { - // Number of fields = 2 - DCHECK_EQ(instr.fields.size(), 2U); - return Instruction::GetTag(instr.fields[0], instr.fields[1]); - } - case Opcode::Goto: { - // Number of fields = 1 - DCHECK_EQ(instr.fields.size(), 1U); - return Instruction::Goto(instr.fields[0]); - } - default: - LOG(FATAL) << "Invalid opcode" << instr.opcode; - return Instruction(); - } -} - -void Deserializer::DeserializeCodeSection() { - // Load the number of functions. - uint64_t sz; - STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "code"); - - size_t num_funcs = static_cast(sz); - vm_->functions.resize(num_funcs); - for (size_t i = 0; i < num_funcs; i++) { - // Load the function info. - VMFunctionSerializer loaded_func; - STREAM_CHECK(loaded_func.Load(strm_), "code/function"); - - // Load the instructions. - std::vector instructions; - for (size_t j = 0; j < loaded_func.num_instructions; j++) { - VMInstructionSerializer instr; - std::vector instr_fields; - STREAM_CHECK(instr.Load(strm_), "code/instruction"); - instructions.push_back(DeserializeInstruction(instr)); - } - - // Create the VM function. - VMFunction vm_func = VMFunction(loaded_func.name, - loaded_func.params, - instructions, - loaded_func.register_file_size); - auto it = vm_->global_map.find(loaded_func.name); - CHECK(it != vm_->global_map.end()); - CHECK_LE(it->second, vm_->global_map.size()); - vm_->functions[it->second] = vm_func; - } -} - -runtime::Module CreateDeserializer(const std::string& code, const runtime::Module lib) { - std::shared_ptr exec = std::make_shared(); - exec->Init(code, lib); - return runtime::Module(exec); -} - -TVM_REGISTER_GLOBAL("relay._vm._Deserializer") -.set_body_typed(CreateDeserializer); - -} // namespace vm -} // namespace relay -} // namespace tvm diff --git a/src/relay/backend/vm/deserializer.h b/src/relay/backend/vm/deserializer.h deleted file mode 100644 index 0caf72bee92c..000000000000 --- a/src/relay/backend/vm/deserializer.h +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/deserializer.h - * \brief Define a deserializer for the serialized Relay VM. - */ - -#ifndef TVM_RELAY_BACKEND_VM_DESERIALIZER_H_ -#define TVM_RELAY_BACKEND_VM_DESERIALIZER_H_ - -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace relay { -namespace vm { - -using namespace tvm::runtime::vm; -namespace runtime = tvm::runtime; - -class Deserializer : public runtime::ModuleNode { - public: - /*! - * \brief Initialize the deserializer for creating a virtual machine object. - * - * \param code The serialized code. - * \param lib The serialized runtime module/library that contains the - * hardware dependent code. - */ - inline void Init(const std::string& code, const runtime::Module& lib); - - /*! - * \brief Return the member function to the frontend. - * - * \param name The name of the function. - * \param sptr_to_self The pointer to the module node. - * - * \return The corresponding member function. - */ - PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self) final; - - const char* type_key() const final { return "Deserializer"; } - - /*! \brief Deserialize the serialized VM. */ - void Deserialize(); - - virtual ~Deserializer() { delete strm_; } - - private: - /*! \brief Deserialize the globals in `vm_`. */ - void DeserializeGlobalSection(); - - /*! \brief Deserialize the constant pool in `vm_`. */ - void DeserializeConstantSection(); - - /*! \brief Deserialize primitive op names in `vm_`. */ - void DeserializePrimitiveOpNames(); - - /*! \brief Deserialize the vm functions in `vm_`. */ - void DeserializeCodeSection(); - - /*! \brief The code to be serialized. */ - std::string code_; - - /*! \brief The stream used for serialization. */ - dmlc::Stream* strm_; - - /*! \brief The VM to be created. */ - std::shared_ptr vm_; -}; - -} // namespace vm -} // namespace relay -} // namespace tvm - -#endif // TVM_RELAY_BACKEND_VM_DESERIALIZER_H_ diff --git a/src/relay/backend/vm/profiler/compiler.cc b/src/relay/backend/vm/profiler/compiler.cc index 9fd28e8c7f46..60c441a60cf0 100644 --- a/src/relay/backend/vm/profiler/compiler.cc +++ b/src/relay/backend/vm/profiler/compiler.cc @@ -33,7 +33,6 @@ namespace vm { class VMCompilerDebug : public VMCompiler { public: VMCompilerDebug() {} - void InitVM() override { vm_ = std::make_shared(); } virtual ~VMCompilerDebug() {} }; diff --git a/src/relay/backend/vm/serializer.cc b/src/relay/backend/vm/serializer.cc deleted file mode 100644 index 0040ef9db470..000000000000 --- a/src/relay/backend/vm/serializer.cc +++ /dev/null @@ -1,439 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/serializer.cc - * \brief Implementation of serializing APIs for the Relay VM. - */ -#include "serializer.h" - -#include -#include - -#include -#include -#include -#include -#include - -#include "serialize_util.h" - -namespace tvm { -namespace relay { -namespace vm { - -void Serializer::Init(const VirtualMachine* vm) { - vm_ = vm; - // Initialize the stream object. - strm_ = new dmlc::MemoryStringStream(&code_); -} - -runtime::PackedFunc Serializer::GetFunction( - const std::string& name, - const std::shared_ptr& sptr_to_self) { - if (name == "get_lib") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetLib(); - }); - } else if (name == "get_primitive_ops") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetPrimitiveOps(); - }); - } else if (name == "get_bytecode") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetBytecode(); - }); - } else if (name == "get_globals") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetGlobals(); - }); - } else if (name == "get_stats") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->Stats(); - }); - } else if (name == "serialize") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->Serialize(); - }); - } else { - LOG(FATAL) << "Unknown packed function: " << name; - return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); - } -} - -tvm::Array Serializer::GetPrimitiveOps() const { - std::vector ret; - for (const auto& it : vm_->primitive_map) { - auto packed_name = tvm::ir::StringImm::make(it.first); - auto packed_index = static_cast(it.second); - if (ret.size() <= packed_index) { - ret.resize(packed_index + 1); - } - ret[packed_index] = packed_name; - } - return ret; -} - -std::string Serializer::Stats() const { - std::ostringstream oss; - oss << "Relay VM statistics:" << std::endl; - - // Get the number of constants and the shape of each of them. - oss << " Constant shapes (# " << vm_->constants.size() << "): ["; - for (const auto& it : vm_->constants) { - auto* cell = it.as(); - CHECK(cell != nullptr); - runtime::NDArray data = cell->data; - const auto& shape = data.Shape(); - - // Scalar - if (shape.empty()) { - oss << "scalar, "; - continue; - } - - oss << "["; - for (auto s : shape) { - oss << s << ", "; - } - oss.seekp(-2, oss.cur); - oss << "], " << std::endl; - } - if (!vm_->constants.empty()) oss.seekp(-2, oss.cur); - oss << "]" << std::endl; - - // Get the number of globals and the name of each of them. - oss << " Globals (#" << vm_->global_map.size() << "): ["; - for (const auto& it : vm_->global_map) { - oss << "(\"" << it.first << "\", " << it.second << ")" << ", "; - } - if (!vm_->global_map.empty()) oss.seekp(-2, oss.cur); - oss << "]" << std::endl; - - // Get the number of primitive ops and the name of each of them. - oss << " Primitive ops (#" << vm_->primitive_map.size() << "): ["; - const auto& prim_ops = GetPrimitiveOps(); - for (const auto& it : prim_ops) { - oss << it << ", "; - } - if (!prim_ops.empty()) oss.seekp(-2, oss.cur); - oss << "]" << std::endl; - - return oss.str(); -} - -TVMByteArray Serializer::Serialize() { - uint64_t header = kTVMVMBytecodeMagic; - strm_->Write(header); - std::string version = TVM_VERSION; - strm_->Write(version); - - // Global section. - SerializeGlobalSection(); - - // Constant section. - SerializeConstantSection(); - - // Primitive names. - SerializePrimitiveOpNames(); - - // Code section. - SerializeCodeSection(); - - TVMByteArray arr; - arr.data = code_.c_str(); - arr.size = code_.length(); - return arr; -} - -void Serializer::SerializeGlobalSection() { - auto globals = GetGlobals(); - std::vector glbs; - for (const auto& it : globals) { - glbs.push_back(it.as()->value); - } - strm_->Write(glbs); -} - -void Serializer::SerializeConstantSection() { - std::vector arrays; - for (const auto& obj : vm_->constants) { - const auto* cell = obj.as(); - CHECK(cell != nullptr); - runtime::NDArray data = cell->data; - arrays.push_back(const_cast(data.operator->())); - } - strm_->Write(static_cast(vm_->constants.size())); - for (const auto& it : arrays) { - runtime::SaveDLTensor(strm_, it); - } -} - -void Serializer::SerializePrimitiveOpNames() { - auto names = GetPrimitiveOps(); - std::vector primitive_names; - for (const auto& it : names) { - primitive_names.push_back(it.as()->value); - } - strm_->Write(primitive_names); -} - -// Serialize a virtual machine instruction. It creates a list that contains the -// hash, opcode, and all fields of an instruction. -// -// For example, the function signature used to create an `AllocTensor` -// instruction is: -// Instruction AllocTensor(std::vector shape, DLDataType dtype, RegName dst) -// -// The serialized form will be: -// `hash 5 dtype.code dtype.bits dtype.lanes ndim dst_register val1 val2 ... valn` -// -// where hash is the hash of serialized instruction that is computed internally -// by the `VMInstructionSerializer`. It is used for sanity check before decoding. -// 5 shows opcode of `AllocTensor`, `(dtype.code dtype.bits dtype.lanes)` -// represents a `DLDataType`, `ndim` is the number of dimensions, `dst_register` -// is the destination register, and the rest of it together indicates the shape -// of the tensor to be allocated. -VMInstructionSerializer SerializeInstruction(const Instruction& instr) { - std::vector fields; - // Save the opcode. - DLOG(INFO) << "Serializing: " << instr << std::endl; - switch (instr.op) { - case Opcode::Move: { - // Number of fields = 2 - fields.assign({instr.from, instr.dst}); - break; - } - case Opcode::Ret: { - // Number of fields = 1 - fields.push_back(instr.result); - break; - } - case Opcode::Fatal: { - // Number of fields = 0 - break; - } - case Opcode::InvokePacked: { - // Number of fields = 3 + instr.arity - // Note that arity includes both input arguments and outputs. We will - // put all the `arity` number of fields in the end for serialization. - fields.assign({instr.packed_index, instr.arity, instr.output_size}); - // Save the args. - fields.insert(fields.end(), instr.packed_args, instr.packed_args + instr.arity); - break; - } - case Opcode::AllocTensor: { - // Number of fields = 5 + instr.alloc_tensor.ndim - // Save `DLDataType` and the dst register. - const auto& dtype = instr.alloc_tensor.dtype; - fields.assign({dtype.code, dtype.bits, dtype.lanes}); - - // The number of dimensions is not needed for constructing an - // `AllocTensor` instruction as it equals to the length of the `shape` - // vector. However, we save it to conveniently deserialize the instruction - // because we will know how many fields are needed by the `shape` argument. - fields.push_back(instr.alloc_tensor.ndim); - fields.push_back(instr.dst); - - // Save the shape of the tensor. - // Note that this field is rotated to the end of the list. - fields.insert(fields.end(), instr.alloc_tensor.shape, - instr.alloc_tensor.shape + instr.alloc_tensor.ndim); - break; - } - case Opcode::AllocTensorReg: { - // Number of fields = 5 - fields.push_back(instr.alloc_tensor_reg.shape_register); - // Save `DLDataType` and the dst register. - const auto& dtype = instr.alloc_tensor.dtype; - fields.assign({dtype.code, dtype.bits, dtype.lanes}); - fields.push_back(instr.dst); - break; - } - case Opcode::AllocDatatype: { - // Number of fields = 3 + instr.num_fields - fields.assign({instr.constructor_tag, instr.num_fields, instr.dst}); - - // Save the fields. - fields.insert(fields.end(), instr.datatype_fields, - instr.datatype_fields + instr.num_fields); - break; - } - case Opcode::AllocClosure: { - // Number of fields = 3 + instr.num_freevar - fields.assign({instr.clo_index, instr.num_freevar, instr.dst}); - - // Save the free vars. - fields.insert(fields.end(), instr.free_vars, - instr.free_vars + instr.num_freevar); - break; - } - case Opcode::If: { - // Number of fields = 4 - fields.assign({instr.if_op.test, - instr.if_op.target, - instr.if_op.true_offset, - instr.if_op.false_offset}); - break; - } - case Opcode::Invoke: { - // Number of fields = 3 + instr.num_args - fields.assign({instr.func_index, instr.num_args, instr.dst}); - - // Save the args. - fields.insert(fields.end(), instr.invoke_args_registers, - instr.invoke_args_registers + instr.num_args); - break; - } - case Opcode::InvokeClosure: { - // Number of fields = 3 + instr.num_closure_args - fields.assign({instr.closure, instr.num_closure_args, instr.dst}); - - // Save the args. - fields.insert(fields.end(), instr.closure_args, - instr.closure_args + instr.num_closure_args); - break; - } - case Opcode::LoadConst: { - // Number of fields = 2 - fields.assign({instr.const_index, instr.dst}); - break; - } - case Opcode::LoadConsti: { - // Number of fields = 2 - fields.assign({instr.load_consti.val, instr.dst}); - break; - } - case Opcode::GetField: { - // Number of fields = 3 - fields.assign({instr.object, instr.field_index, instr.dst}); - break; - } - case Opcode::GetTag: { - // Number of fields = 2 - fields.assign({instr.get_tag.object, instr.dst}); - break; - } - case Opcode::Goto: { - // Number of fields = 1 - fields.push_back(instr.pc_offset); - break; - } - default: - LOG(FATAL) << "Invalid opcode" << static_cast(instr.op); - break; - } - - return VMInstructionSerializer(static_cast(instr.op), fields); -} - -void Serializer::SerializeCodeSection() { - // Save the number of functions. - strm_->Write(static_cast(vm_->functions.size())); - for (const auto& func : vm_->functions) { - // Serialize the function info. - VMFunctionSerializer func_format(func.name, - func.register_file_size, - func.instructions.size(), - func.params); - func_format.Save(strm_); - - // Serialize each instruction. - for (const auto& instr : func.instructions) { - const auto& serialized_instr = SerializeInstruction(instr); - serialized_instr.Save(strm_); - } - } -} - -tvm::Array Serializer::GetGlobals() const { - tvm::Array ret; - std::vector > globals(vm_->global_map.begin(), - vm_->global_map.end()); - auto comp = [](const std::pair& a, - const std::pair& b) { - return a.second < b.second; - }; - std::sort(globals.begin(), globals.end(), comp); - for (const auto& it : globals) { - ret.push_back(tvm::ir::StringImm::make(it.first)); - } - return ret; -} - -std::string Serializer::GetBytecode() const { - std::ostringstream oss; - - for (const auto& func : vm_->functions) { - // Print the header of the function format. - oss << "# func name, reg file size, param count, inst count:" - << std::endl; - oss << func.name << " " - << func.register_file_size << " " - << func.params.size() << " " - << func.instructions.size() << std::endl; - - // Print pramams of a `VMFunction`. - oss << "# Parameters:"<< std::endl; - for (const auto& param : func.params) { - oss << param << " "; - } - oss << std::endl; - - // Print the instructions of a `VMFunction`. - // The part after ";" is the instruction in text format. - oss << "hash, opcode, fields # inst(text):"<< std::endl; - for (const auto& instr : func.instructions) { - const auto& serialized_instr = SerializeInstruction(instr); - oss << std::hex << "0x" << serialized_instr.Hash() << " " - << std::dec << serialized_instr.opcode << " "; - for (auto it : serialized_instr.fields) { - oss << it << " "; - } - oss << " # " << instr; - if (oss.str().back() != '\n') oss << std::endl; - } - } - - return oss.str(); -} - -runtime::Module Serializer::GetLib() const { - return vm_->lib; -} - -runtime::Module CreateSerializer(const VirtualMachine* vm) { - std::shared_ptr exec = std::make_shared(); - exec->Init(vm); - return runtime::Module(exec); -} - -TVM_REGISTER_GLOBAL("relay._vm._Serializer") -.set_body([](TVMArgs args, TVMRetValue* rv) { - runtime::Module mod = args[0]; - const auto* vm = dynamic_cast(mod.operator->()); - CHECK(vm) << "Virtual machine has not been defined yet." - << "\n"; - *rv = CreateSerializer(vm); -}); - -} // namespace vm -} // namespace relay -} // namespace tvm diff --git a/src/relay/backend/vm/serializer.h b/src/relay/backend/vm/serializer.h deleted file mode 100644 index 2371bb4c94f5..000000000000 --- a/src/relay/backend/vm/serializer.h +++ /dev/null @@ -1,202 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/serializer.h - * \brief Define a serializer for the Relay VM. - * - * The following components of a Relay VM will be serialized: - * - The `constants`, e.g., the constant pool, that contains the - * constants used in a Relay program. - * - The `packed_funcs` that essentially contains the generated code for - * a specific target. We return it as a runtime module that can be exported as - * a library file (e.g., .so, .o, or .tar). - * - The `global_map` that contains the globals. - * - The `primitive_map` that contains the name of individual primitive operators. - * - The `functions`, e.g., the `VMFunction`. Each `VMFunction` is composed of - * a list of instructions/bytecode. - * - * Note that only the library is returned as a separate module. All othere parts - * are stored in a single serialized code that is organized with the following - * sections in order. - * - Global section, containing all globals. - * - Constant section, storing the constant pool. - * - Primitive name section, containing the function name of the primitive ops - * used by the virtual machine. - * - Code section, handling the VM functions and bytecode. - * - * The code section is again organized as follows for each VM function: - * func_name, register_file_size, num_instructions (N) - * param1, param2, ..., paramM - * instruction1 - * instruction2 - * ... - * instructionN - * - * Serializing an `Instruction` requires us to deal with the bytecode. Each line - * of the instructions could be serialized as the following format: - * hash, opcode, f1, f2, ..., fX, field with variable length - * 1. hash: the hash of the instruction. This number will be used to help us - * validate if an instruction is well-formed during deserialization. - * 2. opcode: the opcode code of the instruction. - * 3. f1, f2, ..., fX. These fields together represent the fixed fields in - * an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For - * example, `DLDataType` will be unpacked into three fields (code, bits, lanes). - * 4. The rest of the line indicates the field with variable length, e.g., - * the shape of a tensor, the args used by an `InvokPacked` instruction, etc. - */ - -#ifndef TVM_RELAY_BACKEND_VM_SERIALIZER_H_ -#define TVM_RELAY_BACKEND_VM_SERIALIZER_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace relay { -namespace vm { - -using namespace tvm::runtime; -using namespace tvm::runtime::vm; - -/*! - * \brief The Relay VM serializer. - */ -class Serializer : public runtime::ModuleNode { - public: - /*! - * \brief Initialize the serializer for a virtual machine. - * - * \param vm The Relay virtual machine. - */ - inline void Init(const VirtualMachine* vm); - - /*! - * \brief Return the member function to the frontend. - * - * \param name The name of the function. - * \param sptr_to_self The pointer to the module node. - * - * \return The corresponding member function. - */ - PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self) final; - - const char* type_key() const final { return "Serializer"; } - - /*! - * \brief Print the detailed statistics of the given code, i.e. number of - * globls and constants, etc. - */ - std::string Stats() const; - - /*! - * \brief Serialize the `vm_` into global section, constant section, and code - * section. - * - * \return The binary representation of the VM. - */ - TVMByteArray Serialize(); - - /*! - * \brief Get a list of the globals used by the `_vm`. - * - * \return The global map in the form a list. - */ - tvm::Array GetGlobals() const; - - /*! - * \brief Get the primitive operators that are contained in the Relay VM. - * - * \return The list of primitve operators. - */ - tvm::Array GetPrimitiveOps() const; - - /*! - * \brief Get the serialized form of the `functions` in `vm_`. This is - * essentially bytecode serialization. - * - * \return The serialized vm bytecode. - * - * \note The bytecode is in the following format: - * func_name reg_file_size num_instructions - * param1 param2 ... paramM - * instruction1 - * instruction2 - * ... - * instructionN - * - * Each instruction is printed in the following format: - * opcode num_fields field1 ... fieldX # The text format. - * - * The field starting from # is only used for debugging. The serialized code - * doesn't contain it, therefore the deserializer doens't need to handle it. - */ - std::string GetBytecode() const; - - /*! \brief Get the `lib` module in vm_. Serialization of `runtime::module` - * has already been supported by TVM. Therefore, we only return the runtime - * module and let users have the flexibility to call `export_library` from - * the frontend to save the library to disk. - * - * \return The runtime module that contains the hardwre dependent code. - */ - inline runtime::Module GetLib() const; - - virtual ~Serializer() { delete strm_; } - - private: - /*! \brief Serialize the globals in vm_. */ - void SerializeGlobalSection(); - - /*! \brief Serialize the constant pool in vm_. */ - void SerializeConstantSection(); - - /*! \brief Serialize primitive op names in vm_. */ - void SerializePrimitiveOpNames(); - - /*! \brief Serialize the vm functions in vm_. */ - void SerializeCodeSection(); - - /*! \brief The Relay virtual machine for to be serialized. */ - const VirtualMachine* vm_; - - /*! \brief The stream used for serialization. */ - dmlc::Stream* strm_; - - /*! \brief The serialized code. */ - std::string code_; -}; - -} // namespace vm -} // namespace relay -} // namespace tvm - -#endif // TVM_RELAY_BACKEND_VM_SERIALIZER_H_ diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc new file mode 100644 index 000000000000..21f71af4eb8c --- /dev/null +++ b/src/runtime/vm/executable.cc @@ -0,0 +1,734 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file tvm/runtime/vm/executable.cc + * \brief The implementation of a virtual machine executable APIs. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "serialize_util.h" + +namespace tvm { +namespace runtime { +namespace vm { + +#define STREAM_CHECK(val, section) \ + CHECK(val) << "Invalid VM file format in the " << section << " section." \ + << "\n"; + +// Helper to serialize a vm instruction. +VMInstructionSerializer SerializeInstruction(const Instruction& instr); +// Helper to deserialize a serialized vm instruction. +Instruction DeserializeInstruction(const VMInstructionSerializer& instr); + +PackedFunc Executable::GetFunction(const std::string& name, + const std::shared_ptr& sptr_to_self) { + if (name == "get_lib") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetLib(); + }); + } else if (name == "get_bytecode") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetBytecode(); + }); + } else if (name == "get_stats") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->Stats(); + }); + } else if (name == "save") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->Save(); + }); + } else { + LOG(FATAL) << "Unknown packed function: " << name; + return PackedFunc(nullptr); + } +} + +std::string Executable::GetBytecode() const { + std::ostringstream oss; + + for (const auto& func : functions) { + // Print the header of the function format. + oss << "# func name, reg file size, param count, inst count:" + << std::endl; + oss << func.name << " " + << func.register_file_size << " " + << func.params.size() << " " + << func.instructions.size() << std::endl; + + // Print pramams of a `VMFunction`. + oss << "# Parameters: "<< std::endl; + for (const auto& param : func.params) { + oss << param << " "; + } + oss << std::endl; + + // Print the instructions of a `VMFunction`. + // The part after ";" is the instruction in text format. + oss << "hash, opcode, fields # inst(text):"<< std::endl; + for (const auto& instr : func.instructions) { + const auto& serialized_instr = SerializeInstruction(instr); + oss << std::hex << "0x" << serialized_instr.Hash() << " " + << std::dec << serialized_instr.opcode << " "; + for (auto it : serialized_instr.fields) { + oss << it << " "; + } + oss << " # " << instr; + if (oss.str().back() != '\n') oss << std::endl; + } + } + + return oss.str(); +} + +std::string Executable::Stats() const { + std::ostringstream oss; + oss << "Relay VM executable statistics:" << std::endl; + + // Get the number of constants and the shape of each of them. + oss << " Constant shapes (# " << constants.size() << "): ["; + for (const auto& it : constants) { + const auto* cell = it.as(); + CHECK(cell); + runtime::NDArray data = cell->data; + const auto& shape = data.Shape(); + + // Scalar + if (shape.empty()) { + oss << "scalar, "; + continue; + } + + oss << "["; + for (auto s : shape) { + oss << s << ", "; + } + oss.seekp(-2, oss.cur); + oss << "], " << std::endl; + } + if (!constants.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + // Get the number of globals and the name of each of them. + oss << " Globals (#" << global_map.size() << "): ["; + for (const auto& it : global_map) { + oss << "(\"" << it.first << "\", " << it.second << ")" << ", "; + } + if (!global_map.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + // Get the number of primitive ops and the name of each of them. + oss << " Primitive ops (#" << primitive_map.size() << "): ["; + std::vector prim_ops; + for (const auto& it : primitive_map) { + auto packed_index = static_cast(it.second); + if (prim_ops.size() <= packed_index) { + prim_ops.resize(packed_index + 1); + } + prim_ops[packed_index] = it.first; + } + for (const auto& it : prim_ops) { + oss << it << ", "; + } + if (!prim_ops.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + return oss.str(); +} + +void SaveHeader(dmlc::Stream* strm) { + uint64_t header = kTVMVMBytecodeMagic; + strm->Write(header); + std::string version = TVM_VERSION; + strm->Write(version); +} + +TVMByteArray Executable::Save() { + // Initialize the stream object. + code_.clear(); + dmlc::MemoryStringStream strm(&code_); + + // Save header + SaveHeader(&strm); + + // Global section. + SaveGlobalSection(&strm); + + // Constant section. + SaveConstantSection(&strm); + + // Primitive names. + SavePrimitiveOpNames(&strm); + + // Code section. + SaveCodeSection(&strm); + + TVMByteArray arr; + arr.data = code_.c_str(); + arr.size = code_.length(); + return arr; +} + +void Executable::SaveGlobalSection(dmlc::Stream* strm) { + std::vector > globals(this->global_map.begin(), + this->global_map.end()); + auto comp = [](const std::pair& a, + const std::pair& b) { + return a.second < b.second; + }; + std::sort(globals.begin(), globals.end(), comp); + + std::vector glbs; + for (const auto& it : globals) { + glbs.push_back(it.first); + } + strm->Write(glbs); +} + +void Executable::SaveConstantSection(dmlc::Stream* strm) { + std::vector arrays; + for (const auto& obj : this->constants) { + const auto* cell = obj.as(); + CHECK(cell != nullptr); + runtime::NDArray data = cell->data; + arrays.push_back(const_cast(data.operator->())); + } + strm->Write(static_cast(this->constants.size())); + for (const auto& it : arrays) { + runtime::SaveDLTensor(strm, it); + } +} + +void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) { + std::vector primitive_names; + for (const auto& it : this->primitive_map) { + auto packed_index = static_cast(it.second); + if (primitive_names.size() <= packed_index) { + primitive_names.resize(packed_index + 1); + } + primitive_names[packed_index] = it.first; + } + strm->Write(primitive_names); +} + +// Serialize a virtual machine instruction. It creates a list that contains the +// hash, opcode, and all fields of an instruction. +// +// For example, the function signature used to create an `AllocTensor` +// instruction is: +// Instruction AllocTensor(std::vector shape, DLDataType dtype, RegName dst) +// +// The serialized form will be: +// `hash 5 dtype.code dtype.bits dtype.lanes ndim dst_register val1 val2 ... valn` +// +// where hash is the hash of serialized instruction that is computed internally +// by the `VMInstructionExecutable`. It is used for sanity check before decoding. +// 5 shows opcode of `AllocTensor`, `(dtype.code dtype.bits dtype.lanes)` +// represents a `DLDataType`, `ndim` is the number of dimensions, `dst_register` +// is the destination register, and the rest of it together indicates the shape +// of the tensor to be allocated. +VMInstructionSerializer SerializeInstruction(const Instruction& instr) { + std::vector fields; + // Save the opcode. + DLOG(INFO) << "Serializing: " << instr << std::endl; + switch (instr.op) { + case Opcode::Move: { + // Number of fields = 2 + fields.assign({instr.from, instr.dst}); + break; + } + case Opcode::Ret: { + // Number of fields = 1 + fields.push_back(instr.result); + break; + } + case Opcode::Fatal: { + // Number of fields = 0 + break; + } + case Opcode::InvokePacked: { + // Number of fields = 3 + instr.arity + // Note that arity includes both input arguments and outputs. We will + // put all the `arity` number of fields in the end for serialization. + fields.assign({instr.packed_index, instr.arity, instr.output_size}); + // Save the args. + fields.insert(fields.end(), instr.packed_args, instr.packed_args + instr.arity); + break; + } + case Opcode::AllocTensor: { + // Number of fields = 5 + instr.alloc_tensor.ndim + // Save `DLDataType` and the dst register. + const auto& dtype = instr.alloc_tensor.dtype; + fields.assign({dtype.code, dtype.bits, dtype.lanes}); + + // The number of dimensions is not needed for constructing an + // `AllocTensor` instruction as it equals to the length of the `shape` + // vector. However, we save it to conveniently deserialize the instruction + // because we will know how many fields are needed by the `shape` argument. + fields.push_back(instr.alloc_tensor.ndim); + fields.push_back(instr.dst); + + // Save the shape of the tensor. + // Note that this field is rotated to the end of the list. + fields.insert(fields.end(), instr.alloc_tensor.shape, + instr.alloc_tensor.shape + instr.alloc_tensor.ndim); + break; + } + case Opcode::AllocTensorReg: { + // Number of fields = 5 + fields.push_back(instr.alloc_tensor_reg.shape_register); + // Save `DLDataType` and the dst register. + const auto& dtype = instr.alloc_tensor.dtype; + fields.assign({dtype.code, dtype.bits, dtype.lanes}); + fields.push_back(instr.dst); + break; + } + case Opcode::AllocDatatype: { + // Number of fields = 3 + instr.num_fields + fields.assign({instr.constructor_tag, instr.num_fields, instr.dst}); + + // Save the fields. + fields.insert(fields.end(), instr.datatype_fields, + instr.datatype_fields + instr.num_fields); + break; + } + case Opcode::AllocClosure: { + // Number of fields = 3 + instr.num_freevar + fields.assign({instr.clo_index, instr.num_freevar, instr.dst}); + + // Save the free vars. + fields.insert(fields.end(), instr.free_vars, + instr.free_vars + instr.num_freevar); + break; + } + case Opcode::If: { + // Number of fields = 4 + fields.assign({instr.if_op.test, + instr.if_op.target, + instr.if_op.true_offset, + instr.if_op.false_offset}); + break; + } + case Opcode::Invoke: { + // Number of fields = 3 + instr.num_args + fields.assign({instr.func_index, instr.num_args, instr.dst}); + + // Save the args. + fields.insert(fields.end(), instr.invoke_args_registers, + instr.invoke_args_registers + instr.num_args); + break; + } + case Opcode::InvokeClosure: { + // Number of fields = 3 + instr.num_closure_args + fields.assign({instr.closure, instr.num_closure_args, instr.dst}); + + // Save the args. + fields.insert(fields.end(), instr.closure_args, + instr.closure_args + instr.num_closure_args); + break; + } + case Opcode::LoadConst: { + // Number of fields = 2 + fields.assign({instr.const_index, instr.dst}); + break; + } + case Opcode::LoadConsti: { + // Number of fields = 2 + fields.assign({instr.load_consti.val, instr.dst}); + break; + } + case Opcode::GetField: { + // Number of fields = 3 + fields.assign({instr.object, instr.field_index, instr.dst}); + break; + } + case Opcode::GetTag: { + // Number of fields = 2 + fields.assign({instr.get_tag.object, instr.dst}); + break; + } + case Opcode::Goto: { + // Number of fields = 1 + fields.push_back(instr.pc_offset); + break; + } + default: + LOG(FATAL) << "Invalid opcode" << static_cast(instr.op); + break; + } + + return VMInstructionSerializer(static_cast(instr.op), fields); +} + +void Executable::SaveCodeSection(dmlc::Stream* strm) { + // Save the number of functions. + strm->Write(static_cast(this->functions.size())); + for (const auto& func : this->functions) { + // Save the function info. + VMFunctionSerializer func_format(func.name, + func.register_file_size, + func.instructions.size(), + func.params); + func_format.Save(strm); + + // Serialize each instruction. + for (const auto& instr : func.instructions) { + const auto& serialized_instr = SerializeInstruction(instr); + serialized_instr.Save(strm); + } + } +} + +void LoadHeader(dmlc::Stream* strm) { + // Check header. + uint64_t header; + STREAM_CHECK(strm->Read(&header), "header"); + STREAM_CHECK(header == kTVMVMBytecodeMagic, "header"); + + // Check version. + std::string version; + STREAM_CHECK(strm->Read(&version), "version"); + STREAM_CHECK(version == TVM_VERSION, "version"); +} + +runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) { + std::shared_ptr exec = std::make_shared(); + exec->lib = lib; + exec->code_ = code; + dmlc::MemoryStringStream strm(&exec->code_); + + // Load header. + LoadHeader(&strm); + + // Global section. + exec->LoadGlobalSection(&strm); + + // Constant section. + exec->LoadConstantSection(&strm); + + // Primitive names that will be invoked by `InvokePacked` instructions. + exec->LoadPrimitiveOpNames(&strm); + + // Code section. + exec->LoadCodeSection(&strm); + + return runtime::Module(exec); +} + +void Executable::LoadGlobalSection(dmlc::Stream* strm) { + std::vector globals; + STREAM_CHECK(strm->Read(&globals), "global"); + for (size_t i = 0; i < globals.size(); i++) { + this->global_map.insert({globals[i], i}); + } +} + +void Executable::LoadConstantSection(dmlc::Stream* strm) { + uint64_t sz; + // Load the number of constants. + STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant"); + + size_t size = static_cast(sz); + // Load each of the constants. + for (size_t i = 0; i < size; i++) { + runtime::NDArray constant; + STREAM_CHECK(constant.Load(strm), "constant"); + runtime::ObjectRef obj = runtime::vm::Tensor(constant); + this->constants.push_back(obj); + } +} + +void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) { + std::vector primitive_names; + STREAM_CHECK(strm->Read(&primitive_names), "primitive name"); + for (size_t i = 0; i < primitive_names.size(); i++) { + this->primitive_map.insert({primitive_names[i], i}); + } +} + +// Extract the `cnt` number of fields started at `start` from the list +// `instr_fields`. +inline std::vector ExtractFields(const std::vector& instr_fields, + Index start, + Index cnt) { + CHECK_LE(static_cast(start + cnt), instr_fields.size()); + std::vector ret; + for (auto i = start; i < start + cnt; i++) { + ret.push_back(instr_fields[i]); + } + return ret; +} + +Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { + Opcode opcode = static_cast(instr.opcode); + switch (opcode) { + case Opcode::Move: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::Move(instr.fields[0], instr.fields[1]); + } + case Opcode::Ret: { + // Number of fields = 1 + DCHECK_EQ(instr.fields.size(), 1U); + return Instruction::Ret(instr.fields[0]); + } + case Opcode::Fatal: { + // Number of fields = 0 + DCHECK(instr.fields.empty()); + return Instruction::Fatal(); + } + case Opcode::InvokePacked: { + // Number of fields = 3 + instr.arity + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index packed_index = instr.fields[0]; + Index arity = instr.fields[1]; + Index output_size = instr.fields[2]; + std::vector args = ExtractFields(instr.fields, 3, arity); + return Instruction::InvokePacked(packed_index, arity, output_size, args); + } + case Opcode::AllocTensor: { + // Number of fields = 5 + instr.alloc_tensor.ndim + DCHECK_GE(instr.fields.size(), 5U); + DCHECK_EQ(instr.fields.size(), 5U + static_cast(instr.fields[3])); + + DLDataType dtype; + dtype.code = instr.fields[0]; + dtype.bits = instr.fields[1]; + dtype.lanes = instr.fields[2]; + + Index ndim = instr.fields[3]; + RegName dst = instr.fields[4]; + + std::vector shape = ExtractFields(instr.fields, 5, ndim); + + return Instruction::AllocTensor(shape, dtype, dst); + } + case Opcode::AllocTensorReg: { + // Number of fields = 5 + DCHECK_EQ(instr.fields.size(), 5U); + Index shape_register = instr.fields[0]; + + DLDataType dtype; + dtype.code = instr.fields[1]; + dtype.bits = instr.fields[2]; + dtype.lanes = instr.fields[3]; + + RegName dst = instr.fields[4]; + + return Instruction::AllocTensorReg(shape_register, dtype, dst); + } + case Opcode::AllocDatatype: { + // Number of fields = 3 + instr.num_fields + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index constructor_tag = instr.fields[0]; + Index num_fields = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector fields = ExtractFields(instr.fields, 3, num_fields); + + return Instruction::AllocDatatype(constructor_tag, num_fields, fields, dst); + } + case Opcode::AllocClosure: { + // Number of fields = 3 + instr.num_freevar + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index clo_index = instr.fields[0]; + Index num_freevar = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector free_vars = ExtractFields(instr.fields, 3, num_freevar); + + return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst); + } + case Opcode::If: { + // Number of fields = 4 + DCHECK_EQ(instr.fields.size(), 4U); + Index test = instr.fields[0]; + Index target = instr.fields[1]; + Index true_offset = instr.fields[2]; + Index false_offset = instr.fields[3]; + + return Instruction::If(test, target, true_offset, false_offset); + } + case Opcode::Invoke: { + // Number of fields = 3 + instr.num_args + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index func_index = instr.fields[0]; + Index num_args = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector args = ExtractFields(instr.fields, 3, num_args); + + return Instruction::Invoke(func_index, args, dst); + } + case Opcode::InvokeClosure: { + // Number of fields = 3 + instr.num_closure_args + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index closure = instr.fields[0]; + Index num_closure_args = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector args = ExtractFields(instr.fields, 3, num_closure_args); + + return Instruction::InvokeClosure(closure, args, dst); + } + case Opcode::LoadConst: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::LoadConst(instr.fields[0], instr.fields[1]); + } + case Opcode::LoadConsti: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::LoadConsti(instr.fields[0], instr.fields[1]); + } + case Opcode::GetField: { + // Number of fields = 3 + DCHECK_EQ(instr.fields.size(), 3U); + return Instruction::GetField(instr.fields[0], instr.fields[1], instr.fields[2]); + } + case Opcode::GetTag: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::GetTag(instr.fields[0], instr.fields[1]); + } + case Opcode::Goto: { + // Number of fields = 1 + DCHECK_EQ(instr.fields.size(), 1U); + return Instruction::Goto(instr.fields[0]); + } + default: + LOG(FATAL) << "Invalid opcode" << instr.opcode; + return Instruction(); + } +} + +void Executable::LoadCodeSection(dmlc::Stream* strm) { + // Load the number of functions. + uint64_t sz; + STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "code"); + + size_t num_funcs = static_cast(sz); + this->functions.resize(num_funcs); + for (size_t i = 0; i < num_funcs; i++) { + // Load the function info. + VMFunctionSerializer loaded_func; + STREAM_CHECK(loaded_func.Load(strm), "code/function"); + + // Load the instructions. + std::vector instructions; + for (size_t j = 0; j < loaded_func.num_instructions; j++) { + VMInstructionSerializer instr; + std::vector instr_fields; + STREAM_CHECK(instr.Load(strm), "code/instruction"); + instructions.push_back(DeserializeInstruction(instr)); + } + + // Create the VM function. + VMFunction vm_func = VMFunction(loaded_func.name, + loaded_func.params, + instructions, + loaded_func.register_file_size); + auto it = this->global_map.find(loaded_func.name); + CHECK(it != this->global_map.end()); + CHECK_LE(it->second, this->global_map.size()); + this->functions[it->second] = vm_func; + } +} + +TVM_REGISTER_GLOBAL("relay._vm.GetNumOfGlobals") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec); + *rv = static_cast(exec->global_map.size()); +}); + +TVM_REGISTER_GLOBAL("relay._vm.GetGlobalFields") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec); + int idx = args[1]; + std::vector > globals(exec->global_map.begin(), + exec->global_map.end()); + auto comp = [](const std::pair& a, + const std::pair& b) { + return a.second < b.second; + }; + std::sort(globals.begin(), globals.end(), comp); + CHECK_LT(idx, globals.size()); + *rv = globals[idx].first; +}); + +TVM_REGISTER_GLOBAL("relay._vm.GetNumOfPrimitives") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec); + *rv = static_cast(exec->primitive_map.size()); +}); + + +TVM_REGISTER_GLOBAL("relay._vm.GetPrimitiveFields") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec); + int idx = args[1]; + CHECK_GE(idx, 0); + CHECK_LT(idx, exec->primitive_map.size()); + + for (const auto& it : exec->primitive_map) { + if (idx == static_cast(it.second)) { + *rv = it.first; + break; + } + } +}); + +TVM_REGISTER_GLOBAL("relay._vm.Load_Executable") +.set_body_typed([]( + std::string code, + runtime::Module lib) { + return Executable::Load(code, lib); +}); + +} // namespace vm +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 80e0ce57a8ae..821de0bda245 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -85,19 +85,25 @@ PackedFunc VirtualMachineDebug::GetFunction( } } -void VirtualMachineDebug::Init(const std::vector& ctxs) { - VirtualMachine::Init(ctxs); - for (auto kv : primitive_map) { +void VirtualMachineDebug::LoadExecutable(const Executable* exec) { + VirtualMachine::LoadExecutable(exec); + CHECK(this->exec); + for (auto kv : this->exec->primitive_map) { packed_index_map[kv.second] = kv.first; op_invokes[kv.second] = 0; } } +void VirtualMachineDebug::Init(const std::vector& ctxs) { + VirtualMachine::Init(ctxs); +} + void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, const std::vector& args) { - auto ctx = VirtualMachine::GetParamsContext(); + CHECK(this->exec); + auto ctx = this->GetParamsContext(); // warmup VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, args); @@ -117,6 +123,21 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, op_invokes[packed_index] += 1; } +runtime::Module CreateVirtualMachineDebug(const Executable* exec) { + std::shared_ptr vm = std::make_shared(); + vm->LoadExecutable(exec); + return runtime::Module(vm); +} + +TVM_REGISTER_GLOBAL("relay._vm._VirtualMachineDebug") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec) << "Virtual machine has not been defined yet." + << "\n"; + *rv = CreateVirtualMachineDebug(exec); +}); + } // namespace vm } // namespace runtime } // namespace tvm diff --git a/src/runtime/vm/profiler/vm.h b/src/runtime/vm/profiler/vm.h index 447967cafeb0..ff3296cb6c16 100644 --- a/src/runtime/vm/profiler/vm.h +++ b/src/runtime/vm/profiler/vm.h @@ -47,6 +47,8 @@ class VirtualMachineDebug : public VirtualMachine { void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, const std::vector& args) final; + void LoadExecutable(const Executable* exec); + ~VirtualMachineDebug() {} private: diff --git a/src/relay/backend/vm/serialize_util.h b/src/runtime/vm/serialize_util.h similarity index 95% rename from src/relay/backend/vm/serialize_util.h rename to src/runtime/vm/serialize_util.h index 3e7508ebee9b..3931f2f0e023 100644 --- a/src/relay/backend/vm/serialize_util.h +++ b/src/runtime/vm/serialize_util.h @@ -19,11 +19,11 @@ /*! * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/serialize_util.h + * \file src/runtime/vm/serialize_util.h * \brief Definitions of helpers for serializing and deserializing a Relay VM. */ -#ifndef TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ -#define TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ +#ifndef TVM_RUNTIME_VM_SERIALIZE_UTIL_H_ +#define TVM_RUNTIME_VM_SERIALIZE_UTIL_H_ #include #include @@ -34,7 +34,7 @@ #include namespace tvm { -namespace relay { +namespace runtime { namespace vm { /*! \brief The magic number for the serialized VM bytecode file */ @@ -158,7 +158,7 @@ struct VMInstructionSerializer { }; } // namespace vm -} // namespace relay +} // namespace runtime } // namespace tvm -#endif // TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ +#endif // TVM_RUNTIME_VM_SERIALIZE_UTIL_H_ diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 7dea9bdb95ea..78b74768b930 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -575,11 +575,12 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, const std::shared_ptr& sptr_to_self) { if (name == "invoke") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK(exec) << "The executable is not created yet."; std::string func_name = args[0]; - auto gvit = this->global_map.find(func_name); - CHECK(gvit != this->global_map.end()) << "Cannot find function " << func_name; + auto gvit = exec->global_map.find(func_name); + CHECK(gvit != exec->global_map.end()) << "Cannot find function " << func_name; auto func_index = gvit->second; - const auto& vm_func = this->functions[func_index]; + const auto& vm_func = exec->functions[func_index]; const auto& param_names = vm_func.params; auto ctx = this->GetParamsContext(); @@ -617,10 +618,6 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, } this->Init(contexts); }); - } else if (name == "load_params") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->LoadParams(args[0].operator std::string()); - }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); @@ -628,43 +625,20 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, } TVMContext VirtualMachine::GetParamsContext() const { + CHECK(!ctxs.empty()) << "Context has not been initialized yet." + << "\n"; + // Use the fallback device if no device index is available. int fallback_device_type = static_cast(ctxs[0].device_type); // TODO(wweic): For heterogeneous execution, get device information from byte const auto& cit = - std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) { - return fallback_device_type == static_cast(c.device_type); - }); + std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) { + return fallback_device_type == static_cast(c.device_type); + }); return (cit == ctxs.end() ? ctxs[0] : *cit); } -void VirtualMachine::LoadParams(const std::string& params) { - dmlc::MemoryStringStream mss(const_cast(¶ms)); - dmlc::Stream* strm = &mss; - uint64_t header, reserved; - CHECK(strm->Read(&header)) << "Invalid parameter file"; - CHECK(header == kTVMNDArrayListMagic) << "Invalid parameter file"; - CHECK(strm->Read(&reserved)) << "Invalid parameter file"; - - std::vector names; - CHECK(strm->Read(&names)) << "Invalid parameter file"; - - uint64_t sz; - strm->Read(&sz); - size_t size = static_cast(sz); - CHECK(size == names.size()) << "Invalid parameter file"; - - auto ctx = GetParamsContext(); - for (size_t i = 0; i < size; i++) { - NDArray arr; - CHECK(arr.Load(strm)) << "Invalid parameter file"; - ObjectRef obj = Tensor(arr); - auto copy = CopyTo(obj, ctx); - params_.emplace(std::make_pair(names[i], copy)); - } -} - void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) { auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size); frames.push_back(frame); @@ -699,15 +673,17 @@ ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vectorGetAllocator(ctxs[0]); DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B"; return return_register; } ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vector& args) { - auto func_index = this->global_map[name]; + CHECK(exec) << "The executable has not been created yet."; + auto func_index = exec->global_map.at(name); DLOG(INFO) << "Invoke Global " << name << " at index " << func_index; - return Invoke(this->functions[func_index], args); + return Invoke(exec->functions[func_index], args); } void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, @@ -744,14 +720,16 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); } -void VirtualMachine::Init(const std::vector& ctxs) { - this->ctxs = ctxs; +void VirtualMachine::LoadExecutable(const Executable* exec) { + CHECK(exec) << "The executable is not created yet."; + this->exec = exec; + runtime::Module lib = this->exec->lib; // Get the list of packed functions. - CHECK(primitive_map.empty() || lib.operator->()) + CHECK(exec->primitive_map.empty() || lib.operator->()) << "runtime module should have been built for primitive functions" << "\n"; - for (const auto& it : primitive_map) { + for (const auto& it : this->exec->primitive_map) { const auto& packed_name = it.first; auto packed_index = static_cast(it.second); if (packed_funcs.size() <= packed_index) { @@ -761,6 +739,11 @@ void VirtualMachine::Init(const std::vector& ctxs) { } } + +void VirtualMachine::Init(const std::vector& ctxs) { + this->ctxs = ctxs; +} + inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) { frames.back().register_file[r] = val; } @@ -788,6 +771,7 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const { void VirtualMachine::RunLoop() { CHECK(this->code); + CHECK(this->exec); this->pc = 0; Index frame_start = frames.size(); while (true) { @@ -810,7 +794,8 @@ void VirtualMachine::RunLoop() { throw std::runtime_error("VM encountered fatal error"); } case Opcode::LoadConst: { - auto constant_obj = this->constants[instr.const_index]; + auto constant_obj = exec->constants[instr.const_index]; + // TODO(wweic) ctx could be obtained from the ctxs list. auto device_obj = CopyTo(constant_obj, ctxs[0]); WriteRegister(instr.dst, device_obj); pc++; @@ -828,7 +813,7 @@ void VirtualMachine::RunLoop() { for (Index i = 0; i < instr.num_args; ++i) { args.push_back(ReadRegister(instr.invoke_args_registers[i])); } - InvokeGlobal(this->functions[instr.func_index], args); + InvokeGlobal(exec->functions[instr.func_index], args); frames.back().caller_return_register = instr.dst; goto main_loop; } @@ -858,7 +843,7 @@ void VirtualMachine::RunLoop() { for (Index i = 0; i < instr.num_closure_args; ++i) { args.push_back(ReadRegister(instr.closure_args[i])); } - InvokeGlobal(this->functions[closure->func_index], args); + InvokeGlobal(exec->functions[closure->func_index], args); frames.back().caller_return_register = instr.dst; goto main_loop; } @@ -910,6 +895,7 @@ void VirtualMachine::RunLoop() { for (uint32_t i = 0; i < instr.alloc_tensor.ndim; ++i) { shape[i] = instr.alloc_tensor.shape[i]; } + // TODO(wweic) ctx could be obtained from the ctxs list. auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, ctxs[0]); auto obj = Tensor(data); @@ -931,6 +917,7 @@ void VirtualMachine::RunLoop() { auto num_dims = shape_tensor->shape[0]; auto shape = std::vector(shape_tensor->shape[0]); shape.assign(dims, dims + num_dims); + // TODO(wweic) ctx could be obtained from the ctxs list. auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, ctxs[0]); auto obj = Tensor(data); @@ -976,6 +963,21 @@ void VirtualMachine::RunLoop() { } } +runtime::Module CreateVirtualMachine(const Executable* exec) { + std::shared_ptr vm = std::make_shared(); + vm->LoadExecutable(exec); + return runtime::Module(vm); +} + +TVM_REGISTER_GLOBAL("relay._vm._VirtualMachine") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec) << "The virtual machine executable has not been defined yet." + << "\n"; + *rv = CreateVirtualMachine(exec); +}); + } // namespace vm } // namespace runtime } // namespace tvm diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index cedbc4f71859..1b40f894db08 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -47,14 +47,16 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): if isinstance(f, relay.Expr): mod = relay.Module() mod["main"] = f - vm = relay.vm.compile(mod, target) - vm.init(tvm.cpu()) + exe = relay.vm.compile(mod, target) + vm = relay.vm.VirtualMachine(exe) + vm.init(ctx) return vm.invoke("main", *args) else: assert isinstance(f, relay.Module), "expected expression or module" mod = f - vm = relay.vm.compile(mod, target) - vm.init(tvm.cpu()) + exe = relay.vm.compile(mod, target) + vm = relay.vm.VirtualMachine(exe) + vm.init(ctx) ret = vm.invoke("main", *args) return ret @@ -573,25 +575,6 @@ def test_add_op_broadcast(): mod["main"] = func check_result([x_data, y_data], x_data + y_data, mod=mod) -def test_set_params(): - mod = relay.Module() - x = relay.var('x', shape=(10, 5)) - w = relay.var('w', shape=(6, 5)) - b = relay.var('b', shape=(6,)) - y = relay.nn.bias_add(relay.nn.dense(x, w), b) - mod["main"] = relay.Function([x, w, b], y) - vm = relay.vm.compile(mod, 'llvm') - vm.init(tvm.cpu()) - - x_np = np.random.uniform(size=(10, 5)).astype('float32') - w_np = np.random.uniform(size=(6, 5)).astype('float32') - b_np = np.random.uniform(size=(6,)).astype('float32') - ref_np = np.dot(x_np, w_np.T) + b_np - params = {'w': w_np} - vm.load_params(params) - out = vm.run(x_np, b_np) - tvm.testing.assert_allclose(out.asnumpy(), ref_np) - if __name__ == "__main__": test_id() @@ -626,4 +609,3 @@ def test_set_params(): test_add_op_scalar() test_add_op_tensor() test_add_op_broadcast() - test_set_params() diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index 3a317fc2d111..014648099aeb 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -22,29 +22,25 @@ from tvm import relay from tvm.relay.module import Module as rly_module from tvm.relay import vm as _vm -from tvm.relay import serializer, deserializer from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.prelude import Prelude from tvm.contrib import util from tvm.relay import testing -def create_vm(f, ctx=tvm.cpu(), target="llvm", params=None): +def create_exec(f, target="llvm", params=None): if isinstance(f, relay.Expr): mod = relay.Module() mod["main"] = f - vm = _vm.compile(mod, target=target, params=params) - vm.init(ctx) - return vm + executable = _vm.compile(mod, target=target, params=params) + return executable else: assert isinstance(f, relay.Module), "expected mod as relay.Module" - vm = _vm.compile(f, target=target, params=params) - vm.init(ctx) - return vm + executable = _vm.compile(f, target=target, params=params) + return executable def veval(vm, *args, ctx=tvm.cpu()): assert isinstance(vm, _vm.VirtualMachine), "expected VirtualMachine" - vm.init(ctx) ret = vm.run(*args) return ret @@ -59,13 +55,11 @@ def get_vm_output(mod, data, params, target, ctx, dtype='float32'): return result.asnumpy().astype(dtype) def get_serialized_output(mod, data, params, target, ctx, dtype='float32'): - vm = create_vm(mod, ctx, target, params=params) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(mod, target, params=params) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) des_vm.init(ctx) - des_vm.load_params(params) result = des_vm.run(data) return result.asnumpy().astype(dtype) @@ -99,26 +93,25 @@ def test_serializer(): main = relay.Function([x1, y1], glb_f1(x1) * glb_f2(y1)) mod["main"] = main - vm = create_vm(mod) - ser = serializer.Serializer(vm) + exe = create_exec(mod) - glbs = ser.globals + glbs = exe.globals assert len(glbs) == 3 assert "f1" in glbs assert "f2" in glbs assert "main" in glbs - prim_ops = ser.primitive_ops + prim_ops = exe.primitive_ops assert any(item.startswith('fused_add') for item in prim_ops) assert any(item.startswith('fused_subtract') for item in prim_ops) assert any(item.startswith('fused_multiply') for item in prim_ops) - code = ser.bytecode + code = exe.bytecode assert "main 5 2 5" in code assert "f1 2 1 3" in code assert "f2 2 1 3" in code - code, lib = ser.serialize() + code, lib = exe.save() assert isinstance(code, bytearray) assert isinstance(lib, tvm.module.Module) @@ -129,24 +122,24 @@ def test_save_load(): x_data = np.random.rand(10, 10).astype('float32') # serialize. - vm = create_vm(f) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() + vm = create_exec(f) + code, lib = vm.save() assert isinstance(code, bytearray) # save and load the code and lib file. tmp = util.tempdir() path_lib = tmp.relpath("lib.so") lib.export_library(path_lib) - with open(tmp.relpath("code.bc"), "wb") as fo: + with open(tmp.relpath("code.ro"), "wb") as fo: fo.write(code) loaded_lib = tvm.module.load(path_lib) - loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read()) + loaded_code = bytearray(open(tmp.relpath("code.ro"), "rb").read()) # deserialize. - deser = deserializer.Deserializer(loaded_code, loaded_lib) - des_vm = deser.deserialize() + des_exec = _vm.Executable.load_exec(loaded_code, loaded_lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) res = veval(des_vm, x_data) tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data) @@ -156,12 +149,12 @@ def test_const(): c = relay.const(1.0, "float32") x = relay.var('x', shape=(10, 10), dtype='float32') f = relay.Function([x], x + c) - vm = create_vm(f) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() + exe = create_exec(f) + code, lib = exe.save() assert isinstance(code, bytearray) - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) x_data = np.random.rand(10, 10).astype('float32') res = veval(des_vm, x_data) tvm.testing.assert_allclose(res.asnumpy(), x_data + 1) @@ -177,11 +170,11 @@ def test_if(): x_data = np.random.rand(10, 10).astype('float32') y_data = np.random.rand(10, 10).astype('float32') - vm = create_vm(f) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(f) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) # same res = veval(des_vm, x_data, x_data) @@ -213,11 +206,11 @@ def test_loop(): aarg = relay.var('accum', shape=[], dtype='int32') mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg)) - vm = create_vm(mod) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(mod) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) result = veval(des_vm, i_data, accum_data) tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1))) @@ -230,11 +223,11 @@ def test_tuple(): i_data = np.random.rand(41).astype('float32') j_data = np.random.rand(10).astype('float32') - vm = create_vm(f) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(f) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) result = veval(des_vm, (i_data, j_data)) tvm.testing.assert_allclose(result.asnumpy(), j_data) @@ -251,11 +244,11 @@ def test_adt_list(): f = relay.Function([], l321) mod["main"] = f - vm = create_vm(mod) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(mod) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) result = veval(des_vm) assert len(result) == 2 @@ -297,11 +290,11 @@ def test_adt_compose(): f = relay.Function([y], add_two_body) mod["main"] = f - vm = create_vm(mod) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(mod) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) x_data = np.array(np.random.rand()).astype('float32') result = veval(des_vm, x_data) @@ -317,11 +310,11 @@ def test_closure(): clo = ff(relay.const(1.0)) main = clo(relay.const(2.0)) - vm = create_vm(main) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(main) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) res = veval(des_vm) tvm.testing.assert_allclose(res.asnumpy(), 3.0) diff --git a/tests/python/unittest/test_runtime_vm_profiler.py b/tests/python/unittest/test_runtime_vm_profiler.py index b5ce0ec70e51..53f573730576 100644 --- a/tests/python/unittest/test_runtime_vm_profiler.py +++ b/tests/python/unittest/test_runtime_vm_profiler.py @@ -26,9 +26,9 @@ def test_basic(): mod, params = resnet.get_workload() target = 'llvm' ctx = tvm.cpu() - vm = relay.profiler_vm.compile(mod, target) + exe = relay.profiler_vm.compile(mod, target, params=params) + vm = relay.profiler_vm.VirtualMachineProfiler(exe) vm.init(ctx) - vm.load_params(params) data = np.random.rand(1, 3, 224, 224).astype('float32') res = vm.invoke("main", [data])