diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index efbaa6508af5..80c03ea75132 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -290,14 +290,14 @@ namespace symbol { constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi_library_ctx"; /*! \brief Global variable to store binary data alongside a library module. */ constexpr const char* tvm_ffi_library_bin = "__tvm_ffi_library_bin"; +/*! \brief Placeholder for the module's entry function. */ +constexpr const char* tvm_ffi_main = "__tvm_ffi_main__"; /*! \brief global function to set device */ constexpr const char* tvm_set_device = "__tvm_set_device"; /*! \brief Auxiliary counter to global barrier. */ constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state"; /*! \brief Prepare the global barrier before kernels that uses global barrier. */ constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier"; -/*! \brief Placeholder for the module's entry function. */ -constexpr const char* tvm_module_main = "__tvm_main__"; } // namespace symbol // implementations of inline functions. diff --git a/jvm/core/src/main/java/org/apache/tvm/Module.java b/jvm/core/src/main/java/org/apache/tvm/Module.java index 5e78e26ae739..9fa65054f91f 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Module.java +++ b/jvm/core/src/main/java/org/apache/tvm/Module.java @@ -46,7 +46,7 @@ private static Function getApi(String name) { } private Function entry = null; - private final String entryName = "__tvm_main__"; + private final String entryName = "__tvm_ffi_main__"; /** diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 3dd4de5da0a8..e645d3a2b6ce 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -103,7 +103,7 @@ class Module(tvm.ffi.Object): def __new__(cls): instance = super(Module, cls).__new__(cls) # pylint: disable=no-value-for-parameter - instance.entry_name = "__tvm_main__" + instance.entry_name = "__tvm_ffi_main__" instance._entry = None return instance @@ -118,7 +118,7 @@ def entry_func(self): """ if self._entry: return self._entry - self._entry = self.get_function("__tvm_main__") + self._entry = self.get_function("__tvm_ffi_main__") return self._entry def implements_function(self, name, query_imports=False): diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 2435cccf0a98..6b71df928d23 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -258,7 +258,6 @@ class CUDAPrepGlobalBarrier { ffi::Function CUDAModuleNode::GetFunction(const String& name, const ObjectPtr& sptr_to_self) { ICHECK_EQ(sptr_to_self.get(), this); - ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; if (name == symbol::tvm_prepare_global_barrier) { return ffi::Function(CUDAPrepGlobalBarrier(this, sptr_to_self)); } diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index fffac4adea85..24fc7518d6ad 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -50,15 +50,7 @@ class LibraryModuleNode final : public ModuleNode { ffi::Function GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { TVMFFISafeCallType faddr; - if (name == runtime::symbol::tvm_module_main) { - const char* entry_name = - reinterpret_cast(lib_->GetSymbol(runtime::symbol::tvm_module_main)); - ICHECK(entry_name != nullptr) - << "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; - faddr = reinterpret_cast(lib_->GetSymbol(entry_name)); - } else { - faddr = reinterpret_cast(lib_->GetSymbol(name.c_str())); - } + faddr = reinterpret_cast(lib_->GetSymbol(name.c_str())); if (faddr == nullptr) return ffi::Function(); return packed_func_wrapper_(faddr, sptr_to_self); } diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index be36e6197f36..33bb1705c8e4 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -264,7 +264,6 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) ffi::Function ret; AUTORELEASEPOOL { ICHECK_EQ(sptr_to_self.get(), this); - ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) { ret = ffi::Function(); diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 19b426d4b40a..1c61eeb59635 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -138,7 +138,6 @@ cl::OpenCLWorkspace* OpenCLModuleNodeBase::GetGlobalWorkspace() { ffi::Function OpenCLModuleNodeBase::GetFunction(const String& name, const ObjectPtr& sptr_to_self) { ICHECK_EQ(sptr_to_self.get(), this); - ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) return ffi::Function(); const FunctionInfo& info = it->second; diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 791e4b156979..a871a41f0f86 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -195,7 +195,6 @@ class ROCMWrappedFunc { ffi::Function ROCMModuleNode::GetFunction(const String& name, const ObjectPtr& sptr_to_self) { ICHECK_EQ(sptr_to_self.get(), this); - ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) return ffi::Function(); const FunctionInfo& info = it->second; diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index f4922a1bf01d..db81c959dccd 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -208,7 +208,6 @@ VulkanModuleNode::~VulkanModuleNode() { ffi::Function VulkanModuleNode::GetFunction(const String& name, const ObjectPtr& sptr_to_self) { ICHECK_EQ(sptr_to_self.get(), this); - ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) return ffi::Function(); const FunctionInfo& info = it->second; diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 4dd24026c0c8..6271d4edbe30 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -229,28 +229,42 @@ void CodeGenCPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { } void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) { - llvm::Function* f = module_->getFunction(entry_func_name); - ICHECK(f) << "Function " << entry_func_name << "does not in module"; - llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() + 1); - llvm::GlobalVariable* global = - new llvm::GlobalVariable(*module_, type, true, llvm::GlobalValue::WeakAnyLinkage, nullptr, - runtime::symbol::tvm_module_main); -#if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(1)); -#else - global->setAlignment(1); -#endif - // comdat is needed for windows select any linking to work - // set comdat to Any(weak linking) + // create a wrapper function with tvm_ffi_main name and redirects to the entry function + llvm::Function* target_func = module_->getFunction(entry_func_name); + ICHECK(target_func) << "Function " << entry_func_name << " does not exist in module"; + + // Create wrapper function + llvm::Function* wrapper_func = + llvm::Function::Create(target_func->getFunctionType(), llvm::Function::WeakAnyLinkage, + runtime::symbol::tvm_ffi_main, module_.get()); + + // Set attributes (Windows comdat, DLL export, etc.) if (llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) { - llvm::Comdat* comdat = module_->getOrInsertComdat(runtime::symbol::tvm_module_main); + llvm::Comdat* comdat = module_->getOrInsertComdat(runtime::symbol::tvm_ffi_main); comdat->setSelectionKind(llvm::Comdat::Any); - global->setComdat(comdat); + wrapper_func->setComdat(comdat); } - global->setInitializer( - llvm::ConstantDataArray::getString(*llvm_target_->GetContext(), entry_func_name)); - global->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass); + wrapper_func->setCallingConv(llvm::CallingConv::C); + wrapper_func->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); + + // Create simple tail call + llvm::BasicBlock* entry = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry", wrapper_func); + builder_->SetInsertPoint(entry); + + // Forward all arguments to target function + std::vector call_args; + for (llvm::Value& arg : wrapper_func->args()) { + call_args.push_back(&arg); + } + + llvm::Value* result = builder_->CreateCall(target_func, call_args); + if (target_func->getReturnType()->isVoidTy()) { + builder_->CreateRetVoid(); + } else { + builder_->CreateRet(result); + } } std::unique_ptr CodeGenCPU::Finish() { diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index a9e09652eec0..2daf941edf01 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -190,15 +190,7 @@ ffi::Function LLVMModuleNode::GetFunction(const String& name, TVMFFISafeCallType faddr; With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); - if (name == runtime::symbol::tvm_module_main) { - const char* entry_name = reinterpret_cast( - GetGlobalAddr(runtime::symbol::tvm_module_main, *llvm_target)); - ICHECK(entry_name != nullptr) << "Symbol " << runtime::symbol::tvm_module_main - << " is not presented"; - faddr = reinterpret_cast(GetFunctionAddr(entry_name, *llvm_target)); - } else { - faddr = reinterpret_cast(GetFunctionAddr(name, *llvm_target)); - } + faddr = reinterpret_cast(GetFunctionAddr(name, *llvm_target)); if (faddr == nullptr) return ffi::Function(); return tvm::runtime::WrapFFIFunction(faddr, sptr_to_self); } diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 6cd12a931962..020054b3e1fc 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -77,11 +77,11 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func, << "CodeGenCHost: The entry func must have the global_symbol attribute, " << "but function " << gvar << " only has attributes " << func->attrs; - function_names_.push_back(runtime::symbol::tvm_module_main); + function_names_.push_back(runtime::symbol::tvm_ffi_main); stream << "// CodegenC: NOTE: Auto-generated entry function\n"; PrintFuncPrefix(stream); PrintType(func->ret_type, stream); - stream << " " << tvm::runtime::symbol::tvm_module_main + stream << " " << tvm::runtime::symbol::tvm_ffi_main << "(void* self, void* args,int num_args, void* result) {\n"; stream << " return " << global_symbol.value() << "(self, args, num_args, result);\n"; stream << "}\n"; diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index fe7a615531ae..9461da2277eb 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -289,9 +289,13 @@ def evaluate( if tvm.testing.utils.IS_IN_CI: # Run with reduced number and repeat for CI - timer = module.time_evaluator("__tvm_main__", hexagon_session.device, number=1, repeat=1) + timer = module.time_evaluator( + "__tvm_ffi_main__", hexagon_session.device, number=1, repeat=1 + ) else: - timer = module.time_evaluator("__tvm_main__", hexagon_session.device, number=10, repeat=10) + timer = module.time_evaluator( + "__tvm_ffi_main__", hexagon_session.device, number=10, repeat=10 + ) time = timer(a_hexagon, b_hexagon, c_hexagon) if expected_output is not None: diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx.py b/tests/python/contrib/test_hexagon/test_parallel_hvx.py index 8f77fa1c4016..682235256847 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx.py +++ b/tests/python/contrib/test_hexagon/test_parallel_hvx.py @@ -160,7 +160,7 @@ def evaluate(hexagon_session, shape_dtypes, expected_output_producer, sch): repeat = 1 timer = module.time_evaluator( - "__tvm_main__", hexagon_session.device, number=number, repeat=repeat + "__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat ) runtime = timer(a_hexagon, b_hexagon, c_hexagon) tvm.testing.assert_allclose(c_hexagon.numpy(), expected_output_producer(c_shape, a, b)) diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py index a584997dd507..17e31af0a793 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py +++ b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py @@ -331,7 +331,7 @@ def setup_and_run(hexagon_session, sch, a, b, c, operations, mem_scope="global") repeat = 1 timer = module.time_evaluator( - "__tvm_main__", hexagon_session.device, number=number, repeat=repeat + "__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat ) time = timer(a_hexagon, b_hexagon, c_hexagon) gops = round(operations * 128 * 3 / time.mean / 1e9, 4) @@ -365,7 +365,7 @@ def setup_and_run_preallocated(hexagon_session, sch, a, b, c, operations): repeat = 1 timer = module.time_evaluator( - "__tvm_main__", hexagon_session.device, number=number, repeat=repeat + "__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat ) time = timer(a_hexagon, b_hexagon, c_hexagon, a_vtcm_hexagon, b_vtcm_hexagon, c_vtcm_hexagon) gops = round(operations * 128 * 3 / time.mean / 1e9, 4) diff --git a/tests/python/contrib/test_hexagon/test_parallel_scalar.py b/tests/python/contrib/test_hexagon/test_parallel_scalar.py index bd9c78d5daa1..dd765178dc32 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_scalar.py +++ b/tests/python/contrib/test_hexagon/test_parallel_scalar.py @@ -105,7 +105,7 @@ def evaluate(hexagon_session, operations, expected, sch): repeat = 1 timer = module.time_evaluator( - "__tvm_main__", hexagon_session.device, number=number, repeat=repeat + "__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat ) runtime = timer(a_hexagon, b_hexagon, c_hexagon) diff --git a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py index 931f99b2ec92..265f2bf5fd2d 100644 --- a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py +++ b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py @@ -108,9 +108,13 @@ def evaluate(hexagon_session, sch, size): if tvm.testing.utils.IS_IN_CI: # Run with reduced number and repeat for CI - timer = module.time_evaluator("__tvm_main__", hexagon_session.device, number=1, repeat=1) + timer = module.time_evaluator( + "__tvm_ffi_main__", hexagon_session.device, number=1, repeat=1 + ) else: - timer = module.time_evaluator("__tvm_main__", hexagon_session.device, number=10, repeat=10) + timer = module.time_evaluator( + "__tvm_ffi_main__", hexagon_session.device, number=10, repeat=10 + ) runtime = timer(a_hexagon, a_vtcm_hexagon)