Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for using the VM across the RPC boundary. #7746

Merged
merged 10 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions include/tvm/runtime/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ class Executable : public ModuleNode {
*/
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;

/*!
* \brief Save the entire executable to a binary stream.
* \param stream The binary stream to save to.
*/
void SaveToBinary(dmlc::Stream* stream) final;

void SaveToFile(const std::string& path, const std::string& format) final;
jroesch marked this conversation as resolved.
Show resolved Hide resolved
jroesch marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief Serialize the executable into global section, constant section, and
* code section.
Expand Down Expand Up @@ -125,9 +133,17 @@ class Executable : public ModuleNode {
* \brief Get the `lib` module in an executable. Users have the flexibility to call
* `export_library` from the frontend to save the library to disk.
*
* \return The runtime module that contains the hardwre dependent code.
* \return The runtime module that contains the hardware dependent code.
*/
runtime::Module GetLib() const { return lib; }
runtime::Module GetLib() const { return this->imports_[0]; }
jroesch marked this conversation as resolved.
Show resolved Hide resolved

void SetLib(const runtime::Module& lib) {
jroesch marked this conversation as resolved.
Show resolved Hide resolved
ICHECK(lib.defined()) << "library can not be null";

ICHECK_EQ(this->imports().size(), 0) << "can only import the library once";

this->Import(lib);
}

/*!
* \brief Get the arity of the VM Fucntion.
Expand All @@ -148,9 +164,6 @@ class Executable : public ModuleNode {

const char* type_key() const final { return "VMExecutable"; }

/*! \brief The runtime module/library that contains both the host and also the device
* code when executing on non-CPU devices. */
runtime::Module lib;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
/*! \brief A map from globals (as strings) to their index in the function map. */
Expand Down
19 changes: 13 additions & 6 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,24 +269,31 @@ def _collect_dso_modules(self):
return self._collect_from_import_tree(is_dso_exportable)

def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=None, **kwargs):
jroesch marked this conversation as resolved.
Show resolved Hide resolved
"""Export the module and its imported device code one library.
"""
Export the module and all imported modules into a single device library.

This function only works on host llvm modules.
It will pack all the imported modules
This function only works on hos LLVM modules, other runtime::Module
jroesch marked this conversation as resolved.
Show resolved Hide resolved
subclasses will work with this API but they must support implement
the save and load mechanisms of modules completely including saving
from streams and files. This will pack your non-shared library module
into a single shared library which can later be loaded by TVM.

Parameters
----------
file_name : str
The name of the shared library.

fcompile : function(target, file_list, kwargs), optional
Compilation function to use create dynamic library.
The compilation function to use create the final library object during
export. For example this is used to link together all produced artifacts
jroesch marked this conversation as resolved.
Show resolved Hide resolved
into a final dynamic library.
This behavior is controlled by the type of object exported.
If fcompile has attribute object_format, will compile host library
to that format. Otherwise, will use default format "o".

workspace_dir : str, optional
the path to a directory used to create intermediary
artifacts for the process exporting of the library.
The path of the directory used to create the intermediate
artifacts when exporting the module.
If this is not provided a temporary dir will be created.

kwargs : dict, optional
Expand Down
9 changes: 7 additions & 2 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np

import tvm
from tvm.runtime import Module
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base
from .object import Object
Expand Down Expand Up @@ -299,12 +300,16 @@ class VirtualMachine(object):
POOLED_ALLOCATOR = 2

def __init__(self, exe, device, memory_cfg=None):
jroesch marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(exe, Executable):
if not isinstance(exe, Executable) and not isinstance(exe, Module):
jroesch marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError(
"exe is expected to be the type of Executable, "
+ "but received {}".format(type(exe))
)
self.module = _ffi_api._VirtualMachine(exe.module)

if not isinstance(exe, Executable):
exe = Executable(exe)

self.module = exe.mod["vm_load_executable"]()
self._exec = exe
self._init = self.module["init"]
self._invoke = self.module["invoke"]
Expand Down
8 changes: 5 additions & 3 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1155,18 +1155,20 @@ void VMCompiler::Codegen() {

auto compile_engine = CompileEngine::Global();
auto ext_mods = compile_engine->LowerExternalFunctions();
runtime::Module lib;
if (funcs.size() > 0) {
Map<String, IRModule> build_funcs;
for (const auto& i : funcs) {
build_funcs.Set(i.first, i.second);
}
exec_->lib = tvm::build(build_funcs, target_host_);
lib = tvm::build(build_funcs, target_host_);
} else {
// There is no function handled by TVM. We create a virtual main module
// to make sure a DSO module will be also available.
exec_->lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
}
exec_->lib = codegen::CreateMetadataModule(params_, exec_->lib, ext_mods, target_host_);
lib = codegen::CreateMetadataModule(params_, lib, ext_mods, target_host_);
exec_->SetLib(lib);
jroesch marked this conversation as resolved.
Show resolved Hide resolved
}

ExprDeviceMap VMCompiler::AnalyzeContext() const {
Expand Down
42 changes: 23 additions & 19 deletions src/runtime/library_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,28 @@ void InitContextFunctions(std::function<void*(const char*)> fgetsymbol) {
#undef TVM_INIT_CONTEXT_FUNC
}

Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) {
jroesch marked this conversation as resolved.
Show resolved Hide resolved
std::string loadkey = "runtime.module.loadbinary_";
std::string fkey = loadkey + type_key;
const PackedFunc* f = Registry::Get(fkey);
if (f == nullptr) {
std::string loaders = "";
for (auto name : Registry::ListNames()) {
if (name.rfind(loadkey, 0) == 0) {
jroesch marked this conversation as resolved.
Show resolved Hide resolved
if (loaders.size() > 0) {
loaders += ", ";
}
loaders += name.substr(loadkey.size());
}
}
ICHECK(f != nullptr) << "Binary was created using " << type_key
jroesch marked this conversation as resolved.
Show resolved Hide resolved
<< " but a loader of that name is not registered. Available loaders are "
<< loaders << ". Perhaps you need to recompile with this runtime enabled.";
}

return (*f)(static_cast<void*>(stream));
}

/*!
* \brief Load and append module blob to module list
* \param mblob The module blob.
Expand Down Expand Up @@ -133,25 +155,7 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib) {
ICHECK(stream->Read(&import_tree_row_ptr));
ICHECK(stream->Read(&import_tree_child_indices));
} else {
std::string loadkey = "runtime.module.loadbinary_";
std::string fkey = loadkey + tkey;
const PackedFunc* f = Registry::Get(fkey);
if (f == nullptr) {
std::string loaders = "";
for (auto name : Registry::ListNames()) {
if (name.rfind(loadkey, 0) == 0) {
if (loaders.size() > 0) {
loaders += ", ";
}
loaders += name.substr(loadkey.size());
}
}
ICHECK(f != nullptr)
<< "Binary was created using " << tkey
<< " but a loader of that name is not registered. Available loaders are " << loaders
<< ". Perhaps you need to recompile with this runtime enabled.";
}
Module m = (*f)(static_cast<void*>(stream));
auto m = LoadModuleFromBinary(tkey, stream);
modules.emplace_back(m);
}
}
Expand Down
12 changes: 12 additions & 0 deletions src/runtime/library_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,21 @@
#include <tvm/runtime/module.h>

#include <functional>
#include <string>

namespace tvm {
namespace runtime {

/*! \brief Load a module with the given type key directly from the stream.
* This function wraps the registry mechanism used to store type based deserializers
* for each runtime::Module sub-class.
*
* \param type_key The type key of the serialized module.
* \param stream A pointer to the stream containing the serialized module.
* \return module The deserialized module.
*/
Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream);

/*!
* \brief Library is the common interface
* for storing data in the form of shared libaries.
Expand Down
59 changes: 58 additions & 1 deletion src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
#include <utility>
#include <vector>

#include "../file_utils.h"
#include "../library_module.h"
#include "serialize_utils.h"

namespace tvm {
Expand Down Expand Up @@ -74,6 +76,12 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr<Obje
int index = args[1];
*rv = this->GetFunctionParameterName(func_name, index);
});
} else if (name == "vm_load_executable") {
jroesch marked this conversation as resolved.
Show resolved Hide resolved
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
auto vm = make_object<VirtualMachine>();
vm->LoadExecutable(this);
*rv = Module(vm);
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc(nullptr);
Expand Down Expand Up @@ -476,8 +484,19 @@ void LoadHeader(dmlc::Stream* strm) {
}

runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) {
std::cout << "code: " << code.size() << std::endl;
jroesch marked this conversation as resolved.
Show resolved Hide resolved
auto exec = make_object<Executable>();
exec->lib = lib;

// Support null-initialization of lib, to enable initialization during
// deserialization before we have we have deserialized the imports.
if (lib.defined()) {
jroesch marked this conversation as resolved.
Show resolved Hide resolved
ICHECK_EQ(exec->imports_.size(), 0)
jroesch marked this conversation as resolved.
Show resolved Hide resolved
<< "A VMExecutable should never have more than one import inside an the executable, \n"
<< "the first import should *always* be the library containing"
<< "the platform specific kernel code";
exec->Import(lib);
jroesch marked this conversation as resolved.
Show resolved Hide resolved
}

exec->code_ = code;
dmlc::MemoryStringStream strm(&exec->code_);

Expand Down Expand Up @@ -765,6 +784,44 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) {
}
}

void Executable::SaveToBinary(dmlc::Stream* stream) {
auto code_bytes = this->Save();
std::string code(code_bytes.data, code_bytes.size);
stream->Write(code);

ICHECK(this->imports()[0].defined()) << "the library must be imported before serialization";
}

Module ExecutableLoadBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
std::string code;
stream->Read(&code);
auto exec = Executable::Load(code, Module());
return exec;
}

void Executable::SaveToFile(const std::string& path, const std::string& format) {
std::string data;
dmlc::MemoryStringStream writer(&data);
dmlc::SeekStream* strm = &writer;
SaveToBinary(strm);
SaveBinaryToFile(path, data);
}

TVM_REGISTER_GLOBAL("runtime.module.loadbinary_VMExecutable").set_body_typed(ExecutableLoadBinary);

// Load module from module.
Module ExecutableLoadFile(const std::string& file_name, const std::string& format) {
std::string data;
LoadBinaryFromFile(file_name, &data);
dmlc::MemoryStringStream reader(&data);
dmlc::Stream* strm = &reader;
auto exec = ExecutableLoadBinary(reinterpret_cast<void*>(strm));
return exec;
}

TVM_REGISTER_GLOBAL("runtime.module.loadfile_VMExecutable").set_body_typed(ExecutableLoadFile);

TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals").set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
Expand Down
9 changes: 5 additions & 4 deletions src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,12 @@ void VirtualMachine::LoadExecutable(const Executable* exec) {
ICHECK(exec) << "The executable is not created yet.";
exec_ = exec;

runtime::Module lib = exec_->lib;
// Get the list of packed functions.
runtime::Module lib = exec_->GetLib();

ICHECK(exec->primitive_map.empty() || lib.operator->())
<< "runtime module should have been built for primitive functions"
<< "\n";
<< "If the executable has declared primitive functions, the"
<< "generated kernel library must non-be null.";

for (const auto& it : exec_->primitive_map) {
const auto& packed_name = it.first;
auto packed_index = static_cast<size_t>(it.second);
Expand Down
30 changes: 29 additions & 1 deletion tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@

import tvm
from tvm import runtime
from tvm import relay
from tvm import relay, IRModule
from tvm.relay.backend import vm
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.prelude import Prelude
from tvm.relay.loops import while_loop
from tvm.relay import testing
from tvm.contrib import utils
from tvm import rpc
import tvm.testing


Expand Down Expand Up @@ -799,5 +802,30 @@ def test_constant_shape_with_external_codegen():
assert "shape_func" in opt_mod.astext(False)


def test_vm_rpc():
target = "llvm"
target_host = "llvm"

x = relay.var("x", shape=(10, 1))
f = relay.Function([x], x + x)
mod = IRModule.from_expr(f)
vm_exec = vm.compile(mod, target=target, target_host=target_host)

temp = utils.tempdir()
jroesch marked this conversation as resolved.
Show resolved Hide resolved
path = temp.relpath("vm_library.so")
vm_exec.mod.export_library(path)

remote = rpc.LocalSession()
remote.upload(path)
rexec = remote.load_module("vm_library.so")

ctx = remote.cpu()
vm_factory = runtime.vm.VirtualMachine(rexec, ctx)
np_input = np.random.uniform(size=(10, 1)).astype("float32")
input_tensor = tvm.nd.array(np_input, ctx)
out = vm_factory.invoke("main", [input_tensor])
np.testing.assert_allclose(out.asnumpy(), np_input + np_input)


if __name__ == "__main__":
pytest.main([__file__])