Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sanitize names of input tensors in interface header #8720

Merged
merged 3 commits into from
Sep 3, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update tensor name sanitizer tests to parameterize them.
Change-Id: I157d8d8d607de2904285e403893f146e97b510d5
grant-arm committed Sep 1, 2021

Verified

This commit was signed with the committer’s verified signature.
not-an-aardvark Teddy Katz
commit 0f3ad935bbadef84a9309d04d0ddbb83cc09f490
4 changes: 4 additions & 0 deletions tests/python/relay/aot/aot_test_utils.py
Original file line number Diff line number Diff line change
@@ -528,8 +528,12 @@ def compile_and_run(

workspace_bytes += extract_main_workspace_size_bytes(base_path)

sanitized_names = []
for key in model.inputs:
sanitized_tensor_name = re.sub(r"\W", "_", key)
if sanitized_tensor_name in sanitized_names:
raise ValueError(f"Sanitized input tensor name clash: {sanitized_tensor_name}")
sanitized_names.append(sanitized_tensor_name)
create_header_file(
f'{mangle_name(model.name, "input_data")}_{sanitized_tensor_name}',
model.inputs[key],
23 changes: 10 additions & 13 deletions tests/python/relay/aot/test_crt_aot.py
Original file line number Diff line number Diff line change
@@ -503,31 +503,28 @@ def test_transpose(interface_api, use_unpacked_api, test_runner):
)


def test_name_sanitiser():
@parametrize_aot_options
def test_name_sanitiser(interface_api, use_unpacked_api, test_runner):
"""Test that input tensors with special characters in the name don't break compilation"""
use_calculated_workspaces = True
use_unpacked_api = True
interface_api = "c"

x = relay.var("input-x::2", "float32")
ident = relay.Function([x], x)
func = relay.var("input-x::2", "float32")
ident = relay.Function([func], func)
one = np.array(1.0, "float32")
inputs = {"input-x::2": one}
output_list = generate_ref_data(ident, inputs)

compile_and_run(
AOTTestModel(module=IRModule.from_expr(ident), inputs=inputs, outputs=output_list),
AOTTestModel(module=IRModule.from_expr(func), inputs=inputs, outputs=output_list),
test_runner,
interface_api,
use_unpacked_api,
use_calculated_workspaces,
enable_op_fusion=False,
)


def test_name_sanitiser_name_clash():
@parametrize_aot_options
grant-arm marked this conversation as resolved.
Show resolved Hide resolved
def test_name_sanitiser_name_clash(interface_api, use_unpacked_api, test_runner):
"""Test that 2 input tensors with names that clash once sanitized, generates an error"""
use_calculated_workspaces = True
use_unpacked_api = True
interface_api = "c"

dtype = "float32"
x = relay.var("input::-1", shape=(10, 5), dtype=dtype)
@@ -549,9 +546,9 @@ def test_name_sanitiser_name_clash():
with pytest.raises(ValueError, match="Sanitized input tensor name clash"):
compile_and_run(
AOTTestModel(module=IRModule.from_expr(func), inputs=inputs, outputs=output_list),
test_runner,
interface_api,
use_unpacked_api,
use_calculated_workspaces,
enable_op_fusion=False,
)