Skip to content

Commit

Permalink
[microTVM] Fix host-driven AOT memory workspaces (#13807)
Browse files Browse the repository at this point in the history
When using host-driven AOT with memory pools enabled, the workspace and constant memory were not properly supported. In order for them to work properly, the _run function (typically tvmgen_default_run()) needed to be called instead of tvmgen_default___tvm_main__() in order to properly setup the memory workspace pointers.

fixes #13777
alanmacd authored Jan 26, 2023

Verified

This commit was signed with the committer’s verified signature.
mrgrain Momo Kornher
1 parent f7dfef4 commit 4ddb876
Showing 4 changed files with 19 additions and 29 deletions.
13 changes: 1 addition & 12 deletions src/runtime/crt/aot_executor/aot_executor.c
Original file line number Diff line number Diff line change
@@ -83,7 +83,7 @@ int TVMAotExecutor_GetInputIndex(TVMAotExecutor* executor, const char* name) {
}

int TVMAotExecutor_Run(TVMAotExecutor* executor) {
const char* tvm_main_suffix = "___tvm_main__";
const char* tvm_main_suffix = "_run";
char tvm_main_name[TVM_CRT_MAX_STRLEN_FUNCTION_NAME];

{
@@ -203,17 +203,6 @@ int TVMAotExecutor_Init(TVMAotExecutor* executor, TVMModuleHandle module_handle,
TVMNDArray_IncrementReference(array);
}

for (i = 0; i < md->num_workspace_pools; ++i) {
LOG_DEBUG("pools allocate[%d]: %s\n", i, md->workspace_pools[i].name);

status = TVMNDArray_Empty(md->workspace_pools[i].num_shape, md->workspace_pools[i].shape,
md->workspace_pools[i].dtype, executor->device,
&executor->args[arg_idx++]);
if (status != 0) {
return status;
}
}
CHECK_EQ(0, md->num_constant_pools, "Constant pools not supported");
return status;
}

16 changes: 13 additions & 3 deletions src/target/source/source_module.cc
Original file line number Diff line number Diff line change
@@ -929,11 +929,21 @@ runtime::Module CreateCSourceCrtMetadataModule(const Array<runtime::Module>& mod
relay::backend::ExecutorCodegenMetadata metadata,
runtime::metadata::Metadata aot_metadata) {
Array<runtime::Module> final_modules(modules);
if (aot_metadata.defined()) {
final_modules.push_back(CreateAotMetadataModule(aot_metadata, true));
Array<String> func_names;

if (metadata.defined()) {
if (metadata->executor == "aot") {
if (aot_metadata.defined()) {
final_modules.push_back(CreateAotMetadataModule(aot_metadata, true));
}

// add the run function (typically "tvmgen_default_run") to function registry
// when using AOT executor
std::string run_func = runtime::get_name_mangled(metadata->mod_name, "run");
func_names.push_back(run_func);
}
}

Array<String> func_names;
for (runtime::Module mod : final_modules) {
auto pf_funcs = mod.GetFunction("get_func_names");
if (pf_funcs != nullptr) {
18 changes: 5 additions & 13 deletions tests/python/unittest/test_crt.py
Original file line number Diff line number Diff line change
@@ -229,15 +229,10 @@ def do_test():
do_test()


enable_usmp, expect_exception = tvm.testing.parameters((True, True), (False, False))


@tvm.testing.requires_micro
def test_aot_executor_usmp_const_pool(enable_usmp, expect_exception):
"""Test the AOT executor with microTVM using usmp.
Test should fail if const pool is supplied to executor
as these are currently not supported
"""
def test_aot_executor_usmp_const_pool():
"""Test the AOT executor with microTVM using USMP to generate a constant data pool."""

ws_root = pathlib.Path(os.path.dirname(__file__) + "/micro-workspace-usmp")
if ws_root.exists():
shutil.rmtree(ws_root)
@@ -260,7 +255,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8], %c : Tensor[(1
C_np = np.array([[8, 9]], dtype="uint8").astype(type_dict["c"])
params = {"c": C_np}
with tvm.transform.PassContext(
opt_level=3, config={"tir.disable_vectorize": True, "tir.usmp.enable": enable_usmp}
opt_level=3, config={"tir.disable_vectorize": True, "tir.usmp.enable": True}
):
factory = tvm.relay.build(
relay_mod,
@@ -278,10 +273,7 @@ def do_test():
)
)
except tvm._ffi.base.TVMError as e:
if expect_exception:
return
else:
raise e
raise e

assert aot_executor.get_input_index("a") == 0
assert aot_executor.get_input_index("b") == 1
1 change: 0 additions & 1 deletion tests/python/unittest/test_micro_model_library_format.py
Original file line number Diff line number Diff line change
@@ -618,7 +618,6 @@ def test_multiple_relay_modules_aot_graph():

assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod1_lib0.c"))
assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod1_lib1.c"))
assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod1_lib2.c"))
assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod2_lib0.c"))
assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod2_lib1.c"))

0 comments on commit 4ddb876

Please sign in to comment.