diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index bded0c6..d8ed018 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -152,14 +152,7 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None # spec in places where it either deviates from or is more strict than # NumPy behavior - def _check_allowed_dtypes( - self, - other: bool | int | float | Array, - dtype_category: str, - op: str, - *, - check_promotion: bool = True, - ) -> Array: + def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array: """ Helper function for operators to only allow specific input dtypes @@ -183,8 +176,7 @@ def _check_allowed_dtypes( # This will raise TypeError for type combinations that are not allowed # to promote in the spec (even if the NumPy array operator would # promote them). - if check_promotion: - res_dtype = _result_type(self.dtype, other.dtype) + res_dtype = _result_type(self.dtype, other.dtype) if op.startswith("__i"): # Note: NumPy will allow in-place operators in some cases where # the type promoted operator does not match the left-hand side @@ -578,7 +570,7 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: """ # Even though "all" dtypes are allowed, we still require them to be # promotable with each other. - other = self._check_allowed_dtypes(other, "all", "__eq__", check_promotion=False) + other = self._check_allowed_dtypes(other, "all", "__eq__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -612,7 +604,7 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __ge__. """ - other = self._check_allowed_dtypes(other, "real numeric", "__ge__", check_promotion=False) + other = self._check_allowed_dtypes(other, "real numeric", "__ge__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -646,7 +638,7 @@ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __gt__. """ - other = self._check_allowed_dtypes(other, "real numeric", "__gt__", check_promotion=False) + other = self._check_allowed_dtypes(other, "real numeric", "__gt__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -700,7 +692,7 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __le__. """ - other = self._check_allowed_dtypes(other, "real numeric", "__le__", check_promotion=False) + other = self._check_allowed_dtypes(other, "real numeric", "__le__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -722,7 +714,7 @@ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __lt__. """ - other = self._check_allowed_dtypes(other, "real numeric", "__lt__", check_promotion=False) + other = self._check_allowed_dtypes(other, "real numeric", "__lt__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -767,7 +759,7 @@ def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array: """ Performs the operation __ne__. """ - other = self._check_allowed_dtypes(other, "all", "__ne__", check_promotion=False) + other = self._check_allowed_dtypes(other, "all", "__ne__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index d4a108d..b39bd86 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -375,6 +375,8 @@ def equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.equal(x1._array, x2._array)) @@ -437,6 +439,8 @@ def greater(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in greater") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.greater(x1._array, x2._array)) @@ -449,6 +453,8 @@ def greater_equal(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in greater_equal") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.greater_equal(x1._array, x2._array)) @@ -518,6 +524,8 @@ def less(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in less") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.less(x1._array, x2._array)) @@ -530,6 +538,8 @@ def less_equal(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in less_equal") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.less_equal(x1._array, x2._array)) @@ -705,6 +715,8 @@ def not_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.not_equal(x1._array, x2._array)) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 04e606e..b0d4868 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -1,7 +1,7 @@ import operator from builtins import all as all_ -import numpy.testing +from numpy.testing import assert_raises, suppress_warnings import numpy as np import pytest @@ -29,10 +29,6 @@ import array_api_strict -def assert_raises(exception, func, msg=None): - with numpy.testing.assert_raises(exception, msg=msg): - func() - def test_validate_index(): # The indexing tests in the official array API test suite test that the # array object correctly handles the subset of indices that are required @@ -94,7 +90,7 @@ def test_validate_index(): def test_operators(): # For every operator, we test that it works for the required type - # combinations and assert_raises TypeError otherwise + # combinations and raises TypeError otherwise binary_op_dtypes = { "__add__": "numeric", "__and__": "integer_or_boolean", @@ -115,7 +111,6 @@ def test_operators(): "__truediv__": "floating", "__xor__": "integer_or_boolean", } - comparison_ops = ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"] # Recompute each time because of in-place ops def _array_vals(): for d in _integer_dtypes: @@ -129,7 +124,7 @@ def _array_vals(): BIG_INT = int(1e30) for op, dtypes in binary_op_dtypes.items(): ops = [op] - if op not in comparison_ops: + if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]: rop = "__r" + op[2:] iop = "__i" + op[2:] ops += [rop, iop] @@ -160,16 +155,16 @@ def _array_vals(): or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int] )): if a.dtype in _integer_dtypes and s == BIG_INT: - assert_raises(OverflowError, lambda: getattr(a, _op)(s), _op) + assert_raises(OverflowError, lambda: getattr(a, _op)(s)) else: # Only test for no error - with numpy.testing.suppress_warnings() as sup: + with suppress_warnings() as sup: # ignore warnings from pow(BIG_INT) sup.filter(RuntimeWarning, "invalid value encountered in power") getattr(a, _op)(s) else: - assert_raises(TypeError, lambda: getattr(a, _op)(s), _op) + assert_raises(TypeError, lambda: getattr(a, _op)(s)) # Test array op array. for _op in ops: @@ -178,25 +173,25 @@ def _array_vals(): # See the promotion table in NEP 47 or the array # API spec page on type promotion. Mixed kind # promotion is not defined. - if (op not in comparison_ops and - (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] - or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] - or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes - or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes - or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes - or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes - or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes - or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes - )): - assert_raises(TypeError, lambda: getattr(x, _op)(y), _op) + if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] + or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] + or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes + or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes + or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes + or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes + or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes + or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes + ): + assert_raises(TypeError, lambda: getattr(x, _op)(y)) # Ensure in-place operators only promote to the same dtype as the left operand. elif ( _op.startswith("__i") and result_type(x.dtype, y.dtype) != x.dtype ): - assert_raises(TypeError, lambda: getattr(x, _op)(y), _op) + assert_raises(TypeError, lambda: getattr(x, _op)(y)) # Ensure only those dtypes that are required for every operator are allowed. - elif (dtypes == "all" + elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes + or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) or (dtypes == "real numeric" and x.dtype in _real_numeric_dtypes and y.dtype in _real_numeric_dtypes) or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _integer_dtypes @@ -207,7 +202,7 @@ def _array_vals(): ): getattr(x, _op)(y) else: - assert_raises(TypeError, lambda: getattr(x, _op)(y), (x, _op, y)) + assert_raises(TypeError, lambda: getattr(x, _op)(y)) unary_op_dtypes = { "__abs__": "numeric", @@ -226,7 +221,7 @@ def _array_vals(): # Only test for no error getattr(a, op)() else: - assert_raises(TypeError, lambda: getattr(a, op)(), _op) + assert_raises(TypeError, lambda: getattr(a, op)()) # Finally, matmul() must be tested separately, because it works a bit # different from the other operations. @@ -245,9 +240,9 @@ def _matmul_array_vals(): or type(s) == int and a.dtype in _integer_dtypes): # Type promotion is valid, but @ is not allowed on 0-D # inputs, so the error is a ValueError - assert_raises(ValueError, lambda: getattr(a, _op)(s), _op) + assert_raises(ValueError, lambda: getattr(a, _op)(s)) else: - assert_raises(TypeError, lambda: getattr(a, _op)(s), _op) + assert_raises(TypeError, lambda: getattr(a, _op)(s)) for x in _matmul_array_vals(): for y in _matmul_array_vals(): @@ -361,17 +356,20 @@ def test_allow_newaxis(): def test_disallow_flat_indexing_with_newaxis(): a = ones((3, 3, 3)) - assert_raises(IndexError, lambda: a[None, 0, 0]) + with pytest.raises(IndexError): + a[None, 0, 0] def test_disallow_mask_with_newaxis(): a = ones((3, 3, 3)) - assert_raises(IndexError, lambda: a[None, asarray(True)]) + with pytest.raises(IndexError): + a[None, asarray(True)] @pytest.mark.parametrize("shape", [(), (5,), (3, 3, 3)]) @pytest.mark.parametrize("index", ["string", False, True]) def test_error_on_invalid_index(shape, index): a = ones(shape) - assert_raises(IndexError, lambda: a[index]) + with pytest.raises(IndexError): + a[index] def test_mask_0d_array_without_errors(): a = ones(()) @@ -382,8 +380,10 @@ def test_mask_0d_array_without_errors(): ) def test_error_on_invalid_index_with_ellipsis(i): a = ones((3, 3, 3)) - assert_raises(IndexError, lambda: a[..., i]) - assert_raises(IndexError, lambda: a[i, ...]) + with pytest.raises(IndexError): + a[..., i] + with pytest.raises(IndexError): + a[i, ...] def test_array_keys_use_private_array(): """ @@ -400,7 +400,8 @@ def test_array_keys_use_private_array(): a = ones((0,), dtype=bool_) key = ones((0, 0), dtype=bool_) - assert_raises(IndexError, lambda: a[key]) + with pytest.raises(IndexError): + a[key] def test_array_namespace(): a = ones((3, 3)) @@ -421,16 +422,16 @@ def test_array_namespace(): assert a.__array_namespace__(api_version="2021.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2021.12" - assert_raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) - assert_raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12")) + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12")) def test_iter(): - assert_raises(TypeError, lambda: iter(asarray(3))) + pytest.raises(TypeError, lambda: iter(asarray(3))) assert list(ones(3)) == [asarray(1.), asarray(1.), asarray(1.)] assert all_(isinstance(a, Array) for a in iter(ones(3))) assert all_(a.shape == () for a in iter(ones(3))) assert all_(a.dtype == float64 for a in iter(ones(3))) - assert_raises(TypeError, lambda: iter(ones((3, 3)))) + pytest.raises(TypeError, lambda: iter(ones((3, 3)))) @pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) def dlpack_2023_12(api_version): @@ -446,17 +447,17 @@ def dlpack_2023_12(api_version): exception = NotImplementedError if api_version >= '2023.12' else ValueError - assert_raises(exception, lambda: + pytest.raises(exception, lambda: a.__dlpack__(dl_device=CPU_DEVICE)) - assert_raises(exception, lambda: + pytest.raises(exception, lambda: a.__dlpack__(dl_device=None)) - assert_raises(exception, lambda: + pytest.raises(exception, lambda: a.__dlpack__(max_version=(1, 0))) - assert_raises(exception, lambda: + pytest.raises(exception, lambda: a.__dlpack__(max_version=None)) - assert_raises(exception, lambda: + pytest.raises(exception, lambda: a.__dlpack__(copy=False)) - assert_raises(exception, lambda: + pytest.raises(exception, lambda: a.__dlpack__(copy=True)) - assert_raises(exception, lambda: + pytest.raises(exception, lambda: a.__dlpack__(copy=None)) diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 92c9c59..fa3405a 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -1,6 +1,6 @@ from inspect import getfullargspec, getmodule -from .test_array_object import assert_raises +from numpy.testing import assert_raises from .. import asarray, _elementwise_functions from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift @@ -91,15 +91,6 @@ def nargs(func): "trunc": "real numeric", } -comparison_functions = [ - 'equal', - 'greater', - 'greater_equal', - 'less', - 'less_equal', - 'not_equal', -] - def test_missing_functions(): # Ensure the above dictionary is complete. import array_api_strict._elementwise_functions as mod @@ -130,8 +121,7 @@ def _array_vals(): if nargs(func) == 2: for y in _array_vals(): # Disallow dtypes that aren't type promotable - if (func_name not in comparison_functions and - (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] + if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes @@ -139,10 +129,10 @@ def _array_vals(): or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes - )): - assert_raises(TypeError, lambda: func(x, y), (func_name, x, y)) + ): + assert_raises(TypeError, lambda: func(x, y)) if x.dtype not in dtypes or y.dtype not in dtypes: - assert_raises(TypeError, lambda: func(x, y), (func_name, x, y)) + assert_raises(TypeError, lambda: func(x, y)) else: if x.dtype not in dtypes: assert_raises(TypeError, lambda: func(x))