Skip to content

Commit

Permalink
[LLVM][RUNTIME] Enable multi systemlib with device code
Browse files Browse the repository at this point in the history
This PR enables combination of multiple system lib
into the same static library with a system_lib_prefix attribute.
This can open doors for multiple models to be compiled separately
then packaged into the same app via static library.

It resolves a previous issue that prevents multiple system
lib to be linked together when they come with extra binary component
such as CUDA due to symbol conflict.
  • Loading branch information
tqchen committed May 13, 2023
1 parent ae9209b commit be88b90
Show file tree
Hide file tree
Showing 19 changed files with 311 additions and 130 deletions.
29 changes: 29 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,35 @@ constexpr const char* kConstants = "constants";
*/
constexpr const char* kExternalMods = "external_mods";

/*!
* \brief A prefix for generating C symbols system lib creation.
*
* This prefix guides passes that creates global_symbol for internal functions
* that may have c linkage (e.g. TIR functions and some BYOC functions). It also affects
* the symbol of the fat bin blob during module export.
*
* This attribute is used to avoid symbol conflict when we
* generate and combine multiple system libs that get linked into one.
*
* Rationale: mechanisms like BYOC rely on the common global symbol
* and each external compiler also has its own mechanism of mangling.
* As a result, we cannot rely on other mechanisms on setting a global_symbol and then renaming,
* because the external compiler already agreed on the name.
*
* system_lib_prefix provides a way to hint at the passes to allow names to
* avoid name conflict at the beginning.
*
* Note that users can still directly specify global symbols that may conflict.
* It is up to the downstream toolchain to manage those external-facing functions.
*
* This does not affect non-C linkage functions it is less of an issue because
* they will be embedded into fatbin that in different symbols,
* The system lib loader can pick the right prefix for a given prefix.
*
* Having this attribute implies system lib generation linkage.
*/
constexpr const char* kSystemLibPrefix = "system_lib_prefix";

/*!
* \brief All the named runtime::NDArrays accumulated during compilation by external codegen.
* Generally the associated runtime::Module will indicate it requires bindings for these names,
Expand Down
2 changes: 0 additions & 2 deletions include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,6 @@ constexpr const char* tvm_get_c_metadata = "get_c_metadata";
constexpr const char* tvm_module_ctx = "__tvm_module_ctx";
/*! \brief Global variable to store device module blob */
constexpr const char* tvm_dev_mblob = "__tvm_dev_mblob";
/*! \brief Number of bytes of device module blob. */
constexpr const char* tvm_dev_mblob_nbytes = "__tvm_dev_mblob_nbytes";
/*! \brief global function to set device */
constexpr const char* tvm_set_device = "__tvm_set_device";
/*! \brief Auxiliary counter to global barrier. */
Expand Down
9 changes: 7 additions & 2 deletions include/tvm/target/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ runtime::Module Build(IRModule mod, Target target);
*
* \param m The host module with the imports.
* \param system_lib Whether expose as system library.
* \param c_symbol_prefix Optional symbol prefix of the blob symbol.
* \return cstr The C string representation of the file.
*/
std::string PackImportsToC(const runtime::Module& m, bool system_lib);
std::string PackImportsToC(const runtime::Module& m, bool system_lib,
const std::string& c_symbol_prefix = "");

/*!
* \brief Pack imported device library to a LLVM module.
Expand All @@ -68,10 +70,13 @@ std::string PackImportsToC(const runtime::Module& m, bool system_lib);
* \param m The host module with the imports.
* \param system_lib Whether expose as system library.
* \param target_triple LLVM target triple
* \param c_symbol_prefix Optional symbol prefix of the blob symbol.
*
* \return runtime::Module The generated LLVM module.
*/
runtime::Module PackImportsToLLVM(const runtime::Module& m, bool system_lib,
const std::string& target_triple);
const std::string& target_triple,
const std::string& c_symbol_prefix = "");
} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_CODEGEN_H_
27 changes: 21 additions & 6 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No
files = addons if addons else []
is_system_lib = False
has_c_module = False
system_lib_prefix = None
llvm_target_string = None
global_object_format = "o"
for index, module in enumerate(modules):
Expand Down Expand Up @@ -549,6 +550,8 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No
if module.type_key == "llvm":
is_system_lib = module.get_function("__tvm_is_system_module")()
llvm_target_string = module.get_function("_get_target_string")()
system_lib_prefix = module.get_function("__tvm_get_system_lib_prefix")()

