diff --git a/ffi/include/tvm/ffi/extra/module.h b/ffi/include/tvm/ffi/extra/module.h index bc7dff159cda..1af2c2b6b2c0 100644 --- a/ffi/include/tvm/ffi/extra/module.h +++ b/ffi/include/tvm/ffi/extra/module.h @@ -223,14 +223,19 @@ class Module : public ObjectRef { * \brief Symbols for library module. */ namespace symbol { +/*!\ brief symbol prefix for tvm ffi related function symbols */ +constexpr const char* tvm_ffi_symbol_prefix = "__tvm_ffi_"; +// Special symbols have one extra _ prefix to avoid conflict with user symbols +/*! + * \brief Default entry function of a library module is tvm_ffi_symbol_prefix + "main" + */ +constexpr const char* tvm_ffi_main = "__tvm_ffi_main"; /*! \brief Global variable to store context pointer for a library module. */ -constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi_library_ctx"; +constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi__library_ctx"; /*! \brief Global variable to store binary data alongside a library module. */ -constexpr const char* tvm_ffi_library_bin = "__tvm_ffi_library_bin"; +constexpr const char* tvm_ffi_library_bin = "__tvm_ffi__library_bin"; /*! \brief Optional metadata prefix of a symbol. */ -constexpr const char* tvm_ffi_metadata_prefix = "__tvm_ffi_metadata_"; -/*! \brief Default entry function of a library module. */ -constexpr const char* tvm_ffi_main = "__tvm_ffi_main__"; +constexpr const char* tvm_ffi_metadata_prefix = "__tvm_ffi__metadata_"; } // namespace symbol } // namespace ffi } // namespace tvm diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h index 5a30f25a7b5b..f84978800e36 100644 --- a/ffi/include/tvm/ffi/function.h +++ b/ffi/include/tvm/ffi/function.h @@ -800,19 +800,19 @@ inline int32_t TypeKeyToIndex(std::string_view type_key) { * * \endcode */ -#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ - extern "C" { \ - TVM_FFI_DLL_EXPORT int ExportName(void* self, TVMFFIAny* args, int32_t num_args, \ - TVMFFIAny* result) { \ - TVM_FFI_SAFE_CALL_BEGIN(); \ - using FuncInfo = ::tvm::ffi::details::FunctionInfo; \ - static std::string name = #ExportName; \ - ::tvm::ffi::details::unpack_call( \ - std::make_index_sequence{}, &name, Function, \ - reinterpret_cast(args), num_args, \ - reinterpret_cast<::tvm::ffi::Any*>(result)); \ - TVM_FFI_SAFE_CALL_END(); \ - } \ +#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ + extern "C" { \ + TVM_FFI_DLL_EXPORT int __tvm_ffi_##ExportName(void* self, TVMFFIAny* args, int32_t num_args, \ + TVMFFIAny* result) { \ + TVM_FFI_SAFE_CALL_BEGIN(); \ + using FuncInfo = ::tvm::ffi::details::FunctionInfo; \ + static std::string name = #ExportName; \ + ::tvm::ffi::details::unpack_call( \ + std::make_index_sequence{}, &name, Function, \ + reinterpret_cast(args), num_args, \ + reinterpret_cast<::tvm::ffi::Any*>(result)); \ + TVM_FFI_SAFE_CALL_END(); \ + } \ } } // namespace ffi } // namespace tvm diff --git a/ffi/python/tvm_ffi/module.py b/ffi/python/tvm_ffi/module.py index 56aa15348e8c..c3c1d089c612 100644 --- a/ffi/python/tvm_ffi/module.py +++ b/ffi/python/tvm_ffi/module.py @@ -40,7 +40,7 @@ class Module(core.Object): def __new__(cls): instance = super(Module, cls).__new__(cls) # pylint: disable=no-value-for-parameter - instance.entry_name = "__tvm_ffi_main__" + instance.entry_name = "main" instance._entry = None return instance @@ -55,7 +55,7 @@ def entry_func(self): """ if self._entry: return self._entry - self._entry = self.get_function("__tvm_ffi_main__") + self._entry = self.get_function("main") return self._entry @property diff --git a/ffi/src/ffi/extra/library_module.cc b/ffi/src/ffi/extra/library_module.cc index 71c6da6f7cc4..2864cdb5904a 100644 --- a/ffi/src/ffi/extra/library_module.cc +++ b/ffi/src/ffi/extra/library_module.cc @@ -42,7 +42,7 @@ class LibraryModuleObj final : public ModuleObj { Optional GetFunction(const String& name) final { TVMFFISafeCallType faddr; - faddr = reinterpret_cast(lib_->GetSymbol(name.c_str())); + faddr = reinterpret_cast(lib_->GetSymbolWithSymbolPrefix(name)); // ensure the function keeps the Library Module alive Module self_strong_ref = GetRef(this); if (faddr != nullptr) { @@ -140,7 +140,7 @@ class ContextSymbolRegistry { public: void InitContextSymbols(ObjectPtr lib) { for (const auto& [name, symbol] : context_symbols_) { - if (void** symbol_addr = reinterpret_cast(lib->GetSymbol(name.c_str()))) { + if (void** symbol_addr = reinterpret_cast(lib->GetSymbol(name))) { *symbol_addr = symbol; } } diff --git a/ffi/src/ffi/extra/library_module_dynamic_lib.cc b/ffi/src/ffi/extra/library_module_dynamic_lib.cc index 25463a7e5f92..e85b05180baf 100644 --- a/ffi/src/ffi/extra/library_module_dynamic_lib.cc +++ b/ffi/src/ffi/extra/library_module_dynamic_lib.cc @@ -49,7 +49,7 @@ class DSOLibrary final : public Library { if (lib_handle_) Unload(); } - void* GetSymbol(const char* name) final { return GetSymbol_(name); } + void* GetSymbol(const String& name) final { return GetSymbol_(name.c_str()); } private: // private system dependent implementation diff --git a/ffi/src/ffi/extra/library_module_system_lib.cc b/ffi/src/ffi/extra/library_module_system_lib.cc index cdc932cba292..e93c6602c267 100644 --- a/ffi/src/ffi/extra/library_module_system_lib.cc +++ b/ffi/src/ffi/extra/library_module_system_lib.cc @@ -45,7 +45,7 @@ class SystemLibSymbolRegistry { symbol_table_.Set(name, ptr); } - void* GetSymbol(const char* name) { + void* GetSymbol(const String& name) { auto it = symbol_table_.find(name); if (it != symbol_table_.end()) { return (*it).second; @@ -68,13 +68,14 @@ class SystemLibrary final : public Library { public: explicit SystemLibrary(const String& symbol_prefix) : symbol_prefix_(symbol_prefix) {} - void* GetSymbol(const char* name) { - if (symbol_prefix_.length() != 0) { - String name_with_prefix = symbol_prefix_ + name; - void* symbol = reg_->GetSymbol(name_with_prefix.c_str()); - if (symbol != nullptr) return symbol; - } - return reg_->GetSymbol(name); + void* GetSymbol(const String& name) final { + String name_with_prefix = symbol_prefix_ + name; + return reg_->GetSymbol(name_with_prefix); + } + + void* GetSymbolWithSymbolPrefix(const String& name) final { + String name_with_prefix = symbol::tvm_ffi_symbol_prefix + symbol_prefix_ + name; + return reg_->GetSymbol(name_with_prefix); } private: diff --git a/ffi/src/ffi/extra/module_internal.h b/ffi/src/ffi/extra/module_internal.h index 472d531f4b51..86cb6b66c1f6 100644 --- a/ffi/src/ffi/extra/module_internal.h +++ b/ffi/src/ffi/extra/module_internal.h @@ -48,7 +48,17 @@ class Library : public Object { * \param name The name of the symbol. * \return The symbol. */ - virtual void* GetSymbol(const char* name) = 0; + virtual void* GetSymbol(const String& name) = 0; + /*! + * \brief Get the symbol address for a given name with the tvm ffi symbol prefix. + * \param name The name of the symbol. + * \return The symbol. + * \note This function will be overloaded by systemlib implementation. + */ + virtual void* GetSymbolWithSymbolPrefix(const String& name) { + String name_with_prefix = symbol::tvm_ffi_symbol_prefix + name; + return GetSymbol(name_with_prefix); + } // NOTE: we do not explicitly create an type index and type_key here for libary. // This is because we do not need dynamic type downcasting and only need to use the refcounting }; diff --git a/jvm/core/src/main/java/org/apache/tvm/Module.java b/jvm/core/src/main/java/org/apache/tvm/Module.java index 46a74346760e..174457131f05 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Module.java +++ b/jvm/core/src/main/java/org/apache/tvm/Module.java @@ -46,7 +46,7 @@ private static Function getApi(String name) { } private Function entry = null; - private final String entryName = "__tvm_ffi_main__"; + private final String entryName = "main"; /** diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 5ce8b1ec6584..34e9e8381898 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -229,6 +229,11 @@ void CodeGenCPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { } void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) { + if (module_->getFunction(ffi::symbol::tvm_ffi_main) != nullptr) { + // main already exists, no need to create a wrapper function + // main takes precedence over other entry functions + return; + } // create a wrapper function with tvm_ffi_main name and redirects to the entry function llvm::Function* target_func = module_->getFunction(entry_func_name); ICHECK(target_func) << "Function " << entry_func_name << " does not exist in module"; @@ -857,8 +862,9 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& call_args.push_back(GetPackedFuncHandle(func_name)); call_args.insert(call_args.end(), {packed_args, ConstInt32(nargs), result}); } else { + // directly call into symbol, needs to prefix with tvm_ffi_symbol_prefix callee_ftype = ftype_tvm_ffi_c_func_; - callee_value = module_->getFunction(func_name); + callee_value = module_->getFunction(ffi::symbol::tvm_ffi_symbol_prefix + func_name); if (callee_value == nullptr) { callee_value = llvm::Function::Create(ftype_tvm_ffi_c_func_, llvm::Function::ExternalLinkage, func_name, module_.get()); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 8ea438626532..6c88d6943423 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -189,7 +189,8 @@ Optional LLVMModuleNode::GetFunction(const String& name) { TVMFFISafeCallType faddr; With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); - faddr = reinterpret_cast(GetFunctionAddr(name, *llvm_target)); + String name_with_prefix = ffi::symbol::tvm_ffi_symbol_prefix + name; + faddr = reinterpret_cast(GetFunctionAddr(name_with_prefix, *llvm_target)); if (faddr == nullptr) return std::nullopt; ffi::Module self_strong_ref = GetRef(this); return ffi::Function::FromPacked([faddr, self_strong_ref](ffi::PackedArgs args, ffi::Any* rv) { @@ -386,7 +387,8 @@ void LLVMModuleNode::LoadIR(const std::string& file_name) { } bool LLVMModuleNode::ImplementsFunction(const String& name) { - return std::find(function_names_.begin(), function_names_.end(), name) != function_names_.end(); + return std::find(function_names_.begin(), function_names_.end(), + ffi::symbol::tvm_ffi_symbol_prefix + name) != function_names_.end(); } void LLVMModuleNode::InitMCJIT() { diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index acc05cf96c08..65c57cf882b4 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -149,7 +149,9 @@ void CodeGenC::DeclareFunction(const GlobalVar& gvar, const PrimFunc& func) { return gvar->name_hint; } }(); - + if (function_name == ffi::symbol::tvm_ffi_main) { + has_tvm_ffi_main_func_ = true; + } internal_functions_.insert({gvar, function_name}); InitFuncState(func); diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 8c5e1ffd897b..02cb4cd9a779 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -319,6 +319,8 @@ class CodeGenC : public ExprFunctor, Integer constants_byte_alignment_ = 16; /*! \brief whether to print in SSA form */ bool print_ssa_form_{false}; + /*! \brief whether the module has a main function declared */ + bool has_tvm_ffi_main_func_{false}; private: /*! \brief set of volatile buf access */ diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index e18ba0128d6b..a4cbc46f0cca 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -35,7 +35,9 @@ namespace tvm { namespace codegen { -CodeGenCHost::CodeGenCHost() { module_name_ = name_supply_->FreshName("__tvm_ffi_library_ctx"); } +CodeGenCHost::CodeGenCHost() { + module_name_ = name_supply_->FreshName(ffi::symbol::tvm_ffi_library_ctx); +} void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl, std::string target_str, const std::unordered_set& devices) { @@ -72,7 +74,7 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func, emit_fwd_func_decl_ = emit_fwd_func_decl; CodeGenC::AddFunction(gvar, func); - if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc) && !has_tvm_ffi_main_func_) { ICHECK(global_symbol.has_value()) << "CodeGenCHost: The entry func must have the global_symbol attribute, " << "but function " << gvar << " only has attributes " << func->attrs; @@ -235,7 +237,7 @@ void CodeGenCHost::PrintCallPacked(const CallNode* op) { } else { // directly use the original symbol ICHECK(op->op.same_as(builtin::tvm_call_cpacked_lowered())); - packed_func_name = func_name->value; + packed_func_name = ffi::symbol::tvm_ffi_symbol_prefix + func_name->value; } std::string args_stack = PrintExpr(op->args[1]); diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 4a2f530e2f98..1c7e65b3b2cb 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -44,6 +44,7 @@ class CodeGenCHost : public CodeGenC { const std::unordered_set& devices); void InitGlobalContext(); + void AddFunction(const GlobalVar& gvar, const PrimFunc& f) override; void AddFunction(const GlobalVar& gvar, const PrimFunc& f, bool emit_fwd_func_decl); /*! @@ -83,6 +84,8 @@ class CodeGenCHost : public CodeGenC { bool emit_asserts_; /*! \brief whether to emit forwared function declarations in the resulting C code */ bool emit_fwd_func_decl_; + /*! \brief whether to generate the entry function if encountered */ + bool has_main_func_ = false; std::string GetPackedName(const CallNode* op); void PrintGetFuncFromBackend(const std::string& func_name, const std::string& packed_func_name); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index e6c6e9aa0275..f557cab91ad8 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -20,6 +20,7 @@ /*! * \file make_packed_api.cc Lower PrimFunc to use the packed function API. */ +#include #include #include #include @@ -196,7 +197,7 @@ Optional RequiresPackedAPI(const PrimFunc& func) { return std::nullopt; } - return global_symbol; + return global_symbol.value(); } PrimFunc MakePackedAPI(PrimFunc func) { @@ -223,6 +224,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { } auto* func_ptr = func.CopyOnWrite(); + // set the global symbol to the packed function name const Stmt nop = Evaluate(0); int num_args = static_cast(func_ptr->params.size()); @@ -362,10 +364,12 @@ PrimFunc MakePackedAPI(PrimFunc func) { binder.BindDLTensor(buffer, device_type, device_id, var, name_hint + "." + var->name_hint); arg_buffer_declarations.push_back(DeclBuffer(buffer, nop)); } - - func = WithAttrs(std::move(func), - {{tvm::attr::kCallingConv, static_cast(CallingConv::kCPackedFunc)}, - {tvm::attr::kTarget, target_host}}); + // reset global symbol to attach prefix + func = WithAttrs( + std::move(func), + {{tvm::attr::kCallingConv, static_cast(CallingConv::kCPackedFunc)}, + {tvm::attr::kTarget, target_host}, + {tvm::attr::kGlobalSymbol, ffi::symbol::tvm_ffi_symbol_prefix + global_symbol.value()}}); Stmt body = ReturnRewriter(v_result)(func_ptr->body); body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope, diff --git a/tests/python/codegen/test_target_codegen_c_host.py b/tests/python/codegen/test_target_codegen_c_host.py index 3c80cfbeb0b4..8f3798861f46 100644 --- a/tests/python/codegen/test_target_codegen_c_host.py +++ b/tests/python/codegen/test_target_codegen_c_host.py @@ -184,17 +184,9 @@ def subroutine(A_data: T.handle("float32")): built = tvm.tir.build(mod, target="c") - func_names = list(built["get_func_names"]()) - assert ( - "main" in func_names - ), "Externally exposed functions should be listed in available functions." - assert ( - "subroutine" not in func_names - ), "Internal function should not be listed in available functions." - source = built.inspect_source() assert ( - source.count("main(void*") == 2 + source.count("__tvm_ffi_main(void*") == 2 ), "Expected two occurrences, for forward-declaration and definition" assert ( source.count("subroutine(float*") == 2 diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index 953adf78b342..b303cf289eca 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -953,7 +953,10 @@ def test_llvm_target_attributes(): assert re.match('.*"target-cpu"="skylake".*', attribute_definitions[k]) assert re.match('.*"target-features"=".*[+]avx512f.*".*', attribute_definitions[k]) - expected_functions = ["test_func", "test_func_compute_", "__tvm_parallel_lambda"] + expected_functions = [ + "__tvm_ffi_test_func", + "__tvm_parallel_lambda", + ] for n in expected_functions: assert n in functions_with_target diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index 965795d29e02..ab1cce52eac8 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -""" Test different strategies for loading data into vtcm before running HVX workloads. """ +"""Test different strategies for loading data into vtcm before running HVX workloads.""" import numpy as np import pytest @@ -287,13 +287,9 @@ def evaluate( if tvm.testing.utils.IS_IN_CI: # Run with reduced number and repeat for CI - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=1, repeat=1 - ) + timer = module.time_evaluator("main", hexagon_session.device, number=1, repeat=1) else: - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=10, repeat=10 - ) + timer = module.time_evaluator("main", hexagon_session.device, number=10, repeat=10) time = timer(a_hexagon, b_hexagon, c_hexagon) if expected_output is not None: diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx.py b/tests/python/contrib/test_hexagon/test_parallel_hvx.py index 6e1b7db4d5c5..cab3f7d64f9b 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx.py +++ b/tests/python/contrib/test_hexagon/test_parallel_hvx.py @@ -156,9 +156,7 @@ def evaluate(hexagon_session, shape_dtypes, expected_output_producer, sch): number = 1 repeat = 1 - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat - ) + timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat) runtime = timer(a_hexagon, b_hexagon, c_hexagon) tvm.testing.assert_allclose(c_hexagon.numpy(), expected_output_producer(c_shape, a, b)) diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py index a0b94d89cfa6..89385b2aeb8f 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py +++ b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -""" Test different strategies for loading data into vtcm before running HVX workloads. """ +"""Test different strategies for loading data into vtcm before running HVX workloads.""" import numpy as np import tvm @@ -326,9 +326,7 @@ def setup_and_run(hexagon_session, sch, a, b, c, operations, mem_scope="global") number = 1 repeat = 1 - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat - ) + timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat) time = timer(a_hexagon, b_hexagon, c_hexagon) gops = round(operations * 128 * 3 / time.mean / 1e9, 4) return gops, c_hexagon.numpy() @@ -360,9 +358,7 @@ def setup_and_run_preallocated(hexagon_session, sch, a, b, c, operations): number = 1 repeat = 1 - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat - ) + timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat) time = timer(a_hexagon, b_hexagon, c_hexagon, a_vtcm_hexagon, b_vtcm_hexagon, c_vtcm_hexagon) gops = round(operations * 128 * 3 / time.mean / 1e9, 4) return gops, c_hexagon.numpy() diff --git a/tests/python/contrib/test_hexagon/test_parallel_scalar.py b/tests/python/contrib/test_hexagon/test_parallel_scalar.py index dd765178dc32..d9b9a2480312 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_scalar.py +++ b/tests/python/contrib/test_hexagon/test_parallel_scalar.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -""" Test parallelism for multiple different scalar workloads. """ +"""Test parallelism for multiple different scalar workloads.""" import numpy as np @@ -104,9 +104,7 @@ def evaluate(hexagon_session, operations, expected, sch): number = 1 repeat = 1 - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat - ) + timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat) runtime = timer(a_hexagon, b_hexagon, c_hexagon) tvm.testing.assert_allclose(c_hexagon.numpy(), expected(a, b)) diff --git a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py index 265f2bf5fd2d..015a9f0656ed 100644 --- a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py +++ b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py @@ -108,13 +108,9 @@ def evaluate(hexagon_session, sch, size): if tvm.testing.utils.IS_IN_CI: # Run with reduced number and repeat for CI - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=1, repeat=1 - ) + timer = module.time_evaluator("main", hexagon_session.device, number=1, repeat=1) else: - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=10, repeat=10 - ) + timer = module.time_evaluator("main", hexagon_session.device, number=10, repeat=10) runtime = timer(a_hexagon, a_vtcm_hexagon) diff --git a/tests/python/tir-transform/test_tir_transform_make_packed_api.py b/tests/python/tir-transform/test_tir_transform_make_packed_api.py index dd7bd3bf54a2..4fecafef1d15 100644 --- a/tests/python/tir-transform/test_tir_transform_make_packed_api.py +++ b/tests/python/tir-transform/test_tir_transform_make_packed_api.py @@ -261,6 +261,7 @@ def func_without_arg( { "calling_conv": 1, "target": T.target("llvm"), + "global_symbol": "__tvm_ffi_func_without_arg", } ) assert num_args == 0, "func_without_arg: num_args should be 0" @@ -315,6 +316,7 @@ def main( { "calling_conv": 1, "target": T.target("llvm"), + "global_symbol": "__tvm_ffi_main", } ) assert num_args == 1, "main: num_args should be 1" @@ -372,6 +374,7 @@ def main( { "calling_conv": 1, "target": T.target("llvm"), + "global_symbol": "__tvm_ffi_main", } ) assert num_args == 1, "main: num_args should be 1"