From 0c187584444b7a9d6bc0e095be89669eb2f108bd Mon Sep 17 00:00:00 2001 From: Grant Watson Date: Tue, 31 Aug 2021 15:56:13 +0100 Subject: [PATCH] Only test unpacked, C interface API, AOT case Change-Id: I9082ae32079a1a3924c06c7f26c757aafa46dec2 --- tests/python/relay/aot/aot_test_utils.py | 4 ---- tests/python/relay/aot/test_crt_aot.py | 14 ++++++++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index 9274c47b07c9f..baa2397b23037 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -528,12 +528,8 @@ def compile_and_run( workspace_bytes += extract_main_workspace_size_bytes(base_path) - sanitized_names = [] for key in model.inputs: sanitized_tensor_name = re.sub(r"\W", "_", key) - if sanitized_tensor_name in sanitized_names: - raise ValueError(f"Sanitized input tensor name clash: {sanitized_tensor_name}") - sanitized_names.append(sanitized_tensor_name) create_header_file( f'{mangle_name(model.name, "input_data")}_{sanitized_tensor_name}', model.inputs[key], diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 86c6a958cb279..64000a9d56b37 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -503,10 +503,13 @@ def test_transpose(interface_api, use_unpacked_api, test_runner): ) -@parametrize_aot_options -def test_name_sanitiser(interface_api, use_unpacked_api, test_runner): +def test_name_sanitiser(): """Test that input tensors with special characters in the name don't break compilation""" + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_DEFAULT_RUNNER + func = relay.var("input-x::2", "float32") ident = relay.Function([func], func) one = np.array(1.0, "float32") @@ -522,10 +525,13 @@ def test_name_sanitiser(interface_api, use_unpacked_api, test_runner): ) -@parametrize_aot_options -def test_name_sanitiser_name_clash(interface_api, use_unpacked_api, test_runner): +def test_name_sanitiser_name_clash(): """Test that 2 input tensors with names that clash once sanitized, generates an error""" + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_DEFAULT_RUNNER + dtype = "float32" x = relay.var("input::-1", shape=(10, 5), dtype=dtype) # Next 2 input tensor names will clash once sanitized.