diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index baa2397b2303..9274c47b07c9 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -528,8 +528,12 @@ def compile_and_run( workspace_bytes += extract_main_workspace_size_bytes(base_path) + sanitized_names = [] for key in model.inputs: sanitized_tensor_name = re.sub(r"\W", "_", key) + if sanitized_tensor_name in sanitized_names: + raise ValueError(f"Sanitized input tensor name clash: {sanitized_tensor_name}") + sanitized_names.append(sanitized_tensor_name) create_header_file( f'{mangle_name(model.name, "input_data")}_{sanitized_tensor_name}', model.inputs[key], diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 225d2b1b1bfa..86c6a958cb27 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -503,31 +503,28 @@ def test_transpose(interface_api, use_unpacked_api, test_runner): ) -def test_name_sanitiser(): +@parametrize_aot_options +def test_name_sanitiser(interface_api, use_unpacked_api, test_runner): """Test that input tensors with special characters in the name don't break compilation""" - use_calculated_workspaces = True - use_unpacked_api = True - interface_api = "c" - x = relay.var("input-x::2", "float32") - ident = relay.Function([x], x) + func = relay.var("input-x::2", "float32") + ident = relay.Function([func], func) one = np.array(1.0, "float32") inputs = {"input-x::2": one} output_list = generate_ref_data(ident, inputs) compile_and_run( - AOTTestModel(module=IRModule.from_expr(ident), inputs=inputs, outputs=output_list), + AOTTestModel(module=IRModule.from_expr(func), inputs=inputs, outputs=output_list), + test_runner, interface_api, use_unpacked_api, - use_calculated_workspaces, + enable_op_fusion=False, ) -def test_name_sanitiser_name_clash(): +@parametrize_aot_options +def test_name_sanitiser_name_clash(interface_api, use_unpacked_api, test_runner): """Test that 2 input tensors with names that clash once sanitized, generates an error""" - use_calculated_workspaces = True - use_unpacked_api = True - interface_api = "c" dtype = "float32" x = relay.var("input::-1", shape=(10, 5), dtype=dtype) @@ -549,9 +546,9 @@ def test_name_sanitiser_name_clash(): with pytest.raises(ValueError, match="Sanitized input tensor name clash"): compile_and_run( AOTTestModel(module=IRModule.from_expr(func), inputs=inputs, outputs=output_list), + test_runner, interface_api, use_unpacked_api, - use_calculated_workspaces, enable_op_fusion=False, )