Skip to content

Commit

Permalink
[LLVM][RUNTIME] Enable multi systemlib with device code (#14843)
Browse files Browse the repository at this point in the history
This PR enables combination of multiple system lib
into the same static library with a system_lib_prefix attribute.
This can open doors for multiple models to be compiled separately
then packaged into the same app via static library.

It resolves a previous issue that prevents multiple system
lib to be linked together when they come with extra binary component
such as CUDA due to symbol conflict.
  • Loading branch information
tqchen authored May 13, 2023
1 parent e54bbc7 commit 05001be
Show file tree
Hide file tree
Showing 20 changed files with 322 additions and 137 deletions.
29 changes: 29 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,35 @@ constexpr const char* kConstants = "constants";
*/
constexpr const char* kExternalMods = "external_mods";

/*!
* \brief A prefix for generating C symbols system lib creation.
*
* This prefix guides passes that creates global_symbol for internal functions
* that may have c linkage (e.g. TIR functions and some BYOC functions). It also affects
* the symbol of the fat bin blob during module export.
*
* This attribute is used to avoid symbol conflict when we
* generate and combine multiple system libs that get linked into one.
*
* Rationale: mechanisms like BYOC rely on the common global symbol
* and each external compiler also has its own mechanism of mangling.
* As a result, we cannot rely on other mechanisms on setting a global_symbol and then renaming,
* because the external compiler already agreed on the name.
*
* system_lib_prefix provides a way to hint at the passes to allow names to
* avoid name conflict at the beginning.
*
* Note that users can still directly specify global symbols that may conflict.
* It is up to the downstream toolchain to manage those external-facing functions.
*
* This does not affect non-C linkage functions it is less of an issue because
* they will be embedded into fatbin that in different symbols,
* The system lib loader can pick the right prefix for a given prefix.
*
* Having this attribute implies system lib generation linkage.
*/
constexpr const char* kSystemLibPrefix = "system_lib_prefix";

/*!
* \brief All the named runtime::NDArrays accumulated during compilation by external codegen.
* Generally the associated runtime::Module will indicate it requires bindings for these names,
Expand Down
2 changes: 0 additions & 2 deletions include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,6 @@ constexpr const char* tvm_get_c_metadata = "get_c_metadata";
constexpr const char* tvm_module_ctx = "__tvm_module_ctx";
/*! \brief Global variable to store device module blob */
constexpr const char* tvm_dev_mblob = "__tvm_dev_mblob";
/*! \brief Number of bytes of device module blob. */
constexpr const char* tvm_dev_mblob_nbytes = "__tvm_dev_mblob_nbytes";
/*! \brief global function to set device */
constexpr const char* tvm_set_device = "__tvm_set_device";
/*! \brief Auxiliary counter to global barrier. */
Expand Down
9 changes: 7 additions & 2 deletions include/tvm/target/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ runtime::Module Build(IRModule mod, Target target);
*
* \param m The host module with the imports.
* \param system_lib Whether expose as system library.
* \param c_symbol_prefix Optional symbol prefix of the blob symbol.
* \return cstr The C string representation of the file.
*/
std::string PackImportsToC(const runtime::Module& m, bool system_lib);
std::string PackImportsToC(const runtime::Module& m, bool system_lib,
const std::string& c_symbol_prefix = "");

/*!
* \brief Pack imported device library to a LLVM module.
Expand All @@ -68,10 +70,13 @@ std::string PackImportsToC(const runtime::Module& m, bool system_lib);
* \param m The host module with the imports.
* \param system_lib Whether expose as system library.
* \param target_triple LLVM target triple
* \param c_symbol_prefix Optional symbol prefix of the blob symbol.
*
* \return runtime::Module The generated LLVM module.
*/
runtime::Module PackImportsToLLVM(const runtime::Module& m, bool system_lib,
const std::string& target_triple);
const std::string& target_triple,
const std::string& c_symbol_prefix = "");
} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_CODEGEN_H_
27 changes: 21 additions & 6 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No
files = addons if addons else []
is_system_lib = False
has_c_module = False
system_lib_prefix = None
llvm_target_string = None
global_object_format = "o"
for index, module in enumerate(modules):
Expand Down Expand Up @@ -549,6 +550,8 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No
if module.type_key == "llvm":
is_system_lib = module.get_function("__tvm_is_system_module")()
llvm_target_string = module.get_function("_get_target_string")()
system_lib_prefix = module.get_function("__tvm_get_system_lib_prefix")()

