Skip to content

Commit

Permalink
ENH: add crude assert_allclose; use value testing in vecdot
Browse files Browse the repository at this point in the history
  • Loading branch information
ev-br committed Nov 26, 2024
1 parent a71b4c0 commit 5983cac
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
16 changes: 16 additions & 0 deletions array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"assert_0d_equals",
"assert_fill",
"assert_array_elements",
"assert_allclose"
]


Expand Down Expand Up @@ -599,3 +600,18 @@ def assert_array_elements(
at_expected = expected[idx]
msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected)
assert at_out == at_expected, msg


def assert_allclose(actual, desired, *, atol=1e-7, rtol=1e-7, equal_nan=True, msg_extra=""):
if equal_nan:
# XXX assert same position, mask away
pass

msg = f"The input arrays do not have the same shapes ({actual.shape} != {desired.shape}){msg_extra}"
assert actual.shape == desired.shape, msg

msg = f"The input arrays do not have the same shapes ({actual.shape} != {desired.shape}){msg_extra}"
assert actual.dtype == desired.dtype, msg

delta = xp.abs(actual - desired)
assert xp.all(delta < atol + xp.abs(actual)*rtol), msg_extra
6 changes: 5 additions & 1 deletion array_api_tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None,
msg_extra = f'{x_idxes = }, {res_idx = }'
assert_equal(res_stack, decomp_res_stack, msg_extra)
if true_val:
assert_equal(decomp_res_stack, true_val(*x_stacks, **kw), msg_extra)
expected = true_val(*x_stacks, **kw)
if decomp_res_stack.dtype in dh.all_float_dtypes:
ph.assert_allclose(decomp_res_stack, expected, msg_extra=msg_extra)
else:
assert_equal(decomp_res_stack, expected, msg_extra)


def _test_namedtuple(res, fields, func_name):
Expand Down

0 comments on commit 5983cac

Please sign in to comment.