From 0e4dce9c8a051080911aa1b39d5caa8cf9640aec Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 5 Jun 2025 11:33:45 -0400 Subject: [PATCH] [REFACTOR][FFI] Update symbol name for library module This PR updates the symbol name for library module to reflect the latest design. The goal is to stablize the name so we can also bring module to ffi layer. --- .../introduction_to_module_serialization.rst | 17 ++++---- ffi/include/tvm/ffi/c_api.h | 4 +- include/tvm/runtime/module.h | 16 ++----- python/tvm/ffi/cython/base.pxi | 2 +- src/runtime/library_module.cc | 13 +++--- src/target/codegen.cc | 4 +- src/target/llvm/codegen_blob.cc | 43 ++++++++++--------- src/target/llvm/codegen_cpu.cc | 3 +- src/target/llvm/llvm_module.cc | 8 ++-- src/target/source/codegen_c_host.cc | 4 +- src/tir/transforms/make_packed_api.cc | 2 +- 11 files changed, 56 insertions(+), 60 deletions(-) diff --git a/docs/arch/introduction_to_module_serialization.rst b/docs/arch/introduction_to_module_serialization.rst index 49cb7d8e4554..50eeb6df2277 100644 --- a/docs/arch/introduction_to_module_serialization.rst +++ b/docs/arch/introduction_to_module_serialization.rst @@ -133,7 +133,7 @@ before. The ``import_tree_logic`` is just to write ``import_tree_row_ptr_`` and ``import_tree_child_indices_`` into stream. After this step, we will pack it into a symbol -``runtime::symbol::tvm_dev_mblob`` that can be recovered in the dynamic +``runtime::symbol::tvm_ffi_library_bin`` that can be recovered in the dynamic library. Now, we complete the serialization part. As you have seen, we could @@ -152,18 +152,19 @@ according to the function logic, we will call ``module.loadfile_so`` in .. code:: c++ // Load the imported modules - const char* dev_mblob = reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_dev_mblob)); + const char* library_bin = reinterpret_cast( + lib->GetSymbol(runtime::symbol::tvm_ffi_library_bin)); Module root_mod; - if (dev_mblob != nullptr) { - root_mod = ProcessModuleBlob(dev_mblob, lib); + if (library_bin != nullptr) { + root_mod = ProcessLibraryBin(library_bin, lib); } else { - // Only have one single DSO Module - root_mod = Module(n); + // Only have one single DSO Module + root_mod = Module(n); } As said before, we will pack the blob into the symbol -``runtime::symbol::tvm_dev_mblob``. During deserialization part, we will -inspect it. If we have ``runtime::symbol::tvm_dev_mblob``, we will call ``ProcessModuleBlob``, +``runtime::symbol::tvm_ffi_library_bin``. During deserialization part, we will +inspect it. If we have ``runtime::symbol::tvm_ffi_library_bin``, we will call ``ProcessLibraryBin``, whose logic like this: .. code:: c++ diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 996eaa369b3b..7cf7543f482d 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -251,7 +251,7 @@ typedef struct { * * Safe call explicitly catches exception on function boundary. * - * \param self The function handle + * \param handle The function handle * \param num_args Number of input arguments * \param args The input arguments to the call. * \param result Store output result. @@ -278,7 +278,7 @@ typedef struct { * \sa TVMFFIErrorSetRaised * \sa TVMFFIErrorSetRaisedByCStr */ -typedef int (*TVMFFISafeCallType)(void* self, const TVMFFIAny* args, int32_t num_args, +typedef int (*TVMFFISafeCallType)(void* handle, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* result); /*! diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 705fb276d9e7..efbaa6508af5 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -286,12 +286,10 @@ inline ffi::Function Module::GetFunction(const String& name, bool query_imports) /*! \brief namespace for constant symbols */ namespace symbol { -/*! \brief A ffi::Function that retrieves exported metadata. */ -constexpr const char* tvm_get_c_metadata = "get_c_metadata"; -/*! \brief Global variable to store module context. */ -constexpr const char* tvm_module_ctx = "__tvm_module_ctx"; -/*! \brief Global variable to store device module blob */ -constexpr const char* tvm_dev_mblob = "__tvm_dev_mblob"; +/*! \brief Global variable to store context pointer for a library module. */ +constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi_library_ctx"; +/*! \brief Global variable to store binary data alongside a library module. */ +constexpr const char* tvm_ffi_library_bin = "__tvm_ffi_library_bin"; /*! \brief global function to set device */ constexpr const char* tvm_set_device = "__tvm_set_device"; /*! \brief Auxiliary counter to global barrier. */ @@ -300,12 +298,6 @@ constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state"; constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier"; /*! \brief Placeholder for the module's entry function. */ constexpr const char* tvm_module_main = "__tvm_main__"; -/*! \brief Prefix for parameter symbols emitted into the main program. */ -constexpr const char* tvm_param_prefix = "__tvm_param__"; -/*! \brief A ffi::Function that looks up linked parameters by storage_id. */ -constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param"; -/*! \brief Model entrypoint generated as an interface to the AOT function outside of TIR */ -constexpr const char* tvm_entrypoint_suffix = "run"; } // namespace symbol // implementations of inline functions. diff --git a/python/tvm/ffi/cython/base.pxi b/python/tvm/ffi/cython/base.pxi index 8b9c1f3d947b..e18d52fc8d84 100644 --- a/python/tvm/ffi/cython/base.pxi +++ b/python/tvm/ffi/cython/base.pxi @@ -131,7 +131,7 @@ cdef extern from "tvm/ffi/c_api.h": void (*update_traceback)(TVMFFIObjectHandle self, const TVMFFIByteArray* traceback) ctypedef int (*TVMFFISafeCallType)( - void* ctx, const TVMFFIAny* args, int32_t num_args, + void* handle, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* result) noexcept int TVMFFIObjectFree(TVMFFIObjectHandle obj) nogil diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 18f973daf159..39690bd81b9b 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -113,7 +113,7 @@ Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) { * \param root_module the output root module * \param dso_ctx_addr the output dso module */ -void ProcessModuleBlob(const char* mblob, ObjectPtr lib, +void ProcessLibraryBin(const char* mblob, ObjectPtr lib, FFIFunctionWrapper packed_func_wrapper, runtime::Module* root_module, runtime::ModuleNode** dso_ctx_addr = nullptr) { ICHECK(mblob != nullptr); @@ -184,13 +184,13 @@ Module CreateModuleFromLibrary(ObjectPtr lib, FFIFunctionWrapper packed InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); }); auto n = make_object(lib, packed_func_wrapper); // Load the imported modules - const char* dev_mblob = - reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_dev_mblob)); + const char* library_bin = + reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_ffi_library_bin)); Module root_mod; runtime::ModuleNode* dso_ctx_addr = nullptr; - if (dev_mblob != nullptr) { - ProcessModuleBlob(dev_mblob, lib, packed_func_wrapper, &root_mod, &dso_ctx_addr); + if (library_bin != nullptr) { + ProcessLibraryBin(library_bin, lib, packed_func_wrapper, &root_mod, &dso_ctx_addr); } else { // Only have one single DSO Module root_mod = Module(n); @@ -198,7 +198,8 @@ Module CreateModuleFromLibrary(ObjectPtr lib, FFIFunctionWrapper packed } // allow lookup of symbol from root (so all symbols are visible). - if (auto* ctx_addr = reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) { + if (auto* ctx_addr = + reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_ffi_library_ctx))) { *ctx_addr = dso_ctx_addr; } diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 8ddc071cba0f..9d3f1529c81c 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -305,7 +305,7 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib, << "c_symbol_prefix advanced option should be used in conjuction with system-lib"; } - std::string mdev_blob_name = c_symbol_prefix + runtime::symbol::tvm_dev_mblob; + std::string mdev_blob_name = c_symbol_prefix + runtime::symbol::tvm_ffi_library_bin; std::string blob = PackImportsToBytes(mod); // translate to C program @@ -365,7 +365,7 @@ TVM_FFI_REGISTER_GLOBAL("target.Build").set_body_typed(Build); // Export a few auxiliary function to the runtime namespace. TVM_FFI_REGISTER_GLOBAL("runtime.ModuleImportsBlobName").set_body_typed([]() -> std::string { - return runtime::symbol::tvm_dev_mblob; + return runtime::symbol::tvm_ffi_library_bin; }); TVM_FFI_REGISTER_GLOBAL("runtime.ModulePackImportsToNDArray") diff --git a/src/target/llvm/codegen_blob.cc b/src/target/llvm/codegen_blob.cc index 343fdca3eec1..8ea964b6d2d9 100644 --- a/src/target/llvm/codegen_blob.cc +++ b/src/target/llvm/codegen_blob.cc @@ -73,9 +73,9 @@ std::unique_ptr CodeGenBlob(const std::string& data, bool system_l llvm_target->SetTargetMetadata(module.get()); module->setDataLayout(tm->createDataLayout()); auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false); - std::string mdev_blob_name = c_symbol_prefix + runtime::symbol::tvm_dev_mblob; + std::string mdev_blob_name = c_symbol_prefix + runtime::symbol::tvm_ffi_library_bin; - auto* tvm_dev_mblob = new llvm::GlobalVariable( + auto* tvm_ffi_library_bin = new llvm::GlobalVariable( *module, blob_value->getType(), true, llvm::GlobalValue::ExternalLinkage, blob_value, mdev_blob_name, nullptr, llvm::GlobalVariable::NotThreadLocal, 0); @@ -88,17 +88,17 @@ std::unique_ptr CodeGenBlob(const std::string& data, bool system_l const size_t large_data_threshold = 1 << 30; if (data.size() > large_data_threshold && triple.getArch() == llvm::Triple::x86_64 && triple.isOSBinFormatELF()) { - tvm_dev_mblob->setSection(".lrodata"); + tvm_ffi_library_bin->setSection(".lrodata"); } #if TVM_LLVM_VERSION >= 100 - tvm_dev_mblob->setAlignment(llvm::Align(1)); + tvm_ffi_library_bin->setAlignment(llvm::Align(1)); #else - tvm_dev_mblob->setAlignment(1); + tvm_ffi_library_bin->setAlignment(1); #endif if (triple.isOSWindows()) { - tvm_dev_mblob->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass); + tvm_ffi_library_bin->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass); } if (system_lib) { @@ -109,31 +109,32 @@ std::unique_ptr CodeGenBlob(const std::string& data, bool system_l auto int8_ptr_ty = llvmGetPointerTo(int8_ty, 0); llvm::Constant* constant_zero = llvm::Constant::getNullValue(int32_ty); - auto* tvm_dev_mblob_reg = + auto* tvm_ffi_library_bin_reg = new llvm::GlobalVariable(*module, int32_ty, false, llvm::GlobalValue::InternalLinkage, constant_zero, mdev_blob_name + "_reg_"); - auto tvm_dev_mblob_reg_alignment = + auto tvm_ffi_library_bin_reg_alignment = #if TVM_LLVM_VERSION >= 110 module->getDataLayout().getABITypeAlign(int32_ty); #else module->getDataLayout().getABITypeAlignment(int32_ty); #endif #if TVM_LLVM_VERSION >= 100 - tvm_dev_mblob_reg->setAlignment(llvm::Align(tvm_dev_mblob_reg_alignment)); + tvm_ffi_library_bin_reg->setAlignment(llvm::Align(tvm_ffi_library_bin_reg_alignment)); #else - tvm_dev_mblob_reg->setAlignment(tvm_dev_mblob_reg_alignment); + tvm_ffi_library_bin_reg->setAlignment(tvm_ffi_library_bin_reg_alignment); #endif - auto* tvm_dev_mblob_string_ty = llvm::ArrayType::get(int8_ty, mdev_blob_name.length() + 1); - auto* tvm_dev_mblob_string_value = + auto* tvm_ffi_library_bin_string_ty = + llvm::ArrayType::get(int8_ty, mdev_blob_name.length() + 1); + auto* tvm_ffi_library_bin_string_value = llvm::ConstantDataArray::getString(*ctx, mdev_blob_name, true); - auto* tvm_dev_mblob_string = new llvm::GlobalVariable( - *module, tvm_dev_mblob_string_ty, true, llvm::GlobalValue::PrivateLinkage, - tvm_dev_mblob_string_value, mdev_blob_name + ".str"); + auto* tvm_ffi_library_bin_string = new llvm::GlobalVariable( + *module, tvm_ffi_library_bin_string_ty, true, llvm::GlobalValue::PrivateLinkage, + tvm_ffi_library_bin_string_value, mdev_blob_name + ".str"); #if TVM_LLVM_VERSION >= 100 - tvm_dev_mblob_string->setAlignment(llvm::Align(1)); + tvm_ffi_library_bin_string->setAlignment(llvm::Align(1)); #else - tvm_dev_mblob_string->setAlignment(1); + tvm_ffi_library_bin_string->setAlignment(1); #endif // Global init function @@ -185,12 +186,12 @@ std::unique_ptr CodeGenBlob(const std::string& data, bool system_l ir_builder.SetInsertPoint(var_init_fn_bb); llvm::Constant* indices[] = {constant_zero, constant_zero}; llvm::SmallVector args; - args.push_back(llvm::ConstantExpr::getGetElementPtr(tvm_dev_mblob_string_ty, - tvm_dev_mblob_string, indices)); + args.push_back(llvm::ConstantExpr::getGetElementPtr(tvm_ffi_library_bin_string_ty, + tvm_ffi_library_bin_string, indices)); args.push_back( - llvm::ConstantExpr::getGetElementPtr(blob_value->getType(), tvm_dev_mblob, indices)); + llvm::ConstantExpr::getGetElementPtr(blob_value->getType(), tvm_ffi_library_bin, indices)); auto* tvm_backend_fn_ret_value = ir_builder.CreateCall(tvm_backend_fn, args); - ir_builder.CreateStore(tvm_backend_fn_ret_value, tvm_dev_mblob_reg); + ir_builder.CreateStore(tvm_backend_fn_ret_value, tvm_ffi_library_bin_reg); ir_builder.CreateRetVoid(); } diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index bfbd65e524fb..42a4151f2c5e 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -434,7 +434,8 @@ llvm::Value* CodeGenCPU::GetContextPtr(llvm::GlobalVariable* gv) { } void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) { - std::string ctx_symbol = system_lib_prefix_.value_or("") + tvm::runtime::symbol::tvm_module_ctx; + std::string ctx_symbol = + system_lib_prefix_.value_or("") + tvm::runtime::symbol::tvm_ffi_library_ctx; // Module context gv_mod_ctx_ = InitContextPtr(t_void_p_, ctx_symbol); // Register back the locations. diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index ed70d8692635..552c2b74c64e 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -437,8 +437,8 @@ void LLVMModuleNode::InitMCJIT() { // run ctors mcjit_ee_->runStaticConstructorsDestructors(false); - if (void** ctx_addr = - reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_ctx, *llvm_target))) { + if (void** ctx_addr = reinterpret_cast( + GetGlobalAddr(runtime::symbol::tvm_ffi_library_ctx, *llvm_target))) { *ctx_addr = this; } runtime::InitContextFunctions( @@ -561,8 +561,8 @@ void LLVMModuleNode::InitORCJIT() { err = ctorRunner.run(); ICHECK(!err) << llvm::toString(std::move(err)); - if (void** ctx_addr = - reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_ctx, *llvm_target))) { + if (void** ctx_addr = reinterpret_cast( + GetGlobalAddr(runtime::symbol::tvm_ffi_library_ctx, *llvm_target))) { *ctx_addr = this; } runtime::InitContextFunctions( diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index ad73fc9079e9..e2535dbf6845 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -34,7 +34,7 @@ namespace tvm { namespace codegen { -CodeGenCHost::CodeGenCHost() { module_name_ = name_supply_->FreshName("__tvm_module_ctx"); } +CodeGenCHost::CodeGenCHost() { module_name_ = name_supply_->FreshName("__tvm_ffi_library_ctx"); } void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl, std::string target_str, const std::unordered_set& devices) { @@ -53,7 +53,7 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_d } void CodeGenCHost::InitGlobalContext() { - decl_stream << "void* " << tvm::runtime::symbol::tvm_module_ctx << " = NULL;\n"; + decl_stream << "void* " << tvm::runtime::symbol::tvm_ffi_library_ctx << " = NULL;\n"; } void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; } diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 340e018a8db8..250f51d09c57 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -336,7 +336,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { } } - // signature: (void* self, TVMFFIAny* packed_args, int num_args, TVMFFIAny* v_result) + // signature: (void* handle, TVMFFIAny* packed_args, int num_args, TVMFFIAny* v_result) Array args{v_self_handle, v_packed_args, v_num_packed_args, v_result}; // Arg definitions are defined before buffer binding to avoid the use before