if not fcompile:
if file_name.endswith(".tar"):
fcompile = _tar.tar
Expand All @@ -564,15 +567,21 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No
raise ValueError("%s need --system-lib option" % str(fcompile))

if self.imported_modules:
pack_lib_prefix = system_lib_prefix if system_lib_prefix else ""

if enabled("llvm") and llvm_target_string:
path_obj = os.path.join(workspace_dir, f"devc.{global_object_format}")
m = _ffi_api.ModulePackImportsToLLVM(self, is_system_lib, llvm_target_string)
path_obj = os.path.join(
workspace_dir, f"{pack_lib_prefix}devc.{global_object_format}"
)
m = _ffi_api.ModulePackImportsToLLVM(
self, is_system_lib, llvm_target_string, pack_lib_prefix
)
m.save(path_obj)
files.append(path_obj)
else:
path_cc = os.path.join(workspace_dir, "devc.c")
path_cc = os.path.join(workspace_dir, f"{pack_lib_prefix}devc.c")
with open(path_cc, "w") as f:
f.write(_ffi_api.ModulePackImportsToC(self, is_system_lib))
f.write(_ffi_api.ModulePackImportsToC(self, is_system_lib, pack_lib_prefix))
files.append(path_cc)

# The imports could contain a c module but the object format could be tar
Expand All @@ -589,7 +598,7 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No
return fcompile(file_name, files, **kwargs)


def system_lib():
def system_lib(symbol_prefix=""):
"""Get system-wide library module singleton.
System lib is a global module that contains self register functions in startup.
Expand All @@ -602,12 +611,18 @@ def system_lib():
The system lib is intended to be linked and loaded during the entire life-cyle of the program.
If you want dynamic loading features, use dso modules instead.
Parameters
----------
symbol_prefix: Optional[str]
Optional symbol prefix that can be used for search. When we lookup a symbol
symbol_prefix + name will first be searched, then the name without symbol_prefix.
Returns
-------
module : runtime.Module
The system-wide library module.
"""
return _ffi_api.SystemLib()
return _ffi_api.SystemLib(symbol_prefix)


