Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Created CSourceMetaData module for model metadata #7002

Merged
merged 18 commits into from
Dec 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions apps/bundle_deploy/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def build_module(opts):
build_dir = os.path.abspath(opts.out_dir)
if not os.path.isdir(build_dir):
os.makedirs(build_dir)

lib.save(os.path.join(build_dir, file_format_str.format(name="model", ext="o")))
with open(
os.path.join(build_dir, file_format_str.format(name="graph", ext="json")), "w"
Expand Down Expand Up @@ -85,7 +84,6 @@ def build_test_module(opts):
build_dir = os.path.abspath(opts.out_dir)
if not os.path.isdir(build_dir):
os.makedirs(build_dir)

lib.save(os.path.join(build_dir, file_format_str.format(name="test_model", ext="o")))
with open(
os.path.join(build_dir, file_format_str.format(name="test_graph", ext="json")), "w"
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,4 +424,16 @@ def build(inputs, args=None, target=None, target_host=None, name="default_functi
for mdev in device_modules:
if mdev:
rt_mod_host.import_module(mdev)

if not isinstance(target_host, Target):
target_host = Target(target_host)
if (
"system-lib" in target_host.attrs
and target_host.attrs["system-lib"].value == 1
and target_host.kind.name == "c"
):
create_csource_metadata_module = tvm._ffi.get_global_func(
"runtime.CreateCSourceMetadataModule"
)
return create_csource_metadata_module([rt_mod_host], target_host)
return rt_mod_host
14 changes: 10 additions & 4 deletions python/tvm/micro/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def path(self):
# void* arg0 = (((TVMValue*)args)[0].v_handle);
# int32_t arg0_code = ((int32_t*)arg_type_ids)[(0)];
_CRT_GENERATED_LIB_OPTIONS["cflags"].append("-Wno-unused-variable")
_CRT_GENERATED_LIB_OPTIONS["ccflags"].append("-Wno-unused-variable")


# Many TVM-intrinsic operators (i.e. expf, in particular)
Expand Down Expand Up @@ -159,9 +160,6 @@ def build_static_runtime(
mod_build_dir = workspace.relpath(os.path.join("build", "module"))
os.makedirs(mod_build_dir)
mod_src_dir = workspace.relpath(os.path.join("src", "module"))
os.makedirs(mod_src_dir)
mod_src_path = os.path.join(mod_src_dir, "module.c")
module.save(mod_src_path, "cc")

libs = []
for mod_or_src_dir in (extra_libs or []) + RUNTIME_LIB_SRC_DIRS:
Expand All @@ -181,7 +179,15 @@ def build_static_runtime(

libs.append(compiler.library(lib_build_dir, lib_srcs, lib_opts))

libs.append(compiler.library(mod_build_dir, [mod_src_path], generated_lib_opts))
mod_src_dir = workspace.relpath(os.path.join("src", "module"))
os.makedirs(mod_src_dir)
libs.append(
module.export_library(
mod_build_dir,
workspace_dir=mod_src_dir,
fcompile=lambda bdir, srcs, **kwargs: compiler.library(bdir, srcs, generated_lib_opts),
)
)

runtime_build_dir = workspace.relpath(f"build/runtime")
os.makedirs(runtime_build_dir)
Expand Down
40 changes: 30 additions & 10 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name, unused-import, import-outside-toplevel
# pylint: disable=invalid-name, unused-import, import-outside-toplevel, inconsistent-return-statements
"""Runtime Module namespace."""
import os
import ctypes
Expand Down Expand Up @@ -252,7 +252,7 @@ def _collect_dso_modules(self):
def _dso_exportable(self):
return self.type_key == "llvm" or self.type_key == "c"

def export_library(self, file_name, fcompile=None, addons=None, **kwargs):
def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=None, **kwargs):
"""Export the module and its imported device code one library.

This function only works on host llvm modules.
Expand All @@ -268,8 +268,19 @@ def export_library(self, file_name, fcompile=None, addons=None, **kwargs):
If fcompile has attribute object_format, will compile host library
to that format. Otherwise, will use default format "o".

workspace_dir : str, optional
the path to a directory used to create intermediary
artifacts for the process exporting of the library.
If this is not provided a temporary dir will be created.

kwargs : dict, optional
Additional arguments passed to fcompile

Returns
-------
result of fcompile() : unknown, optional
If the compilation function returns an artifact it would be returned via
export_library, if any.
"""
# NOTE: this function depends on contrib library features
# which are only available in when TVM function is available.
Expand All @@ -292,22 +303,28 @@ def export_library(self, file_name, fcompile=None, addons=None, **kwargs):
return

modules = self._collect_dso_modules()
temp = _utils.tempdir()
if workspace_dir is None:
temp = _utils.tempdir()
workspace_dir = temp.temp_dir
files = addons if addons else []
is_system_lib = False
has_c_module = False
llvm_target_triple = None
for index, module in enumerate(modules):
if fcompile is not None and hasattr(fcompile, "object_format"):
object_format = fcompile.object_format
if module.type_key == "c":
object_format = "c"
has_c_module = True
else:
object_format = fcompile.object_format
else:
if module.type_key == "llvm":
object_format = "o"
else:
assert module.type_key == "c"
object_format = "cc"
object_format = "c"
has_c_module = True
path_obj = temp.relpath("lib" + str(index) + "." + object_format)
path_obj = os.path.join(workspace_dir, f"lib{index}.{object_format}")
module.save(path_obj)
files.append(path_obj)
is_system_lib = (
Expand All @@ -330,25 +347,28 @@ def export_library(self, file_name, fcompile=None, addons=None, **kwargs):

if self.imported_modules:
if enabled("llvm") and llvm_target_triple:
path_obj = temp.relpath("devc." + object_format)
path_obj = os.path.join(workspace_dir, f"devc.{object_format}")
m = _ffi_api.ModulePackImportsToLLVM(self, is_system_lib, llvm_target_triple)
m.save(path_obj)
files.append(path_obj)
else:
path_cc = temp.relpath("devc.cc")
path_cc = os.path.join(workspace_dir, "devc.c")
with open(path_cc, "w") as f:
f.write(_ffi_api.ModulePackImportsToC(self, is_system_lib))
files.append(path_cc)

if has_c_module:
# The imports could contain a c module but the object format could be tar
# Thus, it would not recognize the following include paths as options
# which are there assuming a c compiler is the fcompile.
if has_c_module and not file_name.endswith(".tar"):
options = []
if "options" in kwargs:
opts = kwargs["options"]
options = opts if isinstance(opts, (list, tuple)) else [opts]
opts = options + ["-I" + path for path in find_include_path()]
kwargs.update({"options": opts})

fcompile(file_name, files, **kwargs)
return fcompile(file_name, files, **kwargs)


def system_lib():
Expand Down
10 changes: 3 additions & 7 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,18 +510,14 @@ class RelayBuildModule : public runtime::ModuleNode {
// If we cannot decide the target is LLVM, we create an empty CSourceModule.
// The code content is initialized with ";" to prevent complaining
// from CSourceModuleNode::SaveToFile.
ret_.mod = tvm::codegen::CSourceModuleCreate(";", "");
ret_.mod = tvm::codegen::CSourceModuleCreate(";", "", Array<String>{});
}
} else {
ret_.mod = tvm::build(lowered_funcs, target_host_);
}

Array<tvm::runtime::Module> ext_mods = graph_codegen_->GetExternalModules();
// TODO(zhiics) We should be able to completely switch to MetadataModule no
// matter whether there are external modules or not.
if (!ext_mods.empty()) {
ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods);
}
auto ext_mods = graph_codegen_->GetExternalModules();
ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, GetTargetHost());
}

private:
Expand Down
26 changes: 10 additions & 16 deletions src/relay/backend/contrib/codegen_c/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,20 +215,19 @@ class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public Code

class CSourceCodegen : public CSourceModuleCodegenBase {
public:
std::pair<std::string, Array<String>> GenCFunc(const Function& func) {
std::tuple<Array<String>, String, String> GenCFunc(const Function& func) {
ICHECK(func.defined()) << "Input error: expect a Relay function.";

// Record the external symbol for runtime lookup.
auto sid = GetExtSymbol(func);

CodegenC builder(sid);
CodegenC builder(GetExtSymbol(func));
auto out = builder.VisitExpr(func->body);
code_stream_ << builder.JIT(out);

return {sid, builder.const_vars_};
return std::make_tuple(builder.const_vars_, builder.ext_func_id_, builder.JIT(out));
}

runtime::Module CreateCSourceModule(const ObjectRef& ref) override {
ICHECK(ref->IsInstance<FunctionNode>());
auto res = GenCFunc(Downcast<Function>(ref));
Array<String> variables = std::get<0>(res);
String func_name = std::get<1>(res);

// Create headers
code_stream_ << "#include <cstring>\n";
code_stream_ << "#include <vector>\n";
Expand Down Expand Up @@ -259,18 +258,13 @@ class CSourceCodegen : public CSourceModuleCodegenBase {
)op_macro";

code_stream_ << operator_macro << "\n\n";

ICHECK(ref->IsInstance<FunctionNode>());
auto res = GenCFunc(Downcast<Function>(ref));
code_stream_ << std::get<2>(res);
std::string code = code_stream_.str();

String sym = std::get<0>(res);
Array<String> variables = std::get<1>(res);

// Create a CSource module
const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate");
ICHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module";
return (*pf)(code, "c", sym, variables);
return (*pf)(code, "c", Array<String>{func_name}, variables);
}

private:
Expand Down
3 changes: 2 additions & 1 deletion src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {
// Create a CSource module
const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate");
ICHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module";
return (*pf)(code, "c", sym, variables);
// TODO(@manupa-arm): pass the function names to enable system-lib creation
return (*pf)(code, "c", Array<String>{sym}, variables);
}

private:
Expand Down
6 changes: 2 additions & 4 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1146,11 +1146,9 @@ void VMCompiler::Codegen() {
} else {
// There is no function handled by TVM. We create a virtual main module
// to make sure a DSO module will be also available.
exec_->lib = codegen::CSourceModuleCreate(";", "");
}
if (!ext_mods.empty()) {
exec_->lib = codegen::CreateMetadataModule(params_, exec_->lib, ext_mods);
exec_->lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
}
exec_->lib = codegen::CreateMetadataModule(params_, exec_->lib, ext_mods, target_host_);
}

ExprDeviceMap VMCompiler::AnalyzeContext() const {
Expand Down
2 changes: 1 addition & 1 deletion src/target/func_registry_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
namespace tvm {
namespace target {

std::string GenerateFuncRegistryNames(const std::vector<std::string>& function_names) {
std::string GenerateFuncRegistryNames(const Array<String>& function_names) {
std::stringstream ss;
ss << (unsigned char)(function_names.size());
for (auto f : function_names) {
Expand Down
7 changes: 6 additions & 1 deletion src/target/func_registry_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,18 @@
#ifndef TVM_TARGET_FUNC_REGISTRY_GENERATOR_H_
#define TVM_TARGET_FUNC_REGISTRY_GENERATOR_H_

#include <tvm/runtime/container.h>

#include <string>
#include <vector>

using tvm::runtime::Array;
using tvm::runtime::String;

namespace tvm {
namespace target {

std::string GenerateFuncRegistryNames(const std::vector<std::string>& function_names);
std::string GenerateFuncRegistryNames(const Array<String>& function_names);

} // namespace target
} // namespace tvm
Expand Down
4 changes: 2 additions & 2 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -794,10 +794,10 @@ llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() {
void CodeGenCPU::AddStartupFunction() {
if (registry_functions_.size() != 0) {
ICHECK(is_system_lib_) << "Loading of --system-lib modules is yet to be defined for C runtime";
std::vector<std::string> symbols;
Array<String> symbols;
std::vector<llvm::Constant*> funcs;
for (auto sym : registry_functions_) {
symbols.emplace_back(sym.first);
symbols.push_back(sym.first);
funcs.emplace_back(llvm::ConstantExpr::getBitCast(
sym.second, ftype_tvm_backend_packed_c_func_->getPointerTo()));
}
Expand Down
14 changes: 12 additions & 2 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ class LLVMModuleNode final : public runtime::ModuleNode {
if (name == "__tvm_is_system_module") {
bool flag = (mptr_->getFunction("__tvm_module_startup") != nullptr);
return PackedFunc([flag](TVMArgs args, TVMRetValue* rv) { *rv = flag; });
} else if (name == "get_func_names") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->function_names_; });
} else if (name == "get_symbol") {
return PackedFunc(nullptr);
} else if (name == "get_const_vars") {
return PackedFunc(nullptr);
} else if (name == "_get_target_triple") {
std::string target_triple = tm_->getTargetTriple().str();
// getTargetTriple() doesn't include other flags besides the triple. Add back flags which are
Expand Down Expand Up @@ -218,9 +225,10 @@ class LLVMModuleNode final : public runtime::ModuleNode {
ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs, but got " << kv.second->GetTypeKey();
auto f = Downcast<PrimFunc>(kv.second);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined());
function_names_.push_back(global_symbol.value());
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined());
entry_func = global_symbol.value();
}
funcs.push_back(f);
Expand Down Expand Up @@ -377,6 +385,8 @@ class LLVMModuleNode final : public runtime::ModuleNode {
std::unique_ptr<llvm::Module> module_;
// the context.
std::shared_ptr<llvm::LLVMContext> ctx_;
/* \brief names of the functions declared in this module */
Array<String> function_names_;
};

TVM_REGISTER_GLOBAL("target.build.llvm")
Expand Down
Loading