if not fcompile:
if file_name.endswith(".tar"):
fcompile = _tar.tar
Expand All @@ -564,15 +567,21 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No
raise ValueError("%s need --system-lib option" % str(fcompile))

if self.imported_modules:
pack_lib_prefix = system_lib_prefix if system_lib_prefix else ""

if enabled("llvm") and llvm_target_string:
path_obj = os.path.join(workspace_dir, f"devc.{global_object_format}")
m = _ffi_api.ModulePackImportsToLLVM(self, is_system_lib, llvm_target_string)
path_obj = os.path.join(
workspace_dir, f"{pack_lib_prefix}devc.{global_object_format}"
)
m = _ffi_api.ModulePackImportsToLLVM(
self, is_system_lib, llvm_target_string, pack_lib_prefix
)
m.save(path_obj)
files.append(path_obj)
else:
path_cc = os.path.join(workspace_dir, "devc.c")
path_cc = os.path.join(workspace_dir, f"{pack_lib_prefix}devc.c")
with open(path_cc, "w") as f:
f.write(_ffi_api.ModulePackImportsToC(self, is_system_lib))
f.write(_ffi_api.ModulePackImportsToC(self, is_system_lib, pack_lib_prefix))
files.append(path_cc)

# The imports could contain a c module but the object format could be tar
Expand All @@ -589,7 +598,7 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No
return fcompile(file_name, files, **kwargs)


def system_lib():
def system_lib(symbol_prefix=""):
"""Get system-wide library module singleton.
System lib is a global module that contains self register functions in startup.
Expand All @@ -602,12 +611,18 @@ def system_lib():
The system lib is intended to be linked and loaded during the entire life-cyle of the program.
If you want dynamic loading features, use dso modules instead.
Parameters
----------
symbol_prefix: Optional[str]
Optional symbol prefix that can be used for search. When we lookup a symbol
symbol_prefix + name will first be searched, then the name without symbol_prefix.
Returns
-------
module : runtime.Module
The system-wide library module.
"""
return _ffi_api.SystemLib()
return _ffi_api.SystemLib(symbol_prefix)