def load_module(path, fmt=""):
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/library_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ ObjectPtr<Library> CreateDSOLibraryObject(std::string library_path);
* \param lib The library.
* \param wrapper Optional function used to wrap a TVMBackendPackedCFunc,
* by default WrapPackedFunc is used.
* \param symbol_prefix Optional symbol prefix that can be used to search alternative symbols.
*
* \return The corresponding loaded module.
*
* \note This function can create multiple linked modules
Expand Down
58 changes: 39 additions & 19 deletions src/runtime/system_library.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,8 @@
namespace tvm {
namespace runtime {

class SystemLibrary : public Library {
class SystemLibraryRegistry {
public:
SystemLibrary() = default;

void* GetSymbol(const char* name) final {
std::lock_guard<std::mutex> lock(mutex_);
auto it = tbl_.find(name);
if (it != tbl_.end()) {
return it->second;
} else {
return nullptr;
}
}

void RegisterSymbol(const std::string& name, void* ptr) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = tbl_.find(name);
Expand All @@ -56,8 +44,18 @@ class SystemLibrary : public Library {
tbl_[name] = ptr;
}

static const ObjectPtr<SystemLibrary>& Global() {
static auto inst = make_object<SystemLibrary>();
void* GetSymbol(const char* name) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = tbl_.find(name);
if (it != tbl_.end()) {
return it->second;
} else {
return nullptr;
}
}

static SystemLibraryRegistry* Global() {
static SystemLibraryRegistry* inst = new SystemLibraryRegistry();
return inst;
}

Expand All @@ -68,14 +66,36 @@ class SystemLibrary : public Library {
std::unordered_map<std::string, void*> tbl_;
};

TVM_REGISTER_GLOBAL("runtime.SystemLib").set_body_typed([]() {
static auto mod = CreateModuleFromLibrary(SystemLibrary::Global());
return mod;
class SystemLibrary : public Library {
public:
explicit SystemLibrary(const std::string& symbol_prefix) : symbol_prefix_(symbol_prefix) {}

void* GetSymbol(const char* name) {
if (symbol_prefix_.length() != 0) {
std::string name_with_prefix = symbol_prefix_ + name;
void* symbol = reg_->GetSymbol(name_with_prefix.c_str());
if (symbol != nullptr) return symbol;
}
return reg_->GetSymbol(name);
}

private:
SystemLibraryRegistry* reg_ = SystemLibraryRegistry::Global();
std::string symbol_prefix_;
};

TVM_REGISTER_GLOBAL("runtime.SystemLib").set_body([](TVMArgs args, TVMRetValue* rv) {
std::string symbol_prefix = "";
if (args.size() != 0) {
symbol_prefix = args[0].operator std::string();
}
auto mod = CreateModuleFromLibrary(make_object<SystemLibrary>(symbol_prefix));
*rv = mod;
});
} // namespace runtime
} // namespace tvm

int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr) {
tvm::runtime::SystemLibrary::Global()->RegisterSymbol(name, ptr);
tvm::runtime::SystemLibraryRegistry::Global()->RegisterSymbol(name, ptr);
return 0;
}
31 changes: 22 additions & 9 deletions src/target/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,15 @@ std::string SerializeModule(const runtime::Module& mod) {
}
} // namespace

std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
std::string PackImportsToC(const runtime::Module& mod, bool system_lib,
const std::string& c_symbol_prefix) {
std::string bin = SerializeModule(mod);
std::string mdev_blob_name = c_symbol_prefix + runtime::symbol::tvm_dev_mblob;

if (c_symbol_prefix.length() != 0) {
CHECK(system_lib)
<< "c_symbol_prefix advanced option should be used in conjuction with system-lib";
}

// translate to C program
std::ostringstream os;
Expand All @@ -253,10 +260,10 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
os << "#ifdef __cplusplus\n"
<< "extern \"C\" {\n"
<< "#endif\n";
os << "TVM_EXPORT extern const unsigned char " << runtime::symbol::tvm_dev_mblob << "[];\n";
os << "TVM_EXPORT extern const unsigned char " << mdev_blob_name << "[];\n";
uint64_t nbytes = bin.length();
os << "const unsigned char " << runtime::symbol::tvm_dev_mblob << "["
<< bin.length() + sizeof(nbytes) << "] = {\n ";
os << "const unsigned char " << mdev_blob_name << "[" << bin.length() + sizeof(nbytes)
<< "] = {\n ";
os << std::hex;
size_t nunit = 80 / 4;
for (size_t i = 0; i < sizeof(nbytes); ++i) {
Expand All @@ -279,9 +286,9 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
os << "\n};\n";
if (system_lib) {
os << "extern int TVMBackendRegisterSystemLibSymbol(const char*, void*);\n";
os << "static int " << runtime::symbol::tvm_dev_mblob << "_reg_ = "
<< "TVMBackendRegisterSystemLibSymbol(\"" << runtime::symbol::tvm_dev_mblob << "\", (void*)"
<< runtime::symbol::tvm_dev_mblob << ");\n";
os << "static int " << mdev_blob_name << "_reg_ = "
<< "TVMBackendRegisterSystemLibSymbol(\"" << mdev_blob_name << "\", (void*)"
<< mdev_blob_name << ");\n";
}
os << "#ifdef __cplusplus\n"
<< "}\n"
Expand All @@ -290,7 +297,13 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
}

runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib,
const std::string& llvm_target_string) {
const std::string& llvm_target_string,
const std::string& c_symbol_prefix) {
if (c_symbol_prefix.length() != 0) {
CHECK(system_lib)
<< "c_symbol_prefix advanced option should be used in conjuction with system-lib";
}

std::string bin = SerializeModule(mod);

uint64_t nbytes = bin.length();
Expand All @@ -308,7 +321,7 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib,
// the codegen function.
const PackedFunc* codegen_f = runtime::Registry::Get(codegen_f_name);
ICHECK(codegen_f != nullptr) << "codegen.codegen_blob is not presented.";
return (*codegen_f)(blob_byte_array, system_lib, llvm_target_string);
return (*codegen_f)(blob_byte_array, system_lib, llvm_target_string, c_symbol_prefix);
}

TVM_REGISTER_GLOBAL("target.Build").set_body_typed(Build);
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ runtime::Module BuildAMDGPU(IRModule mod, Target target) {
#endif
auto cg = std::make_unique<CodeGenAMDGPU>();

cg->Init("TVMAMDGPUModule", llvm_target.get(), false, false, false);
cg->Init("TVMAMDGPUModule", llvm_target.get(), NullOpt, false, false);

cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) {
ICHECK(kv.second->template IsInstance<PrimFuncNode>())
Expand Down
22 changes: 12 additions & 10 deletions src/target/llvm/codegen_blob.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,22 @@ namespace tvm {
namespace codegen {

std::unique_ptr<llvm::Module> CodeGenBlob(const std::string& data, bool system_lib,
LLVMTarget* llvm_target) {
LLVMTarget* llvm_target,
const std::string& c_symbol_prefix) {
llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine();
const llvm::Triple& triple = tm->getTargetTriple();
llvm::LLVMContext* ctx = llvm_target->GetContext();
std::string module_name = "devc";
std::string module_name = c_symbol_prefix + "devc";
auto module = std::make_unique<llvm::Module>(module_name, *ctx);
module->setTargetTriple(triple.str());
llvm_target->SetTargetMetadata(module.get());
module->setDataLayout(tm->createDataLayout());
auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false);
std::string mdev_blob_name = c_symbol_prefix + runtime::symbol::tvm_dev_mblob;

auto* tvm_dev_mblob = new llvm::GlobalVariable(
*module, blob_value->getType(), true, llvm::GlobalValue::ExternalLinkage, blob_value,
runtime::symbol::tvm_dev_mblob, nullptr, llvm::GlobalVariable::NotThreadLocal, 0);
mdev_blob_name, nullptr, llvm::GlobalVariable::NotThreadLocal, 0);

// If large const data (>2GB) is saved to default .rodata section
// then linking it to shared library will fail - relocation truncated to fit: R_X86_64_PC32.
Expand Down Expand Up @@ -106,9 +109,9 @@ std::unique_ptr<llvm::Module> CodeGenBlob(const std::string& data, bool system_l
auto int8_ptr_ty = int8_ty->getPointerTo(0);

llvm::Constant* constant_zero = llvm::Constant::getNullValue(int32_ty);
auto* tvm_dev_mblob_reg = new llvm::GlobalVariable(
*module, int32_ty, false, llvm::GlobalValue::InternalLinkage, constant_zero,
std::string(runtime::symbol::tvm_dev_mblob) + "_reg_");
auto* tvm_dev_mblob_reg =
new llvm::GlobalVariable(*module, int32_ty, false, llvm::GlobalValue::InternalLinkage,
constant_zero, mdev_blob_name + "_reg_");
auto tvm_dev_mblob_reg_alignment =
#if TVM_LLVM_VERSION >= 110
module->getDataLayout().getABITypeAlign(int32_ty);
Expand All @@ -121,13 +124,12 @@ std::unique_ptr<llvm::Module> CodeGenBlob(const std::string& data, bool system_l
tvm_dev_mblob_reg->setAlignment(tvm_dev_mblob_reg_alignment);
#endif

auto* tvm_dev_mblob_string_ty =
llvm::ArrayType::get(int8_ty, std::strlen(runtime::symbol::tvm_dev_mblob) + 1);
auto* tvm_dev_mblob_string_ty = llvm::ArrayType::get(int8_ty, mdev_blob_name.length() + 1);
auto* tvm_dev_mblob_string_value =
llvm::ConstantDataArray::getString(*ctx, runtime::symbol::tvm_dev_mblob, true);
llvm::ConstantDataArray::getString(*ctx, mdev_blob_name, true);
auto* tvm_dev_mblob_string = new llvm::GlobalVariable(
*module, tvm_dev_mblob_string_ty, true, llvm::GlobalValue::PrivateLinkage,
tvm_dev_mblob_string_value, std::string(runtime::symbol::tvm_dev_mblob) + ".str");
tvm_dev_mblob_string_value, mdev_blob_name + ".str");
#if TVM_LLVM_VERSION >= 100
tvm_dev_mblob_string->setAlignment(llvm::Align(1));
#else
Expand Down
4 changes: 3 additions & 1 deletion src/target/llvm/codegen_blob.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ class LLVMTarget;
* \param data Blob data
* \param system_lib Whether expose as system library.
* \param target_triple LLVM target triple
* \param c_symbol prefix The C symbol prefix of the blob.
*
* \return LLVM module and LLVM context
*/
std::unique_ptr<llvm::Module> CodeGenBlob(const std::string& data, bool system_lib,
LLVMTarget* llvm_target);
LLVMTarget* llvm_target,
const std::string& c_symbol_prefix = "");

} // namespace codegen
} // namespace tvm
Expand Down
Loading

0 comments on commit 05001be

Please sign in to comment.