diff --git a/tests/python/contrib/test_arm_compute_lib/infrastructure.py b/tests/python/contrib/test_arm_compute_lib/infrastructure.py index c54973d58a767..1db893cc99287 100644 --- a/tests/python/contrib/test_arm_compute_lib/infrastructure.py +++ b/tests/python/contrib/test_arm_compute_lib/infrastructure.py @@ -131,7 +131,7 @@ def build_module(mod, target, params=None, enable_acl=True, tvm_ops=0, acl_parti def build_and_run(mod, inputs, outputs, params, device, enable_acl=True, no_runs=1, tvm_ops=0, acl_partitions=1, config=None): """Build and run the relay module.""" - if not config: + if config is None: config = {} try: @@ -177,12 +177,12 @@ def verify(answers, atol, rtol, verify_saturation=False, config=None): f"No results to compare: expected at least two, found {len(answers)}") for answer in zip_longest(*answers): for outs in combinations(answer, 2): - if verify_saturation: - assert np.count_nonzero(outs[0].asnumpy() == 255) < 0.25 * outs[0].asnumpy().size, \ - "Output is saturated: {}".format(outs[0]) - assert np.count_nonzero(outs[0].asnumpy() == 0) < 0.25 * outs[0].asnumpy().size, \ - "Output is saturated: {}".format(outs[0]) try: + if verify_saturation: + assert np.count_nonzero(outs[0].asnumpy() == 255) < 0.25 * outs[0].asnumpy().size, \ + "Output is saturated: {}".format(outs[0]) + assert np.count_nonzero(outs[0].asnumpy() == 0) < 0.25 * outs[0].asnumpy().size, \ + "Output is saturated: {}".format(outs[0]) tvm.testing.assert_allclose( outs[0].asnumpy(), outs[1].asnumpy(), rtol=rtol, atol=atol) except AssertionError as e: diff --git a/tests/python/contrib/test_arm_compute_lib/test_pooling.py b/tests/python/contrib/test_arm_compute_lib/test_pooling.py index b24b68ec2f23e..cb5305a52f87f 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_pooling.py +++ b/tests/python/contrib/test_arm_compute_lib/test_pooling.py @@ -149,7 +149,7 @@ def test_pooling(): pad = [(0, 0), (1, 1), (0, 1)] ceil_mode = [False, True] count_include_pad = [False, True] - input_shapes = [(8, 8, 16), (9, 9, 16)] + input_shapes = [(16, 16, 16), (15, 15, 16)] trials = generate_trials([typef, dtype, size, stride, pad, ceil_mode, count_include_pad, input_shapes], 3) for typef, (dtype, low, high, atol, rtol), size, stride, pad, ceil_mode, count_include_pad, \ @@ -178,12 +178,13 @@ def test_pooling(): "ceil_mode": ceil_mode, "count_include_pad": count_include_pad } + verify_saturation = True if dtype == "uint8" else False for acl in [False, True]: outputs.append(build_and_run(func, inputs, 1, None, device, enable_acl=acl, config=config)[0]) - verify(outputs, atol=atol, rtol=rtol, config=config) + verify(outputs, atol=atol, rtol=rtol, config=config, verify_saturation=verify_saturation) def test_global_pooling(): @@ -212,11 +213,13 @@ def test_global_pooling(): "pooling type": typef, "dtype": dtype, } + verify_saturation = True if dtype == "uint8" else False for acl in [False, True]: outputs.append(build_and_run(func, inputs, 1, None, device, enable_acl=acl, config=config)[0]) - verify(outputs, atol=atol, rtol=rtol, config=config, verify_saturation=True) + + verify(outputs, atol=atol, rtol=rtol, config=config, verify_saturation=verify_saturation) def test_codegen_pooling():