Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sanitize names of input tensors in interface header #8720

Merged
merged 3 commits into from
Sep 3, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Only test unpacked, C interface API, AOT case
Change-Id: I9082ae32079a1a3924c06c7f26c757aafa46dec2
grant-arm committed Sep 1, 2021

Verified

This commit was signed with the committer’s verified signature.
not-an-aardvark Teddy Katz
commit 1e07a0eef5d54b7e12379bb21f4ed0f3f75d3754
4 changes: 0 additions & 4 deletions tests/python/relay/aot/aot_test_utils.py
Original file line number Diff line number Diff line change
@@ -528,12 +528,8 @@ def compile_and_run(

workspace_bytes += extract_main_workspace_size_bytes(base_path)

sanitized_names = []
for key in model.inputs:
sanitized_tensor_name = re.sub(r"\W", "_", key)
if sanitized_tensor_name in sanitized_names:
raise ValueError(f"Sanitized input tensor name clash: {sanitized_tensor_name}")
sanitized_names.append(sanitized_tensor_name)
create_header_file(
f'{mangle_name(model.name, "input_data")}_{sanitized_tensor_name}',
model.inputs[key],
14 changes: 10 additions & 4 deletions tests/python/relay/aot/test_crt_aot.py
Original file line number Diff line number Diff line change
@@ -503,10 +503,13 @@ def test_transpose(interface_api, use_unpacked_api, test_runner):
)


@parametrize_aot_options
def test_name_sanitiser(interface_api, use_unpacked_api, test_runner):
def test_name_sanitiser():
"""Test that input tensors with special characters in the name don't break compilation"""

interface_api = "c"
use_unpacked_api = True
test_runner = AOT_DEFAULT_RUNNER

func = relay.var("input-x::2", "float32")
ident = relay.Function([func], func)
one = np.array(1.0, "float32")
@@ -522,10 +525,13 @@ def test_name_sanitiser(interface_api, use_unpacked_api, test_runner):
)


@parametrize_aot_options
def test_name_sanitiser_name_clash(interface_api, use_unpacked_api, test_runner):
def test_name_sanitiser_name_clash():
"""Test that 2 input tensors with names that clash once sanitized, generates an error"""

interface_api = "c"
use_unpacked_api = True
test_runner = AOT_DEFAULT_RUNNER

dtype = "float32"
x = relay.var("input::-1", shape=(10, 5), dtype=dtype)
# Next 2 input tensor names will clash once sanitized.