From eea84c86ff2fbea826a3b88a7a51e8baf213ffac Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Tue, 2 Feb 2021 19:14:25 +0000 Subject: [PATCH] [AOT] Name mangling in AOT Mini-RFC is here: https://discuss.tvm.apache.org/t/mini-rfc-name-mangling-in-aot With this change we'll mangle the name of global symbols so that we can bundle together multiple models in the same application. The relay.build interface has been left unchanged, which means I am resuing mod_name as a prefix for all functions. If mod_name is None then a "_tvm" prefix is used. I had to add two different compilation functions: - _CompileEngineLowerWithModuleName to mangle all the operators with the mod_name - PartitionGraphWithModName to mangle all the operators produced by BYOC I could have changed signature of both, but that would have meant a very invasive refactoring. I refactored the aot test utils and added some tests for multiple models. Change-Id: I30e93fa075f660054577ea36cf9268ec0c6eebcb --- apps/microtvm/zephyr/aot_demo/src/main.c | 4 +- include/tvm/runtime/module.h | 2 +- python/tvm/micro/model_library_format.py | 16 +- python/tvm/relay/backend/compile_engine.py | 6 +- .../relay/backend/graph_executor_codegen.py | 4 +- python/tvm/relay/backend/utils.py | 37 +++ python/tvm/relay/build_module.py | 15 +- python/tvm/relay/transform/transform.py | 6 +- src/relay/backend/aot_executor_codegen.cc | 30 ++- src/relay/backend/build_module.cc | 15 +- src/relay/backend/compile_engine.cc | 19 +- src/relay/backend/compile_engine.h | 3 +- src/relay/backend/graph_executor_codegen.cc | 12 +- src/relay/backend/vm/compiler.cc | 3 +- src/relay/transforms/partition_graph.cc | 85 ++++++- src/runtime/meta_data.h | 13 +- src/target/source/codegen_c_host.cc | 6 +- src/target/source/codegen_c_host.h | 2 + src/target/source/source_module.cc | 9 +- tests/cpp/relay_build_module_test.cc | 2 +- tests/cpp/utvm_runtime_standalone_test.cc | 2 +- .../contrib/test_bnns/test_conv2d_patterns.py | 6 +- .../contrib/test_ethosn/test_networks.py | 16 +- tests/python/contrib/test_tensorrt.py | 6 +- .../test_vitis_ai/test_vitis_ai_codegen.py | 5 +- tests/python/relay/aot/aot_test.mk | 3 +- tests/python/relay/aot/aot_test_utils.py | 240 +++++++++++++----- tests/python/relay/aot/test_crt_aot.py | 76 +++++- tests/python/relay/test_json_runtime.py | 32 +-- tests/python/relay/test_op_fast_math.py | 2 +- .../python/relay/test_pass_partition_graph.py | 62 ++--- .../test_micro_model_library_format.py | 12 +- 32 files changed, 557 insertions(+), 194 deletions(-) create mode 100644 python/tvm/relay/backend/utils.py diff --git a/apps/microtvm/zephyr/aot_demo/src/main.c b/apps/microtvm/zephyr/aot_demo/src/main.c index b92366a7098b9..2aacda70f5dc9 100644 --- a/apps/microtvm/zephyr/aot_demo/src/main.c +++ b/apps/microtvm/zephyr/aot_demo/src/main.c @@ -41,7 +41,7 @@ #define WORKSPACE_SIZE (270 * 1024) static uint8_t g_aot_memory[WORKSPACE_SIZE]; -extern tvm_model_t network; +extern tvm_model_t tvmgen_default_network; tvm_workspace_t app_workspace; // Wakeup sequence used to wake up QEMU on the host. @@ -205,7 +205,7 @@ void main(void) { double elapsed_time = 0; TVMPlatformTimerStart(); - int ret_val = tvm_runtime_run(&network, inputs, outputs); + int ret_val = tvm_runtime_run(&tvmgen_default_network, inputs, outputs); TVMPlatformTimerStop(&elapsed_time); if (ret_val != 0) { diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 689fe6fa53fce..9dd7423c66797 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -231,7 +231,7 @@ constexpr const char* tvm_param_prefix = "__tvm_param__"; /*! \brief A PackedFunc that looks up linked parameters by storage_id. */ constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param"; /*! \brief The main AOT executor function */ -constexpr const char* tvm_run_func_prefix = "tvm__run_func"; +constexpr const char* tvm_run_func_suffix = "run_model"; } // namespace symbol // implementations of inline functions. diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 1cc3adf9ae07e..7062b20e0d54b 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -35,7 +35,7 @@ class UnsupportedInModelLibraryFormatError(Exception): """Raised when export_model_library_format does not support the given Module tree.""" -def _populate_codegen_dir(mod, codegen_dir: str): +def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None): """Populate the codegen sub-directory as part of a Model Library Format export. Parameters @@ -44,6 +44,9 @@ def _populate_codegen_dir(mod, codegen_dir: str): Module which should be written to codegen_dir. codegen_dir : str Path to the codegen directory on disk. + module_name: Optional[str] + Name used to prefix the generated source files + """ dso_modules = mod._collect_dso_modules() dso_module_handles = [m.handle.value for m in dso_modules] @@ -55,17 +58,19 @@ def _populate_codegen_dir(mod, codegen_dir: str): mod_indices = {"lib": 0, "src": 0} host_codegen_dir = os.path.join(codegen_dir, "host") + lib_name = f"{module_name}_lib" if module_name else "lib" + for dso_mod in dso_modules: if dso_mod.type_key == "c": index = mod_indices["src"] mod_indices["src"] += 1 parent_dir = os.path.join(host_codegen_dir, "src") - file_name = os.path.join(parent_dir, f"lib{index}.c") + file_name = os.path.join(parent_dir, f"{lib_name}{index}.c") elif dso_mod.type_key == "llvm": index = mod_indices["lib"] mod_indices["lib"] += 1 parent_dir = os.path.join(host_codegen_dir, "lib") - file_name = os.path.join(parent_dir, f"lib{index}.o") + file_name = os.path.join(parent_dir, f"{lib_name}{index}.o") else: assert ( False @@ -98,7 +103,6 @@ def _build_sid_map(graph_json): A list with one entry per storage id describing that memory. """ graph = json.loads(graph_json) - seen_storage_ids = set() memory_map = [] for node_id, storage_id in enumerate(graph["attrs"]["storage_id"][1]): @@ -227,7 +231,7 @@ def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, fil runtime = ["aot"] if is_aot else ["graph"] metadata = { - "version": 2, + "version": 3, "model_name": mod.libmod_name, "export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"), "memory": _build_memory_map(mod), @@ -240,7 +244,7 @@ def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, fil codegen_dir_path = tempdir.relpath("codegen") os.mkdir(codegen_dir_path) - _populate_codegen_dir(mod.lib, codegen_dir_path) + _populate_codegen_dir(mod.lib, codegen_dir_path, mod.libmod_name) parameters_dir_path = tempdir.relpath("parameters") os.mkdir(parameters_dir_path) diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 13ecee7debe2f..2db8c5a669f08 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -26,6 +26,7 @@ from tvm.runtime import Object from tvm.support import libinfo from tvm.target import Target +from ..backend.utils import mangle_module_name from .. import function as _function from .. import ty as _ty from . import _backend @@ -328,7 +329,7 @@ class CompileEngine(Object): def __init__(self): raise RuntimeError("Cannot construct a CompileEngine") - def lower(self, source_func, target=None): + def lower(self, source_func, target=None, mod_name="default"): """Lower a source_func to a CachedFunc. Parameters @@ -346,8 +347,9 @@ def lower(self, source_func, target=None): """ # pylint: disable=broad-except, import-outside-toplevel try: + mod_name = mangle_module_name(mod_name) key = _get_cache_key(source_func, target) - return _backend._CompileEngineLower(self, key) + return _backend._CompileEngineLower(self, key, mod_name) except Exception: import traceback diff --git a/python/tvm/relay/backend/graph_executor_codegen.py b/python/tvm/relay/backend/graph_executor_codegen.py index f24bf2c2b55b3..11274b97197f2 100644 --- a/python/tvm/relay/backend/graph_executor_codegen.py +++ b/python/tvm/relay/backend/graph_executor_codegen.py @@ -37,6 +37,7 @@ from tvm.relay import _build_module from tvm.target import Target from tvm.tir import expr as _expr +from .utils import mangle_module_name class GraphExecutorCodegen(object): @@ -80,7 +81,8 @@ def codegen(self, func): params : Dict[str, tvm.nd.NDArray] Additional constant parameters. """ - self._codegen(func) + default_mod_name = mangle_module_name("default") + self._codegen(func, default_mod_name) graph_json = self._get_graph_json() lowered_func = self._get_irmodule() param_names = self._list_params_name() diff --git a/python/tvm/relay/backend/utils.py b/python/tvm/relay/backend/utils.py new file mode 100644 index 0000000000000..b8430a9e6b6eb --- /dev/null +++ b/python/tvm/relay/backend/utils.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utility backend functions.""" + + +def _is_valid_modname(mod_name): + """Determine if mod_name is a valid string to use inside function names""" + if mod_name: + try: + mod_name.encode("ascii") + return True + except UnicodeEncodeError: + return False + + return True + + +def mangle_module_name(mod_name): + if not _is_valid_modname(mod_name): + raise ValueError(mod_name + " contains invalid characters") + if mod_name: + return "tvmgen_" + mod_name + return "tvmgen" diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index e134eeeefd09b..ed722643ff707 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -34,6 +34,7 @@ from . import expr as _expr from . import function as _function from .transform import InferType +from .backend.utils import mangle_module_name from .backend import executor_factory as _executor_factory from .backend import interpreter as _interpreter from .backend.vm import VMExecutor @@ -85,7 +86,9 @@ def __init__(self): self._get_params_func = self.mod["get_params"] self._get_function_metadata = self.mod["get_function_metadata"] - def build(self, mod, target=None, target_host=None, params=None, executor="graph"): + def build( + self, mod, target=None, target_host=None, params=None, executor="graph", mod_name=None + ): """ Parameters ---------- @@ -115,6 +118,9 @@ def build(self, mod, target=None, target_host=None, params=None, executor="graph - If "graph" is specified, then the graph_executor will be used - If "aot" is specified, then the aot_executor will be used + mod_name: Optional[str] + The module name we will build + Returns ------- graph_json : str @@ -145,7 +151,9 @@ def build(self, mod, target=None, target_host=None, params=None, executor="graph old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent autotvm.GLOBAL_SCOPE.silent = use_auto_scheduler - self._build(mod, target, target_host, executor) + mod_name = mangle_module_name(mod_name) + + self._build(mod, target, target_host, executor, mod_name) autotvm.GLOBAL_SCOPE.silent = old_autotvm_silent # Get artifacts @@ -295,6 +303,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" """ # pylint: enable=line-too-long # fmt: on + if not isinstance(ir_mod, (IRModule, _function.Function)): raise ValueError("Type of input parameter mod must be tvm.IRModule") @@ -330,7 +339,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" with tophub_context: bld_mod = BuildModule() executor_config, runtime_mod, params = bld_mod.build( - mod=ir_mod, target=target, params=params, executor=executor + mod=ir_mod, target=target, params=params, executor=executor, mod_name=mod_name ) func_metadata = bld_mod.get_function_metadata() diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 20e8bb94c5014..312480867cf45 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -29,6 +29,7 @@ from tvm import relay from . import _ffi_api +from ..backend.utils import mangle_module_name def build_config(opt_level=2, required_pass=None, disabled_pass=None, trace=None): @@ -714,7 +715,7 @@ def LambdaLift(): return _ffi_api.LambdaLift() -def PartitionGraph(): +def PartitionGraph(mod_name="default"): """Partition a Relay program into regions that can be executed on different backends. @@ -723,7 +724,8 @@ def PartitionGraph(): ret: tvm.transform.Pass The registered pass that partitions the Relay program. """ - return _ffi_api.PartitionGraph() + mod_name = mangle_module_name(mod_name) + return _ffi_api.PartitionGraph(mod_name) def AnnotateTarget(targets, include_non_call_ops=True): diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index a005247d424a1..165af59af7d53 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -359,11 +359,12 @@ class AOTExecutorCodegen : public ExprVisitor { auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); Target target; + // Handle external function if (func->GetAttr(attr::kCompiler).defined()) { target = Target("ext_dev"); CCacheKey key = (*pf0)(func, target); - CachedFunc ext_func = (*pf1)(compile_engine_, key); + CachedFunc ext_func = (*pf1)(compile_engine_, key, mod_name_); ICHECK(ext_func.defined()) << "External function is not defined."; UpdateConstants(func, ¶ms_); @@ -394,7 +395,7 @@ class AOTExecutorCodegen : public ExprVisitor { target = targets_[call_dev_type]; } CCacheKey key = (*pf0)(func, target); - CachedFunc lowered_func = (*pf1)(compile_engine_, key); + CachedFunc lowered_func = (*pf1)(compile_engine_, key, mod_name_); if (!lowered_funcs_.count(target->str())) { lowered_funcs_[target->str()] = IRModule(Map({})); } @@ -517,7 +518,10 @@ class AOTExecutorCodegen : public ExprVisitor { // Define the PrimFunc attributes Map dict_attrs; - dict_attrs.Set("global_symbol", runtime::String(runtime::symbol::tvm_run_func_prefix)); + String run_func_name = + runtime::get_name_mangled(mod_name_, runtime::symbol::tvm_run_func_suffix); + dict_attrs.Set("global_symbol", run_func_name); + dict_attrs.Set("runner_function", Bool(true)); // Make the PrimFunc return tir::PrimFunc(main_signature_, body, VoidType(), Map(), @@ -561,6 +565,8 @@ class AOTExecutorCodegen : public ExprVisitor { std::vector stmts_; /*! \brief the list of return sids (note that the function might return more then one output */ IntegerArray return_sid_; + /*! \brief the module name we use to mangle the function names */ + String mod_name_; public: AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host) @@ -570,10 +576,11 @@ class AOTExecutorCodegen : public ExprVisitor { target_host_ = target_host; } - LoweredOutput Codegen(relay::Function func) { + LoweredOutput Codegen(relay::Function func, String mod_name) { // Get the module, storage map and token sizes auto pf = GetPackedFunc("relay.backend.GraphPlanMemory"); storage_device_map_ = (*pf)(func); + mod_name_ = mod_name; int input_index = 0; for (auto input : func->params) { @@ -621,15 +628,15 @@ class AOTExecutorCodegen : public ExprVisitor { auto target_host_str = target_host_->str(); if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) { ret.lowered_funcs[target_host_str]->Add( - GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func); + GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func); } else { Map symbol_map; - symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func); + symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func); ret.lowered_funcs.Set(target_host_str, IRModule(symbol_map)); } ret.function_metadata = std::move(function_metadata_); - ret.metadata = - runtime::Metadata(input_vars_.size(), return_sid_.size(), runtime::kTvmExecutorAot); + ret.metadata = runtime::Metadata(input_vars_.size(), return_sid_.size(), + runtime::kTvmExecutorAot, mod_name); return ret; } }; @@ -649,7 +656,8 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { } else if (name == "codegen") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Function func = args[0]; - this->output_ = codegen(func); + String mod_name = args[1]; + this->output_ = codegen(func, mod_name); }); } else if (name == "list_params_name") { return PackedFunc( @@ -700,7 +708,9 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { targets, target_host); } - LoweredOutput codegen(Function func) { return this->codegen_->Codegen(func); } + LoweredOutput codegen(Function func, String mod_name) { + return this->codegen_->Codegen(func, mod_name); + } Array list_params_name() { Array ret; diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 00b6fed8c64ac..23670109e5270 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -58,7 +58,7 @@ struct BuildOutput { struct ExecutorCodegen { void Init(runtime::Module* m, TargetsMap targets) { CallFunc("init", m, targets); } - void Codegen(const Function& func) { CallFunc("codegen", func); } + void Codegen(const Function& func, String mod_name) { CallFunc("codegen", func, mod_name); } virtual void UpdateOutput(BuildOutput* ret) = 0; @@ -177,8 +177,8 @@ class RelayBuildModule : public runtime::ModuleNode { [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetModule(); }); } else if (name == "build") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.num_args, 4); - this->Build(args[0], args[1], args[2], args[3]); + ICHECK_EQ(args.num_args, 5); + this->Build(args[0], args[1], args[2], args[3], args[4]); }); } else if (name == "list_params") { return PackedFunc( @@ -279,13 +279,13 @@ class RelayBuildModule : public runtime::ModuleNode { * \param target_host Host target device */ void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host, - const String executor) { + const String executor, const String mod_name) { // Create protected variable targets_ from ground up targets_ = targets; target_host_ = target_host; executor_ = executor; CheckAndUpdateHostConsistency(&targets_, &target_host_); - BuildRelay(mod, params_); + BuildRelay(mod, params_, mod_name); // Clear compile engine so that tuning schedules can be changed between runs. See issue #6096. CompileEngine::Global()->Clear(); } @@ -508,7 +508,8 @@ class RelayBuildModule : public runtime::ModuleNode { * \param params The parameters. */ void BuildRelay(IRModule relay_module, - const std::unordered_map& params) { + const std::unordered_map& params, + const String mod_name) { Target target_host = GetTargetHost(); // If no target_host has been set, we choose a default one, which is // llvm if "codegen.LLVMModuleCreate" is accessible. @@ -527,7 +528,7 @@ class RelayBuildModule : public runtime::ModuleNode { // Generate code for the updated function. executor_codegen_ = MakeExecutorCodegen(executor_); executor_codegen_->Init(nullptr, targets_); - executor_codegen_->Codegen(func); + executor_codegen_->Codegen(func, mod_name); executor_codegen_->UpdateOutput(&ret_); ret_.params = executor_codegen_->GetParams(); diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 5e3b66b3ae15a..574322b4d4b13 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -45,6 +45,7 @@ #include #include +#include "../../runtime/meta_data.h" #include "../transforms/pass_utils.h" #include "utils.h" @@ -612,11 +613,14 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> class CompileEngineImpl : public CompileEngineNode { public: // Lower the function. - CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; } + CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) { + return LowerInternal(key, mangle_fn)->cached_func; + } // For now, build one module per function. PackedFunc JIT(const CCacheKey& key) final { - CCacheValue value = LowerInternal(key); + auto mangle_fn = [](String name) { return name; }; + CCacheValue value = LowerInternal(key, mangle_fn); if (value->packed_func != nullptr) return value->packed_func; // build the function. tvm::runtime::Module m; @@ -711,7 +715,7 @@ class CompileEngineImpl : public CompileEngineNode { private: // implement lowered func - CCacheValue LowerInternal(const CCacheKey& key) { + CCacheValue LowerInternal(const CCacheKey& key, std::function mangle_fn) { std::lock_guard lock(mutex_); CCacheValue value; auto it = cache_.find(key); @@ -755,8 +759,8 @@ class CompileEngineImpl : public CompileEngineNode { return value; } } + cache_node->func_name = GetUniqueName(mangle_fn(cache_node->func_name)); - cache_node->func_name = GetUniqueName(cache_node->func_name); // NOTE: array will copy on write. Array all_args = cache_node->inputs; for (te::Tensor arg : cache_node->outputs) { @@ -876,7 +880,12 @@ TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear").set_body_typed([](Compi }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") - .set_body_typed([](CompileEngine self, CCacheKey key) { return self->Lower(key); }); + .set_body_typed([](CompileEngine self, CCacheKey key, const String mod_name) { + auto mangle_fn = [mod_name](String name) { + return runtime::get_name_mangled(mod_name, name); + }; + return self->Lower(key, mangle_fn); + }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc") .set_body_typed([](CompileEngine self, CCacheKey key) { return self->LowerShapeFunc(key); }); diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index d7628e7a5bdf1..f766fcf97ea71 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -199,9 +199,10 @@ class CompileEngineNode : public Object { /*! * \brief Get lowered result. * \param key The key to the cached function. + * \param mod_name The module name to mangle the functions * \return The result. */ - virtual CachedFunc Lower(const CCacheKey& key) = 0; + virtual CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) = 0; /*! * \brief Just in time compile to get a PackedFunc. * \param key The key to the cached function. diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index d92d4d2077f76..bca8e82440937 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -270,9 +270,10 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorparams) { @@ -547,7 +548,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorGetAttr(attr::kCompiler).defined()) { target = Target("ext_dev"); CCacheKey key = (*pf0)(func, target); - CachedFunc ext_func = (*pf1)(compile_engine_, key); + CachedFunc ext_func = (*pf1)(compile_engine_, key, mod_name_); ICHECK(ext_func.defined()) << "External function is not defined."; UpdateConstants(func, ¶ms_); return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name, attrs); @@ -573,7 +574,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorstr())) { lowered_funcs_[target->str()] = IRModule(Map({})); } @@ -724,6 +725,8 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator param_storage_ids_; /*! \brief plan memory of device result */ Map> storage_device_map_; + /*! \brief the module name we use to mangle the function names */ + String mod_name_; /*! \brief lowered funcs */ std::unordered_map lowered_funcs_; /*! \brief lowered funcs */ @@ -756,7 +759,8 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { } else if (name == "codegen") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Function func = args[0]; - this->output_ = this->codegen_->Codegen(func); + String mod_name = args[1]; + this->output_ = this->codegen_->Codegen(func, mod_name); }); } else if (name == "get_graph_json") { return PackedFunc( diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b43972d686cc4..ebfe506029f3a 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -545,7 +545,8 @@ class VMFunctionCompiler : ExprFunctor { } CCacheKey key(func, target); - auto cfunc = engine_->Lower(key); + auto mangle_fn = [](String name) { return name; }; + auto cfunc = engine_->Lower(key, mangle_fn); auto op_index = -1; if (func->GetAttr(attr::kCompiler).defined()) { diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 94891c3c98ea7..8fe79a6adb552 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -480,11 +480,77 @@ IRModule FlattenTupleOutputs(IRModule module) { return module; } +class NameMangleExtFuncs : public MixedModeMutator { + public: + explicit NameMangleExtFuncs(const IRModule& module, std::function mangle_fn) + : module_(module), mangle_fn_(mangle_fn) {} + + IRModule Run() { + auto glob_funcs = module_->functions; + + // Collect function names to be mangled and create + // global mangled variables + for (const auto& pair : glob_funcs) { + if (auto* fn = pair.second.as()) { + auto func = GetRef(fn); + if (func->GetAttr(attr::kCompiler).defined()) { + auto fn_name_mangled = mangle_fn_(pair.first->name_hint); + GlobalVar gvar = GlobalVar(fn_name_mangled); + mangled_gvars_[pair.first->name_hint] = gvar; + } + } + } + + // Walk the three and mangle the functions. Then replace compiler functions + // with mangled functions in the module + IRModule new_module; + for (const auto& pair : glob_funcs) { + if (auto* fn = pair.second.as()) { + auto func = GetRef(fn); + + if (func->GetAttr(attr::kCompiler).defined()) { + auto new_dict = func->attrs->dict; + new_dict.Set(tvm::attr::kGlobalSymbol, String(mangle_fn_(pair.first->name_hint))); + func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, + DictAttrs(new_dict)); + new_module->Add(mangled_gvars_[pair.first->name_hint], func); + } else { + func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, + func->attrs); + new_module->Add(pair.first, func); + } + } + } + + return new_module; + } + + private: + Expr Rewrite_(const CallNode* call, const Expr& post) final { + Expr new_expr = post; + const CallNode* new_call = new_expr.as(); + auto op_node = new_call->op.as(); + if (op_node == nullptr || mangled_gvars_.find(op_node->name_hint) == mangled_gvars_.end()) { + return new_expr; + } else { + return Call(mangled_gvars_[op_node->name_hint], new_call->args, new_call->attrs, + new_call->type_args, new_call->span); + } + } + + /*!\brief The IRModule used for partitioning. */ + IRModule module_; + /*!\brief The function used to mangle operators name */ + std::function mangle_fn_; + /*!\brief Tabled used to store (unmangled_var_name, mangled_gvar) pairs*/ + std::unordered_map mangled_gvars_; +}; + } // namespace partitioning namespace transform { -Pass PartitionGraph() { +Pass PartitionGraph(String mod_name) { runtime::TypedPackedFunc flatten_tuples = [=](IRModule m, PassContext pc) { // There could be compiler_end annotations on tuples @@ -506,13 +572,26 @@ Pass PartitionGraph() { runtime::TypedPackedFunc part_func = [=](IRModule m, PassContext pc) { return partitioning::Partitioner(m).Partition(); }; + auto name_mangling_fn = [mod_name](String name) { + return runtime::get_name_mangled(mod_name, name); + }; + + runtime::TypedPackedFunc name_mangling_func = + [=](IRModule m, PassContext pc) { + return partitioning::NameMangleExtFuncs(m, name_mangling_fn).Run(); + }; + auto flatten_tuples_pass = CreateModulePass(flatten_tuples, 0, "FlattenNestedTuples", {}); auto remove_default_pass = CreateModulePass(remove_defaults, 0, "RemoveDefaultAnnotations", {}); auto partition_pass = CreateModulePass(part_func, 0, "PartitionGraph", {}); - return Sequential({flatten_tuples_pass, remove_default_pass, partition_pass, InferType()}); + auto name_mangling_pass = CreateModulePass(name_mangling_func, 0, "NameMangleExtFuncs", {}); + return Sequential( + {flatten_tuples_pass, remove_default_pass, partition_pass, name_mangling_pass, InferType()}); } -TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph").set_body_typed(transform::PartitionGraph); +TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph").set_body_typed([](String mod_name) { + return transform::PartitionGraph(mod_name); +}); } // namespace transform diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 495b3f22e6adb..e3ec155dc2911 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -41,6 +41,14 @@ namespace tvm { namespace runtime { +inline String get_name_mangled(const String& module_name, const String& name) { + std::stringstream ss; + ICHECK(module_name.defined()); + ICHECK(name.defined()); + ss << module_name << "_" << name; + return ss.str(); +} + /*! * \brief Structure that can be optionally used by the executor codegen */ @@ -53,6 +61,8 @@ class MetadataNode : public Object { /*! \brief the executor to be used to run the model */ String executor = kTvmExecutorGraph; + String mod_name = ""; + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const char* _type_key = "MetadataObj"; TVM_DECLARE_FINAL_OBJECT_INFO(MetadataNode, Object); @@ -63,11 +73,12 @@ class MetadataNode : public Object { */ class Metadata : public ObjectRef { public: - TVM_DLL Metadata(int num_inputs, int num_outputs, String executor) { + TVM_DLL Metadata(int num_inputs, int num_outputs, String executor, String mod_name) { auto n = make_object(); n->num_inputs = num_inputs; n->num_outputs = num_outputs; n->executor = executor; + n->mod_name = mod_name; data_ = std::move(n); } diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 03fef4709b5e7..0d96b05f6e488 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -48,10 +48,11 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, std::string target_s decl_stream << "#include \"tvm/runtime/c_runtime_api.h\"\n"; decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n"; decl_stream << "#include \n"; - decl_stream << "void* " << module_name_ << " = NULL;\n"; CodeGenC::Init(output_ssa); } +void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; } + void CodeGenCHost::AddFunction(const PrimFunc& f) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) @@ -390,8 +391,7 @@ runtime::Module BuildCHost(IRModule mod, Target target) { // Make sure that the executor function is the last one to be code generated so that all the // symbols are available to tvm_run_func auto fun_name = std::string(kv.first->name_hint); - const bool is_aot_executor_fn = - (fun_name.rfind(::tvm::runtime::symbol::tvm_run_func_prefix, 0) == 0); + bool is_aot_executor_fn = kv.second->GetAttr("runner_function", Bool(false)).value(); if (is_aot_executor_fn) { aot_executor_fn = Downcast(kv.second); diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index e54d78030ed9a..10a437a547c1f 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -42,6 +42,8 @@ class CodeGenCHost final : public CodeGenC { void AddFunction(const PrimFunc& f); + void DefineModuleName(); + /*! \brief Add linked parameters, if they are present. */ void DeclareParameters(Map params); void LinkParameters(Map params); diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 661df9305036e..ecaf3b16084f6 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -193,16 +193,19 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { } void GenerateAOTDescriptor() { + const std::string run_func = ::tvm::runtime::symbol::tvm_run_func_suffix; + const std::string run_func_mangled = runtime::get_name_mangled(metadata_->mod_name, run_func); + const std::string network_mangled = runtime::get_name_mangled(metadata_->mod_name, "network"); code_ << "#include \"tvm/runtime/crt/internal/aot_executor/aot_executor.h\"\n"; code_ << "#include \"tvm/runtime/c_runtime_api.h\"\n"; code_ << "#ifdef __cplusplus\n"; code_ << "extern \"C\"\n"; code_ << "#endif\n"; - code_ << "TVM_DLL int32_t " << ::tvm::runtime::symbol::tvm_run_func_prefix; + code_ << "TVM_DLL int32_t " << run_func_mangled; code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " "out_type_code, void* resource_handle);\n"; - code_ << "const tvm_model_t network = {\n" - << " .run_func = &" << ::tvm::runtime::symbol::tvm_run_func_prefix << ",\n" + code_ << "const tvm_model_t " << network_mangled << " = {\n" + << " .run_func = &" << run_func_mangled << ",\n" << " .num_input_tensors = " << metadata_->num_inputs << ",\n" << " .num_output_tensors = " << metadata_->num_outputs << ", \n" << "};\n"; diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 3141852405635..37e9e6f9c42c8 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -120,7 +120,7 @@ TEST(Relay, BuildModule) { targets.Set(0, llvm_tgt); auto relay_mod = tvm::IRModule::FromExpr(func); ICHECK(relay_mod.defined()) << "Module must be defined"; - build_f(relay_mod, targets, llvm_tgt, runtime::kTvmExecutorGraph); + build_f(relay_mod, targets, llvm_tgt, runtime::kTvmExecutorGraph, ""); std::string json = json_f(); tvm::runtime::Module mod = mod_f(); // run diff --git a/tests/cpp/utvm_runtime_standalone_test.cc b/tests/cpp/utvm_runtime_standalone_test.cc index e674c3b741448..b8def1db298d7 100644 --- a/tests/cpp/utvm_runtime_standalone_test.cc +++ b/tests/cpp/utvm_runtime_standalone_test.cc @@ -92,7 +92,7 @@ TEST(MicroStandaloneRuntime, BuildModule) { Target llvm_tgt = Target("llvm"); targets.Set(0, llvm_tgt); - build_f(func, targets, llvm_tgt, runtime::kTvmExecutorGraph); + build_f(func, targets, llvm_tgt, runtime::kTvmExecutorGraph, ""); std::string json = json_f(); tvm::runtime::Module mod = mod_f(); std::string o_fname = std::tmpnam(nullptr); diff --git a/tests/python/contrib/test_bnns/test_conv2d_patterns.py b/tests/python/contrib/test_bnns/test_conv2d_patterns.py index b10504bbc9611..b81e74b6d8fa5 100644 --- a/tests/python/contrib/test_bnns/test_conv2d_patterns.py +++ b/tests/python/contrib/test_bnns/test_conv2d_patterns.py @@ -57,7 +57,7 @@ def test_pattern_conv2d_with_bias_add(): res = relay.nn.bias_add(res, b, axis=axis) mod = partition(res) - bias_is_fused = is_op_fused(mod["bnns_0"], "nn.bias_add") + bias_is_fused = is_op_fused(mod["tvmgen_default_bnns_0"], "nn.bias_add") assert bias_is_fused if axis == 1 else not bias_is_fused @@ -73,7 +73,7 @@ def test_pattern_conv2d_with_add(): res = relay.add(res, b) mod = partition(res) - bias_is_fused = is_op_fused(mod["bnns_0"], "add") + bias_is_fused = is_op_fused(mod["tvmgen_default_bnns_0"], "add") assert bias_is_fused == should_be_fused @@ -102,6 +102,6 @@ def test_pattern_conv2d_with_non_cons_bias(): res = relay.nn.bias_add(res, b, axis=1) mod = partition(res) - bias_is_fused = is_op_fused(mod["bnns_0"], "nn.bias_add") + bias_is_fused = is_op_fused(mod["tvmgen_default_bnns_0"], "nn.bias_add") assert not bias_is_fused diff --git a/tests/python/contrib/test_ethosn/test_networks.py b/tests/python/contrib/test_ethosn/test_networks.py index ce89c90d93796..a872ef51df9ab 100644 --- a/tests/python/contrib/test_ethosn/test_networks.py +++ b/tests/python/contrib/test_ethosn/test_networks.py @@ -122,11 +122,11 @@ def test_mobilenet_v1(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"bfb5a50607edb50009c58ae9d4287e4d"} + _compile_hash = {"5d3cee6ecc488c40ecf533c5cbacc534"} if tei.get_ethosn_variant() == 3: _compile_hash = {"896c28b4f06341ea638ead3a593e1aed"} if tei.get_ethosn_api_version() == 2008: - _compile_hash = {"47e216d8ab2bf491708ccf5620bc0d02"} + _compile_hash = {"a48d3bc62852bba469bf4b17e5587929"} if tei.get_ethosn_variant() == 3: _compile_hash = {"2436f523e263f66a063cef902f2f43d7"} if tei.get_ethosn_api_version() == 2011: @@ -152,11 +152,11 @@ def test_inception_v3(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"96116d7e6c7385de0688074a3f889983"} + _compile_hash = {"1bc66e83c3de5a9773a719b179c65b1a"} if tei.get_ethosn_variant() == 3: _compile_hash = {"551cde850c6ef960d19be4f317fb8e68"} if tei.get_ethosn_api_version() == 2008: - _compile_hash = {"8c9d75659cd7bc9ff6dd6d490d28f9b2"} + _compile_hash = {"15be2c639771c18b1cbed40d3156e2cf"} if tei.get_ethosn_variant() == 3: _compile_hash = {"cdd4d7f6453d722ea73224ff9d6a115a"} if tei.get_ethosn_api_version() == 2011: @@ -181,7 +181,7 @@ def test_inception_v4(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"b34aec2a48c591818761ed6b42c133e5"} + _compile_hash = {"578b8ee279911b49912a77a64f5ff620"} if tei.get_ethosn_variant() == 3: _compile_hash = {"30f078bd42757e8686eafa1f28d0d352"} if tei.get_ethosn_api_version() == 2008: @@ -189,7 +189,7 @@ def test_inception_v4(): pytest.skip( "Ethos-N78 20.08 does not support inception_v4 in the default configuration." ) - _compile_hash = {"798292bfa596ca7c32086396b494b46c"} + _compile_hash = {"2f073fb62539469bfdf1bc20737ef748"} if tei.get_ethosn_api_version() == 2011: _compile_hash = {"53f126cf654d4cf61ebb23c767f6740b"} if tei.get_ethosn_variant() == 3: @@ -212,11 +212,11 @@ def test_ssd_mobilenet_v1(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"c312edfc9a946ed4dc7c049d472dae6e", "3183f0fa5eba8f6b9557d14eaf47842d"} + _compile_hash = {"cd335229a2052f30273f127a233bd319", "95dedc29d911cdc6b28207ca08e42470"} if tei.get_ethosn_variant() == 3: _compile_hash = {"deee52e136327436411fc725624ae2ea", "6526509d3cbee014e38c79e22bb29d7f"} if tei.get_ethosn_api_version() == 2008: - _compile_hash = {"5999f26e140dee0d7866491997ef78c5", "24e3a690a7e95780052792d5626c85be"} + _compile_hash = {"6f37a3b862ddf03077f46b770474ba32", "31a766a2ede0fb8a1af7f2bcb7ffabb7"} if tei.get_ethosn_variant() == 3: _compile_hash = {"da871b3f03a93df69d704ed44584d6cd", "9f52411d301f3cba3f6e4c0f1c558e87"} if tei.get_ethosn_api_version() == 2011: diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index f9912c9674e53..55a357b2aa15d 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -1183,9 +1183,9 @@ def get_expected(): var1 = relay.var("tensorrt_0_i0", shape=(data_shape), dtype="float32") kernel_trt = relay.var("tensorrt_0_i1", shape=(k_shape), dtype="float32") out1 = relay.nn.conv2d(var1, kernel_trt, channels=k_shape[0], kernel_size=k_shape[2:4]) - f1 = GlobalVar("tensorrt_0") + f1 = GlobalVar("tvmgen_default_tensorrt_0") func = relay.Function([var1, kernel_trt], out1) - func = set_func_attr(func, "tensorrt", "tensorrt_0") + func = set_func_attr(func, "tensorrt", "tvmgen_default_tensorrt_0") mod[f1] = func mod = relay.transform.InferType()(mod) @@ -1402,7 +1402,7 @@ def test_empty_subgraph(): var1 = relay.var("tensorrt_0_i0", shape=(x_shape), dtype="float32") f1 = GlobalVar("tensorrt_0") func = relay.Function([var1], var1) - func = set_func_attr(func, "tensorrt", "tensorrt_0") + func = set_func_attr(func, "tensorrt", "tvmgen_default_tensorrt_0") mod[f1] = func mod = relay.transform.InferType()(mod) diff --git a/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py b/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py index 4d5d5dc92c419..18c57d485d76f 100644 --- a/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py +++ b/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py @@ -269,7 +269,6 @@ def partition(dpu_target): with tvm.transform.PassContext(opt_level=3): mod = opt_pass(mod) - return mod def expected(): @@ -289,8 +288,8 @@ def expected(): func0 = relay.Function( [data0, weight0, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0], bn.astuple() ) - func0 = set_func_attr(func0, "vitis_ai", "vitis_ai_0") - gv0 = relay.GlobalVar("vitis_ai_0") + func0 = set_func_attr(func0, "vitis_ai", "tvmgen_default_vitis_ai_0") + gv0 = relay.GlobalVar("tvmgen_default_vitis_ai_0") mod = tvm.IRModule() mod[gv0] = func0 mod = relay.transform.InferType()(mod) diff --git a/tests/python/relay/aot/aot_test.mk b/tests/python/relay/aot/aot_test.mk index 793a8b1ea69a9..2426d9fd2963f 100644 --- a/tests/python/relay/aot/aot_test.mk +++ b/tests/python/relay/aot/aot_test.mk @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# Makefile to build ethosu_test_runner # Setup build environment # AOT_ROOT ?= $(TVM_ROOT)/src/runtime/crt/aot @@ -47,7 +46,7 @@ CRT_SRCS = $(shell find $(CRT_ROOT)) aot_test_runner: $(build_dir)/aot_test_runner -source_libs= $(wildcard $(build_dir)/../codegen/host/src/lib*.c) +source_libs= $(wildcard $(build_dir)/../codegen/host/src/*.c) lib_objs =$(source_libs:.c=.o) $(build_dir)/aot_test_runner: $(build_dir)/test.c $(build_dir)/aot_executor.o $(source_libs) $(build_dir)/stack_allocator.o $(build_dir)/crt_backend_api.o diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index c1917674873de..539009d3e1774 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -32,10 +32,16 @@ from tvm.relay import transform from tvm.contrib import utils, graph_executor from tvm.relay.backend import compile_engine +from tvm.relay.backend.utils import mangle_module_name from tvm.contrib import utils from tvm.micro import export_model_library_format +def mangle_name(mod_name, name): + mod_name = mangle_module_name(mod_name) + return mod_name + "_" + name + + def subprocess_with_stdout_and_log(cmd, cwd, logfile, stdout): """ This method runs a process and logs the output to both a log file and stdout @@ -56,28 +62,16 @@ def subprocess_with_stdout_and_log(cmd, cwd, logfile, stdout): print(text, end="") -def create_main(test_name, input_list, output_list, output_path, workspace_bytes): - file_path = pathlib.Path(f"{output_path}/" + test_name).resolve() - # create header file - raw_path = file_path.with_suffix(".c").resolve() - with open(raw_path, "w") as main_file: - main_file.write("#include \n") - main_file.write("#include \n") - main_file.write('#include "tvm/runtime/crt/internal/aot_executor/aot_executor.h"\n') - main_file.write('#include "tvm/runtime/crt/stack_allocator.h"\n') - main_file.write(f"#define WORKSPACE_SIZE ({workspace_bytes})\n") - main_file.write("static uint8_t g_aot_memory[WORKSPACE_SIZE];\n") - - for i in range(0, len(input_list)): - main_file.write('#include "input_data%i.h"\n' % i) - for i in range(0, len(output_list)): - main_file.write('#include "expected_output_data%i.h"\n' % i) - main_file.write('#include "output_data%i.h"\n' % i) - - main_file.write("extern tvm_model_t network;\n") - main_file.write("tvm_workspace_t app_workspace;\n") - main_file.write( - """ +def emit_main_network_definition(main_file, mod_name): + main_file.write(f'extern tvm_model_t {mangle_name(mod_name,"network")};\n') + + +def emit_main_prologue(main_file, workspace_bytes): + main_file.write(f"#define WORKSPACE_SIZE ({workspace_bytes})\n") + main_file.write("static uint8_t g_aot_memory[WORKSPACE_SIZE];\n") + main_file.write("tvm_workspace_t app_workspace;\n") + main_file.write( + """ tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) { return StackMemoryManager_Allocate(&app_workspace, num_bytes, out_ptr); } @@ -91,48 +85,102 @@ def create_main(test_name, input_list, output_list, output_path, workspace_bytes void TVMLogf(const char* msg, ...) { } TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) {} +int main(){\n """ - ) - main_file.write("int main(){\n") - main_file.write("void* inputs[%i] = { " % (len(input_list))) - - for i in range(0, len(input_list)): - main_file.write("input_data%i, " % i) - main_file.write("};\n") - - main_file.write("void* outputs[%i] = { " % (len(output_list))) - for i in range(0, len(output_list)): - main_file.write("output_data%i, " % i) - main_file.write("};\n") - - main_file.write("StackMemoryManager_Init(&app_workspace, g_aot_memory, WORKSPACE_SIZE);") - main_file.write("tvm_runtime_run(&network, inputs, outputs);") - - for i in range(0, len(output_list)): - is_float_dtype = output_list[i].dtype == "float32" - main_file.write("for (int i = 0; i 0.001f){printf("ko\\n");return -1;}\n' - % (i, i) - ) - else: - main_file.write( - 'if (output_data%s[i]!=expected_output_data%s[i]){printf("ko\\n");return -1;}\n' - % (i, i) - ) - main_file.write("}\n") - - main_file.write('printf("ok\\n");') - main_file.write("return 0;") + ) + + +def emit_main_data(main_file, input_list, output_list, mod_name): + for i in range(0, len(input_list)): + main_file.write(f'#include "{mangle_name(mod_name,"input_data")}{i}.h"\n') + + for i in range(0, len(output_list)): + main_file.write(f'#include "{mangle_name(mod_name,"expected_output_data")}{i}.h"\n') + main_file.write(f'#include "{mangle_name(mod_name,"output_data")}{i}.h"\n') + + +def emit_main_run(main_file, input_list, output_list, mod_name): + num_outputs = len(output_list) + num_inputs = len(input_list) + + main_file.write(f'void* {mangle_name(mod_name,"inputs")}[{num_inputs}] = {{ ') + + for i in range(0, len(input_list)): + main_file.write(f'{mangle_name(mod_name,"input_data")}{i}, ') + main_file.write("};\n") + + main_file.write(f'void* {mangle_name(mod_name,"outputs")}[{num_outputs}] = {{ ') + for i in range(0, len(output_list)): + main_file.write(f'{mangle_name(mod_name,"output_data")}{i}, ') + main_file.write("};\n") + main_file.write( + f'tvm_runtime_run(&{mangle_name(mod_name,"network")}, {mangle_name(mod_name,"inputs")}, {mangle_name(mod_name,"outputs")});' + ) + + +def emit_main_compare(main_file, output_list, mod_name): + for i in range(0, len(output_list)): + is_float_dtype = output_list[i].dtype == "float32" + main_file.write(f'for (int i = 0; i<{mangle_name(mod_name,"output_data")}{i}_len; i++){{\n') + if is_float_dtype: + main_file.write( + f'if (fabs({mangle_name(mod_name,"output_data")}{i}[i]-{mangle_name(mod_name,"expected_output_data")}{i}[i]) > 0.001f){{printf("ko\\n");return -1;}}\n' + ) + else: + main_file.write( + f'if ({mangle_name(mod_name,"output_data")}{i}[i]!={mangle_name(mod_name, "expected_output_data")}{i}[i]){{printf("ko\\n");return -1;}}\n' + ) main_file.write("}\n") +def emit_main_init_memory_manager(main_file): + main_file.write("StackMemoryManager_Init(&app_workspace, g_aot_memory, WORKSPACE_SIZE);") + + +def emit_main_epilogue(main_file): + main_file.write('printf("ok\\n");') + main_file.write("return 0;") + main_file.write("}\n") + + +def emit_main_common_includes(main_file): + main_file.write("#include \n") + main_file.write("#include \n") + main_file.write('#include "tvm/runtime/crt/internal/aot_executor/aot_executor.h"\n') + main_file.write('#include "tvm/runtime/crt/stack_allocator.h"\n') + + +def create_main(test_name, input_list_map, output_list_map, output_path, workspace_bytes): + file_path = pathlib.Path(f"{output_path}/" + test_name).resolve() + # create header file + raw_path = file_path.with_suffix(".c").resolve() + with open(raw_path, "w") as main_file: + emit_main_common_includes(main_file) + + for k in input_list_map: + emit_main_network_definition(main_file, k) + + emit_main_prologue(main_file, workspace_bytes) + + for k in input_list_map: + emit_main_data(main_file, input_list_map[k], output_list_map[k], k) + + emit_main_init_memory_manager(main_file) + + for k in input_list_map: + emit_main_run(main_file, input_list_map[k], output_list_map[k], k) + + for k in input_list_map: + emit_main_compare(main_file, output_list_map[k], k) + + emit_main_epilogue(main_file) + + def create_header_file(tensor_name, npy_data, output_path): """ This method generates a header file containing the data contained in the numpy array provided. - It is used to capture the tensor data (for both inputs and expected outputs) to be bundled into the standalone ethosu_test_runner. + It is used to capture the tensor data (for both inputs and expected outputs) to be bundled into the standalone application. """ file_path = pathlib.Path(f"{output_path}/" + tensor_name).resolve() # create header file @@ -165,7 +213,13 @@ def extract_main_workspace_sizebytes(extract_dir): def compile_and_run( - mod, input_list, output_list, use_calculated_workspaces, params=None, workspace_byte_alignment=8 + mod, + input_list, + output_list, + use_calculated_workspaces, + params=None, + workspace_byte_alignment=8, + mod_name=None, ): """ This method verifies the generated source @@ -178,7 +232,7 @@ def compile_and_run( cflags += "-DTVM_CRT_STACK_ALLOCATOR_ENABLE_LIFO_CHECK " with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - lib = tvm.relay.build(mod, target, target_host=target, params=params) + lib = tvm.relay.build(mod, target, target_host=target, params=params, mod_name=mod_name) tmp_path = utils.tempdir() tmp_dir = tmp_path.temp_dir @@ -197,17 +251,21 @@ def compile_and_run( workspace_bytes = 16384 * 1024 for i in range(len(input_list)): - create_header_file((f"input_data{i}"), input_list[i], build_path) + create_header_file((f'{mangle_name(mod_name, "input_data")}{i}'), input_list[i], build_path) for i in range(len(output_list)): create_header_file( - (f"output_data{i}"), + (f'{mangle_name(mod_name,"output_data")}{i}'), np.zeros(output_list[i].shape, output_list[i].dtype), build_path, ) - create_header_file((f"expected_output_data{i}"), output_list[i], build_path) + create_header_file( + (f'{mangle_name(mod_name, "expected_output_data")}{i}'), output_list[i], build_path + ) - create_main("test.c", input_list, output_list, build_path, workspace_bytes) + create_main( + "test.c", {mod_name: input_list}, {mod_name: output_list}, build_path, workspace_bytes + ) # Verify that compiles fine file_dir = os.path.dirname(os.path.abspath(__file__)) @@ -228,6 +286,64 @@ def compile_and_run( assert ret == 0 +def compile_and_run_multiple_models(mod_map, input_list_map, output_list_map, param_map): + """ + This method verifies the generated source + """ + target = "c -runtime=c --link-params --executor=aot" + tmp_path = utils.tempdir() + tmp_dir = tmp_path.temp_dir + + base_path = os.path.join(tmp_dir, "test") + build_path = os.path.join(base_path, "build") + os.makedirs(build_path, exist_ok=True) + for mod_name, mod in mod_map.items(): + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + lib = tvm.relay.build( + mod, target, target_host=target, params=param_map[mod_name], mod_name=mod_name + ) + + tar_file = os.path.join(base_path, "test.tar") + export_model_library_format(lib, tar_file) + t = tarfile.open(tar_file) + t.extractall(base_path) + + input_list = input_list_map[mod_name] + output_list = output_list_map[mod_name] + + for i in range(len(input_list_map[mod_name])): + create_header_file( + (f'{mangle_name(mod_name,"input_data")}{i}'), input_list[i], build_path + ) + + for i in range(len(output_list_map[mod_name])): + create_header_file( + (f'{mangle_name(mod_name,"output_data")}{i}'), + np.zeros(output_list[i].shape, output_list[i].dtype), + build_path, + ) + create_header_file( + (f'{mangle_name(mod_name,"expected_output_data")}{i}'), output_list[i], build_path + ) + + create_main("test.c", input_list_map, output_list_map, build_path, workspace_bytes=16384 * 1024) + + # Verify that compiles fine + file_dir = os.path.dirname(os.path.abspath(__file__)) + makefile = os.path.join(file_dir, "aot_test.mk") + make_cmd = f"make -f {makefile} build_dir=" + build_path + f" TVM_ROOT={file_dir}/../../../.." + + compile_log_path = os.path.join(build_path, "test_compile.log") + ret = subprocess_with_stdout_and_log(make_cmd, ".", compile_log_path, False) + assert ret == 0 + + # Verify that runs fine + run_log_path = os.path.join(build_path, "test_run.log") + ret = subprocess_with_stdout_and_log("./aot_test_runner", build_path, run_log_path, False) + assert ret == 0 + + def generate_ref_data(mod, input_data, params=None, target="llvm"): """Generate reference data through executing the relay module""" compile_engine.get().clear() diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 02b4de3a64f34..762eb95019b58 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -348,7 +348,8 @@ def test_byoc_utvm(use_calculated_workspaces): mod = tvm.IRModule() ann = CcompilerAnnotator() mod["main"] = ann.visit(f) - mod = tvm.relay.transform.PartitionGraph()(mod) + + mod = tvm.relay.transform.PartitionGraph("mod_name")(mod) mod = tvm.relay.transform.InferType()(mod) x_data = np.random.rand(10, 10).astype("float32") @@ -361,7 +362,78 @@ def test_byoc_utvm(use_calculated_workspaces): output_list = generate_ref_data(mod, map_inputs) input_list = [map_inputs["x"]] input_list.extend([map_inputs["w{}".format(i)] for i in range(8)]) - compile_and_run(mod, input_list, output_list, use_calculated_workspaces) + compile_and_run(mod, input_list, output_list, use_calculated_workspaces, mod_name="my_mod") + + +def test_add_name_mangling_with_params(): + x = relay.var("x", shape=(1, 10)) + y = relay.var("y", shape=(1, 10)) + z = relay.add(x, y) + func = relay.Function([x, y], z) + + x_in = np.ones((1, 10)).astype("float32") + y_in = np.random.uniform(size=(1, 10)).astype("float32") + + params = {"x": x_in} + inputs = {"y": y_in} + output_list = generate_ref_data(func, inputs, params) + + input_list = [y_in] + compile_and_run( + func, + input_list, + output_list, + use_calculated_workspaces=False, + params=params, + mod_name="my_mod", + ) + + +def test_multiple_models(): + # Identity model without params + x = relay.var("x", "float32") + mod1 = relay.Function([x], x) + one = np.array(1.0, "float32") + inputs1 = {"x": one} + output_list1 = generate_ref_data(mod1, inputs1) + input_list1 = [one] + params1 = None + + # Convolution model + RELAY_MODEL = """ +#[version = "0.0.5"] +def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), int8]) { + %1 = nn.conv2d( + %data, + %weight, + padding=[2, 2], + channels=8, + kernel_size=[5, 5], + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32"); + %1 +} +""" + mod2 = tvm.parser.fromtext(RELAY_MODEL) + main_func = mod2["main"] + shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in main_func.params} + type_dict = {p.name_hint: p.checked_type.dtype for p in main_func.params} + + weight_data = np.ones(shape_dict["weight"]).astype(type_dict["weight"]) + input_data = np.ones(shape_dict["data"]).astype(type_dict["data"]) + + params2 = {"weight": weight_data} + inputs2 = {"data": input_data} + output_list2 = generate_ref_data(mod2, inputs2, params2) + input_list2 = [input_data] + + input_list_map = {"mod1": input_list1, "mod2": input_list2} + output_list_map = {"mod1": output_list1, "mod2": output_list2} + mod_map = {"mod1": mod1, "mod2": mod2} + param_map = {"mod1": params1, "mod2": params2} + + compile_and_run_multiple_models(mod_map, input_list_map, output_list_map, param_map) if __name__ == "__main__": diff --git a/tests/python/relay/test_json_runtime.py b/tests/python/relay/test_json_runtime.py index 52e082e27b747..f5674dbf5fb39 100644 --- a/tests/python/relay/test_json_runtime.py +++ b/tests/python/relay/test_json_runtime.py @@ -104,8 +104,8 @@ def conv2d_direct(): out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1)) func = relay.Function([data0, weight0], out) - func = set_func_attr(func, "dnnl", "dnnl_0") - glb_var = relay.GlobalVar("dnnl_0") + func = set_func_attr(func, "dnnl", "tvmgen_default_dnnl_0") + glb_var = relay.GlobalVar("tvmgen_default_dnnl_0") mod = tvm.IRModule() mod[glb_var] = func mod = transform.InferType()(mod) @@ -139,8 +139,8 @@ def group_conv2d(): out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1), groups=32) func = relay.Function([data0, weight0], out) - func = set_func_attr(func, "dnnl", "dnnl_0") - glb_var = relay.GlobalVar("dnnl_0") + func = set_func_attr(func, "dnnl", "tvmgen_default_dnnl_0") + glb_var = relay.GlobalVar("tvmgen_default_dnnl_0") mod = tvm.IRModule() mod[glb_var] = func mod = transform.InferType()(mod) @@ -183,8 +183,8 @@ def gen_add(): out = relay.add(data0, data1) func = relay.Function([data0, data1], out) - func = set_func_attr(func, "dnnl", "dnnl_0") - glb_var = relay.GlobalVar("dnnl_0") + func = set_func_attr(func, "dnnl", "tvmgen_default_dnnl_0") + glb_var = relay.GlobalVar("tvmgen_default_dnnl_0") mod = tvm.IRModule() mod[glb_var] = func mod = transform.InferType()(mod) @@ -226,8 +226,8 @@ def gen_relu(): out = relay.nn.relu(data0) func = relay.Function([data0], out) - func = set_func_attr(func, "dnnl", "dnnl_0") - glb_var = relay.GlobalVar("dnnl_0") + func = set_func_attr(func, "dnnl", "tvmgen_default_dnnl_0") + glb_var = relay.GlobalVar("tvmgen_default_dnnl_0") mod = tvm.IRModule() mod[glb_var] = func mod = transform.InferType()(mod) @@ -276,8 +276,8 @@ def gen_dense(): out = relay.nn.dense(a, b) func = relay.Function([a, b], out) - func = set_func_attr(func, "dnnl", "dnnl_0") - glb_var = relay.GlobalVar("dnnl_0") + func = set_func_attr(func, "dnnl", "tvmgen_default_dnnl_0") + glb_var = relay.GlobalVar("tvmgen_default_dnnl_0") mod = tvm.IRModule() mod[glb_var] = func mod = transform.InferType()(mod) @@ -325,8 +325,8 @@ def gen_bn(): out = bn[0] func = relay.Function([data, gamma, beta, moving_mean, moving_var], out) - func = set_func_attr(func, "dnnl", "dnnl_0") - glb_var = relay.GlobalVar("dnnl_0") + func = set_func_attr(func, "dnnl", "tvmgen_default_dnnl_0") + glb_var = relay.GlobalVar("tvmgen_default_dnnl_0") mod = tvm.IRModule() mod[glb_var] = func mod = transform.InferType()(mod) @@ -471,8 +471,8 @@ def conv2d_relu(): arg_2 = relay.var("arg_2", shape=w1shape, dtype=dtype) call = relay.Call(func, [arg_1, arg_2]) p_func = relay.Function([arg_1, arg_2], call) - p_func = set_func_attr(p_func, "dnnl", "dnnl_0") - glb_var = relay.GlobalVar("dnnl_0") + p_func = set_func_attr(p_func, "dnnl", "tvmgen_default_dnnl_0") + glb_var = relay.GlobalVar("tvmgen_default_dnnl_0") mod = tvm.IRModule() mod[glb_var] = p_func mod = transform.InferType()(mod) @@ -521,8 +521,8 @@ def conv2d_bias_relu(): arg_3 = relay.var("arg_3", shape=bshape, dtype=dtype) call = relay.Call(func, [arg_1, arg_2, arg_3]) p_func = relay.Function([arg_1, arg_2, arg_3], call) - p_func = set_func_attr(p_func, "dnnl", "dnnl_0") - glb_var = relay.GlobalVar("dnnl_0") + p_func = set_func_attr(p_func, "dnnl", "tvmgen_default_dnnl_0") + glb_var = relay.GlobalVar("tvmgen_default_dnnl_0") mod = tvm.IRModule() mod[glb_var] = p_func mod = transform.InferType()(mod) diff --git a/tests/python/relay/test_op_fast_math.py b/tests/python/relay/test_op_fast_math.py index f968dbedddfe2..20ccefed8513b 100644 --- a/tests/python/relay/test_op_fast_math.py +++ b/tests/python/relay/test_op_fast_math.py @@ -41,7 +41,7 @@ def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"): graph, lib, params = relay.build(mod, target=target, params=None) # Check that the op related to fast math have been convered to function in lib - func_name = "fused_" + name + func_name = "tvmgen_default_fused_" + name # When there're multiple targets in tvm.testing.parametrize_targets, the function # built will have a "_1" in function name assert func_name in graph diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 4db8bd5e7b5bc..98d7161ae36c3 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -339,8 +339,8 @@ def expected(): add = x0 + y0 # Function that uses C compiler func = relay.Function([x0, y0], add) - func = set_func_attr(func, "ccompiler", "ccompiler_0") - glb_0 = relay.GlobalVar("ccompiler_0") + func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_0") + glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_0") mod[glb_0] = func add_call = relay.Call(glb_0, [x, y]) # Function that uses default compiler. Ops are fused in this function. @@ -416,8 +416,8 @@ def expected(): out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) func = relay.Function([data0, input0], out) - func = set_func_attr(func, "dnnl", "dnnl_0") - glb_var = relay.GlobalVar("dnnl_0") + func = set_func_attr(func, "dnnl", "tvmgen_default_dnnl_0") + glb_var = relay.GlobalVar("tvmgen_default_dnnl_0") mod = tvm.IRModule() mod[glb_var] = func mod = transform.InferType()(mod) @@ -532,8 +532,8 @@ def expected(): bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar) func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn.astuple()) - func0 = set_func_attr(func0, "test_compiler", "test_compiler_2") - gv0 = relay.GlobalVar("test_compiler_2") + func0 = set_func_attr(func0, "test_compiler", "tvmgen_default_test_compiler_2") + gv0 = relay.GlobalVar("tvmgen_default_test_compiler_2") mod[gv0] = func0 mod = transform.InferType()(mod) @@ -544,8 +544,8 @@ def expected(): data=data1, weight=weight1, kernel_size=(3, 3), channels=16, padding=(1, 1) ) func1 = relay.Function([data1, weight1], conv) - func1 = set_func_attr(func1, "test_compiler", "test_compiler_0") - gv1 = relay.GlobalVar("test_compiler_0") + func1 = set_func_attr(func1, "test_compiler", "tvmgen_default_test_compiler_0") + gv1 = relay.GlobalVar("tvmgen_default_test_compiler_0") mod[gv1] = func1 mod = transform.InferType()(mod) @@ -613,7 +613,7 @@ def expected(): bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar) func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn.astuple()) - func0 = set_func_attr(func0, "test_compiler", "test_compiler_0") + func0 = set_func_attr(func0, "test_compiler", "tvmgen_default_test_compiler_0") # main function data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32")) @@ -643,8 +643,8 @@ def expected(): add = x0 + y0 # Function that uses C compiler func = relay.Function([y0], add) - func = set_func_attr(func, "ccompiler", "ccompiler_0") - glb_0 = relay.GlobalVar("ccompiler_0") + func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_0") + glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_0") mod[glb_0] = func mod = relay.transform.InferType()(mod) add_call = relay.Call(glb_0, [y]) @@ -733,8 +733,8 @@ def expected(): tuple_o = relay.Tuple((relu_o, bn_o[1], bn_o[2])) func0 = relay.Function([data, weight, bn_gamma, bn_beta, bn_mean, bn_var], tuple_o) - func0 = set_func_attr(func0, "test_target", "test_target_0") - gv0 = relay.GlobalVar("test_target_0") + func0 = set_func_attr(func0, "test_target", "tvmgen_default_test_target_0") + gv0 = relay.GlobalVar("tvmgen_default_test_target_0") mod[gv0] = func0 mod = relay.transform.InferType()(mod) @@ -796,8 +796,8 @@ def expected(): f1_O_2 = relay.nn.relu(f1_O_1) f1_out = relay.Tuple((f1_O_2, f1_O_1)) func1 = relay.Function([f1_cb1], f1_out) - func1 = set_func_attr(func1, "test_target", "test_target_0") - gv1 = relay.GlobalVar("test_target_0") + func1 = set_func_attr(func1, "test_target", "tvmgen_default_test_target_0") + gv1 = relay.GlobalVar("tvmgen_default_test_target_0") mod[gv1] = func1 mod = relay.transform.InferType()(mod) @@ -806,8 +806,8 @@ def expected(): f2_cb4 = relay.var("test_target_1_i1", shape=(10, 10)) f2_O_3 = relay.add(f2_cb3, f2_cb4) func0 = relay.Function([f2_cb3, f2_cb4], f2_O_3) - func0 = set_func_attr(func0, "test_target", "test_target_1") - gv0 = relay.GlobalVar("test_target_1") + func0 = set_func_attr(func0, "test_target", "tvmgen_default_test_target_1") + gv0 = relay.GlobalVar("tvmgen_default_test_target_1") mod[gv0] = func0 mod = relay.transform.InferType()(mod) @@ -955,8 +955,8 @@ def expected_same_output_region(): mul = log * sub # The partitioned graph contains log, subtract, and multiply func = relay.Function([x0, y0], mul) - func = set_func_attr(func, "ccompiler", "ccompiler_0") - glb_0 = relay.GlobalVar("ccompiler_0") + func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_0") + glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_0") mod[glb_0] = func mod = transform.InferType()(mod) @@ -977,8 +977,8 @@ def expected_different_output_region(): i0 = relay.var("i0", shape=(8, 8)) log = relay.log(i0) func = relay.Function([i0], log) - func = set_func_attr(func, "ccompiler", "ccompiler_0") - glb_0 = relay.GlobalVar("ccompiler_0") + func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_0") + glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_0") mod[glb_0] = func mod = transform.InferType()(mod) @@ -987,8 +987,8 @@ def expected_different_output_region(): y0 = relay.var("y0", shape=(8, 8)) sub = x0 - y0 func = relay.Function([x0, y0], sub) - func = set_func_attr(func, "ccompiler", "ccompiler_1") - glb_1 = relay.GlobalVar("ccompiler_1") + func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_1") + glb_1 = relay.GlobalVar("tvmgen_default_ccompiler_1") mod[glb_1] = func mod = transform.InferType()(mod) @@ -1063,8 +1063,8 @@ def expected(): func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Compiler", target) - func0 = func0.with_attr("global_symbol", target + "_0") - gv0 = relay.GlobalVar(target + "_0") + func0 = func0.with_attr("global_symbol", "tvmgen_default_" + target + "_0") + gv0 = relay.GlobalVar("tvmgen_default_" + target + "_0") mod[gv0] = func0 mod = transform.InferType()(mod) @@ -1140,8 +1140,8 @@ def expected(): func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Compiler", target) - func0 = func0.with_attr("global_symbol", target + "_0") - gv0 = relay.GlobalVar(target + "_0") + func0 = func0.with_attr("global_symbol", "tvmgen_default_" + target + "_0") + gv0 = relay.GlobalVar("tvmgen_default_" + target + "_0") mod[gv0] = func0 mod = transform.InferType()(mod) @@ -1216,7 +1216,7 @@ def create_graph(): partitioned = seq(create_graph()) - concat = partitioned["const_tuples_0"].body + concat = partitioned["tvmgen_default_const_tuples_0"].body assert type(concat.args[1]) == relay.Tuple assert type(concat.args[2]) == relay.Tuple assert type(concat.args[3]) == relay.Constant @@ -1266,8 +1266,8 @@ def expected(): func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Compiler", target) - func0 = func0.with_attr("global_symbol", target + "_0") - gv0 = relay.GlobalVar(target + "_0") + func0 = func0.with_attr("global_symbol", "tvmgen_default_" + target + "_0") + gv0 = relay.GlobalVar("tvmgen_default_" + target + "_0") mod[gv0] = func0 mod = transform.InferType()(mod) @@ -1349,7 +1349,7 @@ def Optimize(mod): mod = transform.PartitionGraph()(mod) try: - t0 = mod["test_target_0"] + t0 = mod["tvmgen_default_test_target_0"] except: raise KeyError("test_target_0 not found") diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index d2c519da22b5a..2922a3adf48b2 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -85,7 +85,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 2 + assert metadata["version"] == 3 assert metadata["model_name"] == "add" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" @@ -115,8 +115,8 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ in metadata["memory"]["functions"]["operator_functions"][0]["function_name"] ) - assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "lib0.c")) - assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "lib1.c")) + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "add_lib0.c")) + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "add_lib1.c")) if executor == "graph": validate_graph_json(extract_dir, factory) @@ -165,7 +165,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 2 + assert metadata["version"] == 3 assert metadata["model_name"] == "add" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" @@ -194,7 +194,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ in metadata["memory"]["functions"]["operator_functions"][0]["function_name"] ) - assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "lib", "lib0.o")) + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "lib", "add_lib0.o")) validate_graph_json(extract_dir, factory) @@ -244,7 +244,7 @@ def @main(%p0: Tensor[(1, 56, 56, 128), int16], %p1: Tensor[(3, 3, 128, 1), int1 with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 2 + assert metadata["version"] == 3 assert metadata["model_name"] == "qnn_conv2d" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ"