From ff94e42971decd395199c3fccbc6f4b478299e3d Mon Sep 17 00:00:00 2001 From: Grant Watson Date: Tue, 10 Aug 2021 15:46:59 +0100 Subject: [PATCH] Sanitize names of input tensors in interface header Change-Id: I7f02a993887bf84316262cd2586a734a9079c338 --- python/tvm/micro/interface_api.py | 4 +++- src/target/source/source_module.cc | 4 +++- tests/python/relay/aot/aot_test_utils.py | 17 +++++++++++++---- tests/python/relay/aot/test_crt_aot.py | 20 ++++++++++++++++++++ 4 files changed, 39 insertions(+), 6 deletions(-) diff --git a/python/tvm/micro/interface_api.py b/python/tvm/micro/interface_api.py index 8086b1ed65545..13213efdbea72 100644 --- a/python/tvm/micro/interface_api.py +++ b/python/tvm/micro/interface_api.py @@ -18,6 +18,7 @@ """Defines functions for generating a C interface header""" import os +import re from tvm.relay.backend.utils import mangle_module_name @@ -59,7 +60,8 @@ def generate_c_interface_header(module_name, inputs, outputs, output_path): _emit_brief(header_file, module_name, "Input tensor pointers") header_file.write(f"struct {mangled_name}_inputs {{\n") for input_name in inputs: - header_file.write(f" void* {input_name};\n") + sanitized_input_name = re.sub(r"\W+", "_", input_name) + header_file.write(f" void* {sanitized_input_name};\n") header_file.write("};\n\n") _emit_brief(header_file, module_name, "Output tensor pointers") diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 7728773b13d72..e451d2ce7ccc1 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -252,7 +252,9 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { << ") {"; code_ << "return " << run_func << "("; for (const auto& input : metadata_->inputs) { - code_ << "inputs->" << input << ","; + std::string sanitised_input = input; + std::replace_if(sanitised_input.begin(), sanitised_input.end(), ::ispunct, '_'); + code_ << "inputs->" << sanitised_input << ","; } if (metadata_->num_outputs == 1) { code_ << "outputs->output"; diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index 36c415ec8c839..a44cfa01e937c 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -21,6 +21,7 @@ import logging import os import pathlib +import re import shutil import subprocess import tarfile @@ -189,7 +190,10 @@ def emit_main_prologue(main_file, workspace_bytes): def emit_main_data(main_file, input_map, output_list, mod_name): for key in input_map: - main_file.write(f'#include "{mangle_name(mod_name,"input_data")}_{key}.h"\n') + sanitized_tensor_name = re.sub(r"\W+", "_", key) + main_file.write( + f'#include "{mangle_name(mod_name,"input_data")}_{sanitized_tensor_name}.h"\n' + ) for i in range(0, len(output_list)): main_file.write(f'#include "{mangle_name(mod_name,"expected_output_data")}{i}.h"\n') @@ -201,7 +205,10 @@ def emit_main_data_structs(main_file, input_map, output_list, mod_name): f"struct {mangle_name(mod_name, 'inputs')} {mangle_name(mod_name, 'inputs')} = {{" ) for key in input_map: - main_file.write(f"\t.{key} = {mangle_name(mod_name, 'input_data')}_{key},\n") + sanitized_tensor_name = re.sub(r"\W+", "_", key) + main_file.write( + f"\t.{sanitized_tensor_name} = {mangle_name(mod_name, 'input_data')}_{sanitized_tensor_name},\n" + ) main_file.write("};\n") main_file.write( @@ -222,7 +229,8 @@ def emit_main_data_setup(main_file, input_map, output_list, mod_name): main_file.write(f'void* {mangle_name(mod_name,"inputs")}[{num_inputs}] = {{ ') for key in input_map: - main_file.write(f'{mangle_name(mod_name,"input_data")}_{key}, ') + sanitized_tensor_name = re.sub(r"\W+", "_", key) + main_file.write(f'{mangle_name(mod_name,"input_data")}_{sanitized_tensor_name}, ') main_file.write("};\n") main_file.write(f'void* {mangle_name(mod_name,"outputs")}[{num_outputs}] = {{ ') @@ -458,8 +466,9 @@ def compile_and_run( workspace_bytes += 16384 * 1024 for key in model.inputs: + sanitized_tensor_name = re.sub(r"\W+", "_", key) create_header_file( - f'{mangle_name(model.name, "input_data")}_{key}', + f'{mangle_name(model.name, "input_data")}_{sanitized_tensor_name}', model.inputs[key], include_path, ) diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index abbf350bff775..caecff903fb96 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -500,5 +500,25 @@ def test_transpose(interface_api, use_unpacked_api, use_calculated_workspaces): ) +def test_name_sanitiser(): + """Test that input tensors with special characters in the name don't break compilation""" + use_calculated_workspaces = True + use_unpacked_api = True + interface_api = "c" + + x = relay.var("input-x:2", "float32") + ident = relay.Function([x], x) + one = np.array(1.0, "float32") + inputs = {"input-x:2": one} + output_list = generate_ref_data(ident, inputs) + + compile_and_run( + AOTTestModel(module=IRModule.from_expr(ident), inputs=inputs, outputs=output_list), + interface_api, + use_unpacked_api, + use_calculated_workspaces, + ) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))