def load_module(path, fmt=""):
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/library_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ ObjectPtr<Library> CreateDSOLibraryObject(std::string library_path);
* \param lib The library.
* \param wrapper Optional function used to wrap a TVMBackendPackedCFunc,
* by default WrapPackedFunc is used.
* \param symbol_prefix Optional symbol prefix that can be used to search alternative symbols.
*
* \return The corresponding loaded module.
*
* \note This function can create multiple linked modules
Expand Down
58 changes: 39 additions & 19 deletions src/runtime/system_library.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,8 @@
namespace tvm {
namespace runtime {

class SystemLibrary : public Library {
class SystemLibraryRegistry {
public:
SystemLibrary() = default;

void* GetSymbol(const char* name) final {
std::lock_guard<std::mutex> lock(mutex_);
auto it = tbl_.find(name);
if (it != tbl_.end()) {
return it->second;
} else {
return nullptr;
}
}

void RegisterSymbol(const std::string& name, void* ptr) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = tbl_.find(name);
Expand All @@ -56,8 +44,18 @@ class SystemLibrary : public Library {
tbl_[name] = ptr;
}

static const ObjectPtr<SystemLibrary>& Global() {
static auto inst = make_object<SystemLibrary>();
void* GetSymbol(const char* name) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = tbl_.find(name);
if (it != tbl_.end()) {
return it->second;
} else {
return nullptr;
}
}

static SystemLibraryRegistry* Global() {
static SystemLibraryRegistry* inst = new SystemLibraryRegistry();
return inst;
}

Expand All @@ -68,14 +66,36 @@ class SystemLibrary : public Library {
std::unordered_map<std::string, void*> tbl_;
};

TVM_REGISTER_GLOBAL("runtime.SystemLib").set_body_typed([]() {
static auto mod = CreateModuleFromLibrary(SystemLibrary::Global());
return mod;
class SystemLibrary : public Library {
public:
explicit SystemLibrary(const std::string& symbol_prefix) : symbol_prefix_(symbol_prefix) {}

void* GetSymbol(const char* name) {
if (symbol_prefix_.length() != 0) {
std::string name_with_prefix = symbol_prefix_ + name;
void* symbol = reg_->GetSymbol(name_with_prefix.c_str());
if (symbol != nullptr) return symbol;
}
return reg_->GetSymbol(name);
}

private:
SystemLibraryRegistry* reg_ = SystemLibraryRegistry::Global();
std::string symbol_prefix_;
};

TVM_REGISTER_GLOBAL("runtime.SystemLib").set_body([](TVMArgs args, TVMRetValue* rv) {
std::string symbol_prefix = "";
if (args.size() != 0) {
symbol_prefix = args[0].operator std::string();
}
auto mod = CreateModuleFromLibrary(make_object<SystemLibrary>(symbol_prefix));
*rv = mod;
});
} // namespace runtime
} // namespace tvm

int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr) {
tvm::runtime::SystemLibrary::Global()->RegisterSymbol(name, ptr);
tvm::runtime::SystemLibraryRegistry::Global()->RegisterSymbol(name, ptr);
return 0;
}
31 changes: 22 additions & 9 deletions src/target/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,15 @@ std::string SerializeModule(const runtime::Module& mod) {
}
} // namespace

std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
std::string PackImportsToC(const runtime::Module& mod, bool system_lib,
const std::string& c_symbol_prefix) {
std::string bin = SerializeModule(mod);
std::string mdev_blob_name = c_symbol_prefix + runtime::symbol::tvm_dev_mblob;

if (c_symbol_prefix.length() != 0) {
CHECK(system_lib)
<< "c_symbol_prefix advanced option should be used in conjuction with system-lib";
}

// translate to C program
std::ostringstream os;
Expand All @@ -253,10 +260,10 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
os << "#ifdef __cplusplus\n"
<< "extern \"C\" {\n"
<< "#endif\n";
os << "TVM_EXPORT extern const unsigned char " << runtime::symbol::tvm_dev_mblob << "[];\n";
os << "TVM_EXPORT extern const unsigned char " << mdev_blob_name << "[];\n";
uint64_t nbytes = bin.length();
os << "const unsigned char " << runtime::symbol::tvm_dev_mblob << "["
<< bin.length() + sizeof(nbytes) << "] = {\n ";
os << "const unsigned char " << mdev_blob_name << "[" << bin.length() + sizeof(nbytes)
<< "] = {\n ";
os << std::hex;
size_t nunit = 80 / 4;
for (size_t i = 0; i < sizeof(nbytes); ++i) {
Expand All @@ -279,9 +286,9 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
os << "\n};\n";
if (system_lib) {
os << "extern int TVMBackendRegisterSystemLibSymbol(const char*, void*);\n";
os << "static int " << runtime::symbol::tvm_dev_mblob << "_reg_ = "
<< "TVMBackendRegisterSystemLibSymbol(\"" << runtime::symbol::tvm_dev_mblob << "\", (void*)"
<< runtime::symbol::tvm_dev_mblob << ");\n";
os << "static int " << mdev_blob_name << "_reg_ = "
<< "TVMBackendRegisterSystemLibSymbol(\"" << mdev_blob_name << "\", (void*)"
<< mdev_blob_name << ");\n";
}
os << "#ifdef __cplusplus\n"
<< "}\n"
Expand All @@ -290,7 +297,13 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
}

runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib,
const std::string& llvm_target_string) {
const std::string& llvm_target_string,
const std::string& c_symbol_prefix) {
if (c_symbol_prefix.length() != 0) {
CHECK(system_lib)
<< "c_symbol_prefix advanced option should be used in conjuction with system-lib";
}

std::string bin = SerializeModule(mod);

uint64_t nbytes = bin.length();
Expand All @@ -308,7 +321,7 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib,
// the codegen function.
const PackedFunc* codegen_f = runtime::Registry::Get(codegen_f_name);
ICHECK(codegen_f != nullptr) << "codegen.codegen_blob is not presented.";
return (*codegen_f)(blob_byte_array, system_lib, llvm_target_string);
return (*codegen_f)(blob_byte_array, system_lib, llvm_target_string, c_symbol_prefix);
}

TVM_REGISTER_GLOBAL("target.Build").set_body_typed(Build);
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ runtime::Module BuildAMDGPU(IRModule mod, Target target) {
#endif
auto cg = std::make_unique<CodeGenAMDGPU>();

cg->Init("TVMAMDGPUModule", llvm_target.get(), false, false, false);
cg->Init("TVMAMDGPUModule", llvm_target.get(), NullOpt, false, false);

cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) {
ICHECK(kv.second->template IsInstance<PrimFuncNode>())
Expand Down
22 changes: 12 additions & 10 deletions src/target/llvm/codegen_blob.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,22 @@ namespace tvm {
namespace codegen {

std::unique_ptr<llvm::Module> CodeGenBlob(const std::string& data, bool system_lib,
LLVMTarget* llvm_target) {
LLVMTarget* llvm_target,
const std::string& c_symbol_prefix) {
llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine();
const llvm::Triple& triple = tm->getTargetTriple();
llvm::LLVMContext* ctx = llvm_target->GetContext();
std::string module_name = "devc";
std::string module_name = c_symbol_prefix + "devc";
auto module = std::make_unique<llvm::Module>(module_name, *ctx);
module->setTargetTriple(triple.str());
llvm_target->SetTargetMetadata(module.get());
module->setDataLayout(tm->createDataLayout());
auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false);
std::string mdev_blob_name = c_symbol_prefix + runtime::symbol::tvm_dev_mblob;

auto* tvm_dev_mblob = new llvm::GlobalVariable(
*module, blob_value->getType(), true, llvm::GlobalValue::ExternalLinkage, blob_value,
runtime::symbol::tvm_dev_mblob, nullptr, llvm::GlobalVariable::NotThreadLocal, 0);
mdev_blob_name, nullptr, llvm::GlobalVariable::NotThreadLocal, 0);

// If large const data (>2GB) is saved to default .rodata section
// then linking it to shared library will fail - relocation truncated to fit: R_X86_64_PC32.
Expand Down Expand Up @@ -106,9 +109,9 @@ std::unique_ptr<llvm::Module> CodeGenBlob(const std::string& data, bool system_l
auto int8_ptr_ty = int8_ty->getPointerTo(0);

llvm::Constant* constant_zero = llvm::Constant::getNullValue(int32_ty);
auto* tvm_dev_mblob_reg = new llvm::GlobalVariable(
*module, int32_ty, false, llvm::GlobalValue::InternalLinkage, constant_zero,
std::string(runtime::symbol::tvm_dev_mblob) + "_reg_");
auto* tvm_dev_mblob_reg =
new llvm::GlobalVariable(*module, int32_ty, false, llvm::GlobalValue::InternalLinkage,
constant_zero, mdev_blob_name + "_reg_");
auto tvm_dev_mblob_reg_alignment =
#if TVM_LLVM_VERSION >= 110
module->getDataLayout().getABITypeAlign(int32_ty);
Expand All @@ -121,13 +124,12 @@ std::unique_ptr<llvm::Module> CodeGenBlob(const std::string& data, bool system_l
tvm_dev_mblob_reg->setAlignment(tvm_dev_mblob_reg_alignment);
#endif

auto* tvm_dev_mblob_string_ty =
llvm::ArrayType::get(int8_ty, std::strlen(runtime::symbol::tvm_dev_mblob) + 1);
auto* tvm_dev_mblob_string_ty = llvm::ArrayType::get(int8_ty, mdev_blob_name.length() + 1);
auto* tvm_dev_mblob_string_value =
llvm::ConstantDataArray::getString(*ctx, runtime::symbol::tvm_dev_mblob, true);
llvm::ConstantDataArray::getString(*ctx, mdev_blob_name, true);
auto* tvm_dev_mblob_string = new llvm::GlobalVariable(
*module, tvm_dev_mblob_string_ty, true, llvm::GlobalValue::PrivateLinkage,
tvm_dev_mblob_string_value, std::string(runtime::symbol::tvm_dev_mblob) + ".str");
tvm_dev_mblob_string_value, mdev_blob_name + ".str");
#if TVM_LLVM_VERSION >= 100
tvm_dev_mblob_string->setAlignment(llvm::Align(1));
#else
Expand Down
4 changes: 3 additions & 1 deletion src/target/llvm/codegen_blob.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ class LLVMTarget;
* \param data Blob data
* \param system_lib Whether expose as system library.
* \param target_triple LLVM target triple
* \param c_symbol prefix The C symbol prefix of the blob.
*
* \return LLVM module and LLVM context
*/
std::unique_ptr<llvm::Module> CodeGenBlob(const std::string& data, bool system_lib,
LLVMTarget* llvm_target);
LLVMTarget* llvm_target,
const std::string& c_symbol_prefix = "");

} // namespace codegen
} // namespace tvm
Expand Down
Loading

0 comments on commit be88b90

Please sign in to comment.