diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index f6b7ae25..4695c08c 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -28,6 +28,7 @@ "assert_0d_equals", "assert_fill", "assert_array_elements", + "assert_allclose" ] @@ -599,3 +600,18 @@ def assert_array_elements( at_expected = expected[idx] msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected) assert at_out == at_expected, msg + + +def assert_allclose(actual, desired, *, atol=1e-7, rtol=1e-7, equal_nan=True, msg_extra=""): + if equal_nan: + # XXX assert same position, mask away + pass + + msg = f"The input arrays do not have the same shapes ({actual.shape} != {desired.shape}){msg_extra}" + assert actual.shape == desired.shape, msg + + msg = f"The input arrays do not have the same shapes ({actual.shape} != {desired.shape}){msg_extra}" + assert actual.dtype == desired.dtype, msg + + delta = xp.abs(actual - desired) + assert xp.all(delta < atol + xp.abs(actual)*rtol), msg_extra diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index c997948b..5c366834 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -106,7 +106,11 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None, msg_extra = f'{x_idxes = }, {res_idx = }' assert_equal(res_stack, decomp_res_stack, msg_extra) if true_val: - assert_equal(decomp_res_stack, true_val(*x_stacks, **kw), msg_extra) + expected = true_val(*x_stacks, **kw) + if decomp_res_stack.dtype in dh.all_float_dtypes: + ph.assert_allclose(decomp_res_stack, expected, msg_extra=msg_extra) + else: + assert_equal(decomp_res_stack, expected, msg_extra) def _test_namedtuple(res, fields, func_name):