Skip to content

Commit

Permalink
Sanitize names of input tensors in interface header
Browse files Browse the repository at this point in the history
Change-Id: I7f02a993887bf84316262cd2586a734a9079c338
  • Loading branch information
grant-arm committed Aug 20, 2021
1 parent e12ddca commit 3b0e7bf
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 6 deletions.
13 changes: 12 additions & 1 deletion python/tvm/micro/interface_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@

"""Defines functions for generating a C interface header"""

# TODO: Currently the Interface API header is generated in Python but the source it references
# is generated in C++. These should be consolidated to generate both header and source in C++
# and avoid re-implementing logic, such as name sanitising, in the two different languages.
# See https://github.com/apache/tvm/issues/8792 .

import os
import re

from tvm.relay.backend.utils import mangle_module_name

Expand Down Expand Up @@ -58,8 +64,13 @@ def generate_c_interface_header(module_name, inputs, outputs, output_path):

_emit_brief(header_file, module_name, "Input tensor pointers")
header_file.write(f"struct {mangled_name}_inputs {{\n")
sanitized_names = []
for input_name in inputs:
header_file.write(f" void* {input_name};\n")
sanitized_input_name = re.sub(r"\W", "_", input_name)
if sanitized_input_name in sanitized_names:
raise ValueError(f"Sanitized input tensor name clash: {sanitized_input_name}")
sanitized_names.append(sanitized_input_name)
header_file.write(f" void* {sanitized_input_name};\n")
header_file.write("};\n\n")

_emit_brief(header_file, module_name, "Output tensor pointers")
Expand Down
6 changes: 5 additions & 1 deletion src/target/source/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
code_ << "}\n";
}

static int isNotAlnum(char c) { return !std::isalnum(c); }

void GenerateCInterfaceEntrypoint(const std::string& entrypoint_name, const std::string& run_func,
const std::string& mod_name) {
code_ << "#include <" << mod_name << ".h>\n";
Expand All @@ -252,7 +254,9 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
<< ") {";
code_ << "return " << run_func << "(";
for (const auto& input : metadata_->inputs) {
code_ << "inputs->" << input << ",";
std::string sanitised_input = input;
std::replace_if(sanitised_input.begin(), sanitised_input.end(), isNotAlnum, '_');
code_ << "inputs->" << sanitised_input << ",";
}
if (metadata_->num_outputs == 1) {
code_ << "outputs->output";
Expand Down
17 changes: 13 additions & 4 deletions tests/python/relay/aot/aot_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import logging
import os
import pathlib
import re
import shutil
import subprocess
import tarfile
Expand Down Expand Up @@ -189,7 +190,10 @@ def emit_main_prologue(main_file, workspace_bytes):

def emit_main_data(main_file, input_map, output_list, mod_name):
for key in input_map:
main_file.write(f'#include "{mangle_name(mod_name,"input_data")}_{key}.h"\n')
sanitized_tensor_name = re.sub(r"\W", "_", key)
main_file.write(
f'#include "{mangle_name(mod_name,"input_data")}_{sanitized_tensor_name}.h"\n'
)

for i in range(0, len(output_list)):
main_file.write(f'#include "{mangle_name(mod_name,"expected_output_data")}{i}.h"\n')
Expand All @@ -201,7 +205,10 @@ def emit_main_data_structs(main_file, input_map, output_list, mod_name):
f"struct {mangle_name(mod_name, 'inputs')} {mangle_name(mod_name, 'inputs')} = {{"
)
for key in input_map:
main_file.write(f"\t.{key} = {mangle_name(mod_name, 'input_data')}_{key},\n")
sanitized_tensor_name = re.sub(r"\W", "_", key)
main_file.write(
f"\t.{sanitized_tensor_name} = {mangle_name(mod_name, 'input_data')}_{sanitized_tensor_name},\n"
)
main_file.write("};\n")

main_file.write(
Expand All @@ -222,7 +229,8 @@ def emit_main_data_setup(main_file, input_map, output_list, mod_name):

main_file.write(f'void* {mangle_name(mod_name,"inputs")}[{num_inputs}] = {{ ')
for key in input_map:
main_file.write(f'{mangle_name(mod_name,"input_data")}_{key}, ')
sanitized_tensor_name = re.sub(r"\W", "_", key)
main_file.write(f'{mangle_name(mod_name,"input_data")}_{sanitized_tensor_name}, ')
main_file.write("};\n")

main_file.write(f'void* {mangle_name(mod_name,"outputs")}[{num_outputs}] = {{ ')
Expand Down Expand Up @@ -458,8 +466,9 @@ def compile_and_run(
workspace_bytes += 16384 * 1024

for key in model.inputs:
sanitized_tensor_name = re.sub(r"\W", "_", key)
create_header_file(
f'{mangle_name(model.name, "input_data")}_{key}',
f'{mangle_name(model.name, "input_data")}_{sanitized_tensor_name}',
model.inputs[key],
include_path,
)
Expand Down
53 changes: 53 additions & 0 deletions tests/python/relay/aot/test_crt_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,5 +500,58 @@ def test_transpose(interface_api, use_unpacked_api, use_calculated_workspaces):
)


def test_name_sanitiser():
"""Test that input tensors with special characters in the name don't break compilation"""
use_calculated_workspaces = True
use_unpacked_api = True
interface_api = "c"

x = relay.var("input-x::2", "float32")
ident = relay.Function([x], x)
one = np.array(1.0, "float32")
inputs = {"input-x::2": one}
output_list = generate_ref_data(ident, inputs)

compile_and_run(
AOTTestModel(module=IRModule.from_expr(ident), inputs=inputs, outputs=output_list),
interface_api,
use_unpacked_api,
use_calculated_workspaces,
)


def test_name_sanitiser_name_clash():
"""Test that 2 input tensors with names that clash once sanitized, generates an error"""
use_calculated_workspaces = True
use_unpacked_api = True
interface_api = "c"

dtype = "float32"
x = relay.var("input::-1", shape=(10, 5), dtype=dtype)
# Next 2 input tensor names will clash once sanitized.
y = relay.var("input::-2", shape=(10, 5), dtype=dtype)
t = relay.var("input:--2", shape=(), dtype=dtype)
a = relay.add(x, y)
b = relay.transpose(a)
z = relay.add(b, t)
# Check result.
func = relay.Function([x, y, t], z)
x_data = np.random.rand(10, 5).astype(dtype)
y_data = np.random.rand(10, 5).astype(dtype)
t_data = np.random.uniform(size=()).astype(dtype)

inputs = {"input::-1": x_data, "input::-2": y_data, "input:--2": t_data}
output_list = generate_ref_data(func, inputs)

with pytest.raises(ValueError, match="Sanitized input tensor name clash"):
compile_and_run(
AOTTestModel(module=IRModule.from_expr(func), inputs=inputs, outputs=output_list),
interface_api,
use_unpacked_api,
use_calculated_workspaces,
enable_op_fusion=False,
)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 3b0e7bf

Please sign in to comment.