Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions docs/arch/introduction_to_module_serialization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ before. The ``import_tree_logic`` is just to write ``import_tree_row_ptr_``
and ``import_tree_child_indices_`` into stream.

After this step, we will pack it into a symbol
``runtime::symbol::tvm_dev_mblob`` that can be recovered in the dynamic
``runtime::symbol::tvm_ffi_library_bin`` that can be recovered in the dynamic
library.

Now, we complete the serialization part. As you have seen, we could
Expand All @@ -152,18 +152,19 @@ according to the function logic, we will call ``module.loadfile_so`` in
.. code:: c++

// Load the imported modules
const char* dev_mblob = reinterpret_cast<const char*>(lib->GetSymbol(runtime::symbol::tvm_dev_mblob));
const char* library_bin = reinterpret_cast<const char*>(
lib->GetSymbol(runtime::symbol::tvm_ffi_library_bin));
Module root_mod;
if (dev_mblob != nullptr) {
root_mod = ProcessModuleBlob(dev_mblob, lib);
if (library_bin != nullptr) {
root_mod = ProcessLibraryBin(library_bin, lib);
} else {
// Only have one single DSO Module
root_mod = Module(n);
// Only have one single DSO Module
root_mod = Module(n);
}

As said before, we will pack the blob into the symbol
``runtime::symbol::tvm_dev_mblob``. During deserialization part, we will
inspect it. If we have ``runtime::symbol::tvm_dev_mblob``, we will call ``ProcessModuleBlob``,
``runtime::symbol::tvm_ffi_library_bin``. During deserialization part, we will
inspect it. If we have ``runtime::symbol::tvm_ffi_library_bin``, we will call ``ProcessLibraryBin``,
whose logic like this:

.. code:: c++
Expand Down
4 changes: 2 additions & 2 deletions ffi/include/tvm/ffi/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ typedef struct {
*
* Safe call explicitly catches exception on function boundary.
*
* \param self The function handle
* \param handle The function handle
* \param num_args Number of input arguments
* \param args The input arguments to the call.
* \param result Store output result.
Expand All @@ -278,7 +278,7 @@ typedef struct {
* \sa TVMFFIErrorSetRaised
* \sa TVMFFIErrorSetRaisedByCStr
*/
typedef int (*TVMFFISafeCallType)(void* self, const TVMFFIAny* args, int32_t num_args,
typedef int (*TVMFFISafeCallType)(void* handle, const TVMFFIAny* args, int32_t num_args,
TVMFFIAny* result);

/*!
Expand Down
16 changes: 4 additions & 12 deletions include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,10 @@ inline ffi::Function Module::GetFunction(const String& name, bool query_imports)

/*! \brief namespace for constant symbols */
namespace symbol {
/*! \brief A ffi::Function that retrieves exported metadata. */
constexpr const char* tvm_get_c_metadata = "get_c_metadata";
/*! \brief Global variable to store module context. */
constexpr const char* tvm_module_ctx = "__tvm_module_ctx";
/*! \brief Global variable to store device module blob */
constexpr const char* tvm_dev_mblob = "__tvm_dev_mblob";
/*! \brief Global variable to store context pointer for a library module. */
constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi_library_ctx";
/*! \brief Global variable to store binary data alongside a library module. */
constexpr const char* tvm_ffi_library_bin = "__tvm_ffi_library_bin";
/*! \brief global function to set device */
constexpr const char* tvm_set_device = "__tvm_set_device";
/*! \brief Auxiliary counter to global barrier. */
Expand All @@ -300,12 +298,6 @@ constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier";
/*! \brief Placeholder for the module's entry function. */
constexpr const char* tvm_module_main = "__tvm_main__";
/*! \brief Prefix for parameter symbols emitted into the main program. */
constexpr const char* tvm_param_prefix = "__tvm_param__";
/*! \brief A ffi::Function that looks up linked parameters by storage_id. */
constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param";
/*! \brief Model entrypoint generated as an interface to the AOT function outside of TIR */
constexpr const char* tvm_entrypoint_suffix = "run";
} // namespace symbol

// implementations of inline functions.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ffi/cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ cdef extern from "tvm/ffi/c_api.h":
void (*update_traceback)(TVMFFIObjectHandle self, const TVMFFIByteArray* traceback)

ctypedef int (*TVMFFISafeCallType)(
void* ctx, const TVMFFIAny* args, int32_t num_args,
void* handle, const TVMFFIAny* args, int32_t num_args,
TVMFFIAny* result) noexcept

int TVMFFIObjectFree(TVMFFIObjectHandle obj) nogil
Expand Down
13 changes: 7 additions & 6 deletions src/runtime/library_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) {
* \param root_module the output root module
* \param dso_ctx_addr the output dso module
*/
void ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib,
void ProcessLibraryBin(const char* mblob, ObjectPtr<Library> lib,
FFIFunctionWrapper packed_func_wrapper, runtime::Module* root_module,
runtime::ModuleNode** dso_ctx_addr = nullptr) {
ICHECK(mblob != nullptr);
Expand Down Expand Up @@ -184,21 +184,22 @@ Module CreateModuleFromLibrary(ObjectPtr<Library> lib, FFIFunctionWrapper packed
InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); });
auto n = make_object<LibraryModuleNode>(lib, packed_func_wrapper);
// Load the imported modules
const char* dev_mblob =
reinterpret_cast<const char*>(lib->GetSymbol(runtime::symbol::tvm_dev_mblob));
const char* library_bin =
reinterpret_cast<const char*>(lib->GetSymbol(runtime::symbol::tvm_ffi_library_bin));

Module root_mod;
runtime::ModuleNode* dso_ctx_addr = nullptr;
if (dev_mblob != nullptr) {
ProcessModuleBlob(dev_mblob, lib, packed_func_wrapper, &root_mod, &dso_ctx_addr);
if (library_bin != nullptr) {
ProcessLibraryBin(library_bin, lib, packed_func_wrapper, &root_mod, &dso_ctx_addr);
} else {
// Only have one single DSO Module
root_mod = Module(n);
dso_ctx_addr = root_mod.operator->();
}

// allow lookup of symbol from root (so all symbols are visible).
if (auto* ctx_addr = reinterpret_cast<void**>(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) {
if (auto* ctx_addr =
reinterpret_cast<void**>(lib->GetSymbol(runtime::symbol::tvm_ffi_library_ctx))) {
*ctx_addr = dso_ctx_addr;
}

Expand Down
4 changes: 2 additions & 2 deletions src/target/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib,
<< "c_symbol_prefix advanced option should be used in conjuction with system-lib";
}

std::string mdev_blob_name = c_symbol_prefix + runtime::symbol::tvm_dev_mblob;
std::string mdev_blob_name = c_symbol_prefix + runtime::symbol::tvm_ffi_library_bin;
std::string blob = PackImportsToBytes(mod);

// translate to C program
Expand Down Expand Up @@ -365,7 +365,7 @@ TVM_FFI_REGISTER_GLOBAL("target.Build").set_body_typed(Build);

// Export a few auxiliary function to the runtime namespace.
TVM_FFI_REGISTER_GLOBAL("runtime.ModuleImportsBlobName").set_body_typed([]() -> std::string {
return runtime::symbol::tvm_dev_mblob;
return runtime::symbol::tvm_ffi_library_bin;
});

TVM_FFI_REGISTER_GLOBAL("runtime.ModulePackImportsToNDArray")
Expand Down
43 changes: 22 additions & 21 deletions src/target/llvm/codegen_blob.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ std::unique_ptr<llvm::Module> CodeGenBlob(const std::string& data, bool system_l
llvm_target->SetTargetMetadata(module.get());
module->setDataLayout(tm->createDataLayout());
auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false);
std::string mdev_blob_name = c_symbol_prefix + runtime::symbol::tvm_dev_mblob;
std::string mdev_blob_name = c_symbol_prefix + runtime::symbol::tvm_ffi_library_bin;

auto* tvm_dev_mblob = new llvm::GlobalVariable(
auto* tvm_ffi_library_bin = new llvm::GlobalVariable(
*module, blob_value->getType(), true, llvm::GlobalValue::ExternalLinkage, blob_value,
mdev_blob_name, nullptr, llvm::GlobalVariable::NotThreadLocal, 0);

Expand All @@ -88,17 +88,17 @@ std::unique_ptr<llvm::Module> CodeGenBlob(const std::string& data, bool system_l
const size_t large_data_threshold = 1 << 30;
if (data.size() > large_data_threshold && triple.getArch() == llvm::Triple::x86_64 &&
triple.isOSBinFormatELF()) {
tvm_dev_mblob->setSection(".lrodata");
tvm_ffi_library_bin->setSection(".lrodata");
}

#if TVM_LLVM_VERSION >= 100
tvm_dev_mblob->setAlignment(llvm::Align(1));
tvm_ffi_library_bin->setAlignment(llvm::Align(1));
#else
tvm_dev_mblob->setAlignment(1);
tvm_ffi_library_bin->setAlignment(1);
#endif

if (triple.isOSWindows()) {
tvm_dev_mblob->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass);
tvm_ffi_library_bin->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass);
}

if (system_lib) {
Expand All @@ -109,31 +109,32 @@ std::unique_ptr<llvm::Module> CodeGenBlob(const std::string& data, bool system_l
auto int8_ptr_ty = llvmGetPointerTo(int8_ty, 0);

llvm::Constant* constant_zero = llvm::Constant::getNullValue(int32_ty);
auto* tvm_dev_mblob_reg =
auto* tvm_ffi_library_bin_reg =
new llvm::GlobalVariable(*module, int32_ty, false, llvm::GlobalValue::InternalLinkage,
constant_zero, mdev_blob_name + "_reg_");
auto tvm_dev_mblob_reg_alignment =
auto tvm_ffi_library_bin_reg_alignment =
#if TVM_LLVM_VERSION >= 110
module->getDataLayout().getABITypeAlign(int32_ty);
#else
module->getDataLayout().getABITypeAlignment(int32_ty);
#endif
#if TVM_LLVM_VERSION >= 100
tvm_dev_mblob_reg->setAlignment(llvm::Align(tvm_dev_mblob_reg_alignment));
tvm_ffi_library_bin_reg->setAlignment(llvm::Align(tvm_ffi_library_bin_reg_alignment));
#else
tvm_dev_mblob_reg->setAlignment(tvm_dev_mblob_reg_alignment);
tvm_ffi_library_bin_reg->setAlignment(tvm_ffi_library_bin_reg_alignment);
#endif

auto* tvm_dev_mblob_string_ty = llvm::ArrayType::get(int8_ty, mdev_blob_name.length() + 1);
auto* tvm_dev_mblob_string_value =
auto* tvm_ffi_library_bin_string_ty =
llvm::ArrayType::get(int8_ty, mdev_blob_name.length() + 1);
auto* tvm_ffi_library_bin_string_value =
llvm::ConstantDataArray::getString(*ctx, mdev_blob_name, true);
auto* tvm_dev_mblob_string = new llvm::GlobalVariable(
*module, tvm_dev_mblob_string_ty, true, llvm::GlobalValue::PrivateLinkage,
tvm_dev_mblob_string_value, mdev_blob_name + ".str");
auto* tvm_ffi_library_bin_string = new llvm::GlobalVariable(
*module, tvm_ffi_library_bin_string_ty, true, llvm::GlobalValue::PrivateLinkage,
tvm_ffi_library_bin_string_value, mdev_blob_name + ".str");
#if TVM_LLVM_VERSION >= 100
tvm_dev_mblob_string->setAlignment(llvm::Align(1));
tvm_ffi_library_bin_string->setAlignment(llvm::Align(1));
#else
tvm_dev_mblob_string->setAlignment(1);
tvm_ffi_library_bin_string->setAlignment(1);
#endif

// Global init function
Expand Down Expand Up @@ -185,12 +186,12 @@ std::unique_ptr<llvm::Module> CodeGenBlob(const std::string& data, bool system_l
ir_builder.SetInsertPoint(var_init_fn_bb);
llvm::Constant* indices[] = {constant_zero, constant_zero};
llvm::SmallVector<llvm::Value*, 2> args;
args.push_back(llvm::ConstantExpr::getGetElementPtr(tvm_dev_mblob_string_ty,
tvm_dev_mblob_string, indices));
args.push_back(llvm::ConstantExpr::getGetElementPtr(tvm_ffi_library_bin_string_ty,
tvm_ffi_library_bin_string, indices));
args.push_back(
llvm::ConstantExpr::getGetElementPtr(blob_value->getType(), tvm_dev_mblob, indices));
llvm::ConstantExpr::getGetElementPtr(blob_value->getType(), tvm_ffi_library_bin, indices));
auto* tvm_backend_fn_ret_value = ir_builder.CreateCall(tvm_backend_fn, args);
ir_builder.CreateStore(tvm_backend_fn_ret_value, tvm_dev_mblob_reg);
ir_builder.CreateStore(tvm_backend_fn_ret_value, tvm_ffi_library_bin_reg);
ir_builder.CreateRetVoid();
}

Expand Down
3 changes: 2 additions & 1 deletion src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,8 @@ llvm::Value* CodeGenCPU::GetContextPtr(llvm::GlobalVariable* gv) {
}

void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) {
std::string ctx_symbol = system_lib_prefix_.value_or("") + tvm::runtime::symbol::tvm_module_ctx;
std::string ctx_symbol =
system_lib_prefix_.value_or("") + tvm::runtime::symbol::tvm_ffi_library_ctx;
// Module context
gv_mod_ctx_ = InitContextPtr(t_void_p_, ctx_symbol);
// Register back the locations.
Expand Down
8 changes: 4 additions & 4 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,8 @@ void LLVMModuleNode::InitMCJIT() {
// run ctors
mcjit_ee_->runStaticConstructorsDestructors(false);

if (void** ctx_addr =
reinterpret_cast<void**>(GetGlobalAddr(runtime::symbol::tvm_module_ctx, *llvm_target))) {
if (void** ctx_addr = reinterpret_cast<void**>(
GetGlobalAddr(runtime::symbol::tvm_ffi_library_ctx, *llvm_target))) {
*ctx_addr = this;
}
runtime::InitContextFunctions(
Expand Down Expand Up @@ -561,8 +561,8 @@ void LLVMModuleNode::InitORCJIT() {
err = ctorRunner.run();
ICHECK(!err) << llvm::toString(std::move(err));

if (void** ctx_addr =
reinterpret_cast<void**>(GetGlobalAddr(runtime::symbol::tvm_module_ctx, *llvm_target))) {
if (void** ctx_addr = reinterpret_cast<void**>(
GetGlobalAddr(runtime::symbol::tvm_ffi_library_ctx, *llvm_target))) {
*ctx_addr = this;
}
runtime::InitContextFunctions(
Expand Down
4 changes: 2 additions & 2 deletions src/target/source/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
namespace tvm {
namespace codegen {

CodeGenCHost::CodeGenCHost() { module_name_ = name_supply_->FreshName("__tvm_module_ctx"); }
CodeGenCHost::CodeGenCHost() { module_name_ = name_supply_->FreshName("__tvm_ffi_library_ctx"); }

void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl,
std::string target_str, const std::unordered_set<std::string>& devices) {
Expand All @@ -53,7 +53,7 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_d
}

void CodeGenCHost::InitGlobalContext() {
decl_stream << "void* " << tvm::runtime::symbol::tvm_module_ctx << " = NULL;\n";
decl_stream << "void* " << tvm::runtime::symbol::tvm_ffi_library_ctx << " = NULL;\n";
}

void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; }
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
}
}

// signature: (void* self, TVMFFIAny* packed_args, int num_args, TVMFFIAny* v_result)
// signature: (void* handle, TVMFFIAny* packed_args, int num_args, TVMFFIAny* v_result)
Array<Var> args{v_self_handle, v_packed_args, v_num_packed_args, v_result};

// Arg definitions are defined before buffer binding to avoid the use before
Expand Down
Loading