Skip to content

Commit

Permalink
Sanitize names of input tensors in interface header
Browse files Browse the repository at this point in the history
Change-Id: I7f02a993887bf84316262cd2586a734a9079c338
  • Loading branch information
grant-arm committed Aug 15, 2021
1 parent e12ddca commit ff94e42
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 6 deletions.
4 changes: 3 additions & 1 deletion python/tvm/micro/interface_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Defines functions for generating a C interface header"""

import os
import re

from tvm.relay.backend.utils import mangle_module_name

Expand Down Expand Up @@ -59,7 +60,8 @@ def generate_c_interface_header(module_name, inputs, outputs, output_path):
_emit_brief(header_file, module_name, "Input tensor pointers")
header_file.write(f"struct {mangled_name}_inputs {{\n")
for input_name in inputs:
header_file.write(f" void* {input_name};\n")
sanitized_input_name = re.sub(r"\W+", "_", input_name)
header_file.write(f" void* {sanitized_input_name};\n")
header_file.write("};\n\n")

_emit_brief(header_file, module_name, "Output tensor pointers")
Expand Down
4 changes: 3 additions & 1 deletion src/target/source/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
<< ") {";
code_ << "return " << run_func << "(";
for (const auto& input : metadata_->inputs) {
code_ << "inputs->" << input << ",";
std::string sanitised_input = input;
std::replace_if(sanitised_input.begin(), sanitised_input.end(), ::ispunct, '_');
code_ << "inputs->" << sanitised_input << ",";
}
if (metadata_->num_outputs == 1) {
code_ << "outputs->output";
Expand Down
17 changes: 13 additions & 4 deletions tests/python/relay/aot/aot_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import logging
import os
import pathlib
import re
import shutil
import subprocess
import tarfile
Expand Down Expand Up @@ -189,7 +190,10 @@ def emit_main_prologue(main_file, workspace_bytes):

def emit_main_data(main_file, input_map, output_list, mod_name):
for key in input_map:
main_file.write(f'#include "{mangle_name(mod_name,"input_data")}_{key}.h"\n')
sanitized_tensor_name = re.sub(r"\W+", "_", key)
main_file.write(
f'#include "{mangle_name(mod_name,"input_data")}_{sanitized_tensor_name}.h"\n'
)

for i in range(0, len(output_list)):
main_file.write(f'#include "{mangle_name(mod_name,"expected_output_data")}{i}.h"\n')
Expand All @@ -201,7 +205,10 @@ def emit_main_data_structs(main_file, input_map, output_list, mod_name):
f"struct {mangle_name(mod_name, 'inputs')} {mangle_name(mod_name, 'inputs')} = {{"
)
for key in input_map:
main_file.write(f"\t.{key} = {mangle_name(mod_name, 'input_data')}_{key},\n")
sanitized_tensor_name = re.sub(r"\W+", "_", key)
main_file.write(
f"\t.{sanitized_tensor_name} = {mangle_name(mod_name, 'input_data')}_{sanitized_tensor_name},\n"
)
main_file.write("};\n")

main_file.write(
Expand All @@ -222,7 +229,8 @@ def emit_main_data_setup(main_file, input_map, output_list, mod_name):

main_file.write(f'void* {mangle_name(mod_name,"inputs")}[{num_inputs}] = {{ ')
for key in input_map:
main_file.write(f'{mangle_name(mod_name,"input_data")}_{key}, ')
sanitized_tensor_name = re.sub(r"\W+", "_", key)
main_file.write(f'{mangle_name(mod_name,"input_data")}_{sanitized_tensor_name}, ')
main_file.write("};\n")

main_file.write(f'void* {mangle_name(mod_name,"outputs")}[{num_outputs}] = {{ ')
Expand Down Expand Up @@ -458,8 +466,9 @@ def compile_and_run(
workspace_bytes += 16384 * 1024

for key in model.inputs:
sanitized_tensor_name = re.sub(r"\W+", "_", key)
create_header_file(
f'{mangle_name(model.name, "input_data")}_{key}',
f'{mangle_name(model.name, "input_data")}_{sanitized_tensor_name}',
model.inputs[key],
include_path,
)
Expand Down
20 changes: 20 additions & 0 deletions tests/python/relay/aot/test_crt_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,5 +500,25 @@ def test_transpose(interface_api, use_unpacked_api, use_calculated_workspaces):
)


def test_name_sanitiser():
"""Test that input tensors with special characters in the name don't break compilation"""
use_calculated_workspaces = True
use_unpacked_api = True
interface_api = "c"

x = relay.var("input-x:2", "float32")
ident = relay.Function([x], x)
one = np.array(1.0, "float32")
inputs = {"input-x:2": one}
output_list = generate_ref_data(ident, inputs)

compile_and_run(
AOTTestModel(module=IRModule.from_expr(ident), inputs=inputs, outputs=output_list),
interface_api,
use_unpacked_api,
use_calculated_workspaces,
)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit ff94e42

Please sign in to comment.