Skip to content

Commit

Permalink
Fix not checking output saturation
Browse files Browse the repository at this point in the history
Change-Id: Ia6f3d9db31cfb8c417d8556d29961210fea418b2
  • Loading branch information
lhutton1 committed Aug 14, 2020
1 parent 03f1b3c commit ec9befa
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
12 changes: 6 additions & 6 deletions tests/python/contrib/test_arm_compute_lib/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def build_module(mod, target, params=None, enable_acl=True, tvm_ops=0, acl_parti
def build_and_run(mod, inputs, outputs, params, device, enable_acl=True, no_runs=1,
tvm_ops=0, acl_partitions=1, config=None):
"""Build and run the relay module."""
if not config:
if config is None:
config = {}

try:
Expand Down Expand Up @@ -177,12 +177,12 @@ def verify(answers, atol, rtol, verify_saturation=False, config=None):
f"No results to compare: expected at least two, found {len(answers)}")
for answer in zip_longest(*answers):
for outs in combinations(answer, 2):
if verify_saturation:
assert np.count_nonzero(outs[0].asnumpy() == 255) < 0.25 * outs[0].asnumpy().size, \
"Output is saturated: {}".format(outs[0])
assert np.count_nonzero(outs[0].asnumpy() == 0) < 0.25 * outs[0].asnumpy().size, \
"Output is saturated: {}".format(outs[0])
try:
if verify_saturation:
assert np.count_nonzero(outs[0].asnumpy() == 255) < 0.25 * outs[0].asnumpy().size, \
"Output is saturated: {}".format(outs[0])
assert np.count_nonzero(outs[0].asnumpy() == 0) < 0.25 * outs[0].asnumpy().size, \
"Output is saturated: {}".format(outs[0])
tvm.testing.assert_allclose(
outs[0].asnumpy(), outs[1].asnumpy(), rtol=rtol, atol=atol)
except AssertionError as e:
Expand Down
9 changes: 6 additions & 3 deletions tests/python/contrib/test_arm_compute_lib/test_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_pooling():
pad = [(0, 0), (1, 1), (0, 1)]
ceil_mode = [False, True]
count_include_pad = [False, True]
input_shapes = [(8, 8, 16), (9, 9, 16)]
input_shapes = [(16, 16, 16), (15, 15, 16)]
trials = generate_trials([typef, dtype, size, stride, pad, ceil_mode, count_include_pad, input_shapes], 3)

for typef, (dtype, low, high, atol, rtol), size, stride, pad, ceil_mode, count_include_pad, \
Expand Down Expand Up @@ -178,12 +178,13 @@ def test_pooling():
"ceil_mode": ceil_mode,
"count_include_pad": count_include_pad
}
verify_saturation = True if dtype == "uint8" else False

for acl in [False, True]:
outputs.append(build_and_run(func, inputs, 1, None, device,
enable_acl=acl, config=config)[0])

verify(outputs, atol=atol, rtol=rtol, config=config)
verify(outputs, atol=atol, rtol=rtol, config=config, verify_saturation=verify_saturation)


def test_global_pooling():
Expand Down Expand Up @@ -212,11 +213,13 @@ def test_global_pooling():
"pooling type": typef,
"dtype": dtype,
}
verify_saturation = True if dtype == "uint8" else False

for acl in [False, True]:
outputs.append(build_and_run(func, inputs, 1, None, device,
enable_acl=acl, config=config)[0])
verify(outputs, atol=atol, rtol=rtol, config=config, verify_saturation=True)

verify(outputs, atol=atol, rtol=rtol, config=config, verify_saturation=verify_saturation)


def test_codegen_pooling():
Expand Down

0 comments on commit ec9befa

Please sign in to comment.