Skip to content

Commit

Permalink
Revert "Allow any combination of real dtypes in comparisons"
Browse files Browse the repository at this point in the history
  • Loading branch information
asmeurer authored Jul 18, 2024
1 parent 77a9c2d commit 899ad12
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 89 deletions.
24 changes: 8 additions & 16 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,7 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None
# spec in places where it either deviates from or is more strict than
# NumPy behavior

def _check_allowed_dtypes(
self,
other: bool | int | float | Array,
dtype_category: str,
op: str,
*,
check_promotion: bool = True,
) -> Array:
def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array:
"""
Helper function for operators to only allow specific input dtypes
Expand All @@ -183,8 +176,7 @@ def _check_allowed_dtypes(
# This will raise TypeError for type combinations that are not allowed
# to promote in the spec (even if the NumPy array operator would
# promote them).
if check_promotion:
res_dtype = _result_type(self.dtype, other.dtype)
res_dtype = _result_type(self.dtype, other.dtype)
if op.startswith("__i"):
# Note: NumPy will allow in-place operators in some cases where
# the type promoted operator does not match the left-hand side
Expand Down Expand Up @@ -578,7 +570,7 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
"""
# Even though "all" dtypes are allowed, we still require them to be
# promotable with each other.
other = self._check_allowed_dtypes(other, "all", "__eq__", check_promotion=False)
other = self._check_allowed_dtypes(other, "all", "__eq__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand Down Expand Up @@ -612,7 +604,7 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __ge__.
"""
other = self._check_allowed_dtypes(other, "real numeric", "__ge__", check_promotion=False)
other = self._check_allowed_dtypes(other, "real numeric", "__ge__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand Down Expand Up @@ -646,7 +638,7 @@ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __gt__.
"""
other = self._check_allowed_dtypes(other, "real numeric", "__gt__", check_promotion=False)
other = self._check_allowed_dtypes(other, "real numeric", "__gt__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand Down Expand Up @@ -700,7 +692,7 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __le__.
"""
other = self._check_allowed_dtypes(other, "real numeric", "__le__", check_promotion=False)
other = self._check_allowed_dtypes(other, "real numeric", "__le__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand All @@ -722,7 +714,7 @@ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __lt__.
"""
other = self._check_allowed_dtypes(other, "real numeric", "__lt__", check_promotion=False)
other = self._check_allowed_dtypes(other, "real numeric", "__lt__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand Down Expand Up @@ -767,7 +759,7 @@ def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
"""
Performs the operation __ne__.
"""
other = self._check_allowed_dtypes(other, "all", "__ne__", check_promotion=False)
other = self._check_allowed_dtypes(other, "all", "__ne__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand Down
12 changes: 12 additions & 0 deletions array_api_strict/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ def equal(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.equal(x1._array, x2._array))

Expand Down Expand Up @@ -437,6 +439,8 @@ def greater(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in greater")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.greater(x1._array, x2._array))

Expand All @@ -449,6 +453,8 @@ def greater_equal(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in greater_equal")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.greater_equal(x1._array, x2._array))

Expand Down Expand Up @@ -518,6 +524,8 @@ def less(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in less")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.less(x1._array, x2._array))

Expand All @@ -530,6 +538,8 @@ def less_equal(x1: Array, x2: Array, /) -> Array:
"""
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in less_equal")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.less_equal(x1._array, x2._array))

Expand Down Expand Up @@ -705,6 +715,8 @@ def not_equal(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.not_equal(x1._array, x2._array))

Expand Down
91 changes: 46 additions & 45 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import operator
from builtins import all as all_

import numpy.testing
from numpy.testing import assert_raises, suppress_warnings
import numpy as np
import pytest

Expand Down Expand Up @@ -29,10 +29,6 @@

import array_api_strict

def assert_raises(exception, func, msg=None):
with numpy.testing.assert_raises(exception, msg=msg):
func()

def test_validate_index():
# The indexing tests in the official array API test suite test that the
# array object correctly handles the subset of indices that are required
Expand Down Expand Up @@ -94,7 +90,7 @@ def test_validate_index():

def test_operators():
# For every operator, we test that it works for the required type
# combinations and assert_raises TypeError otherwise
# combinations and raises TypeError otherwise
binary_op_dtypes = {
"__add__": "numeric",
"__and__": "integer_or_boolean",
Expand All @@ -115,7 +111,6 @@ def test_operators():
"__truediv__": "floating",
"__xor__": "integer_or_boolean",
}
comparison_ops = ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]
# Recompute each time because of in-place ops
def _array_vals():
for d in _integer_dtypes:
Expand All @@ -129,7 +124,7 @@ def _array_vals():
BIG_INT = int(1e30)
for op, dtypes in binary_op_dtypes.items():
ops = [op]
if op not in comparison_ops:
if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]:
rop = "__r" + op[2:]
iop = "__i" + op[2:]
ops += [rop, iop]
Expand Down Expand Up @@ -160,16 +155,16 @@ def _array_vals():
or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int]
)):
if a.dtype in _integer_dtypes and s == BIG_INT:
assert_raises(OverflowError, lambda: getattr(a, _op)(s), _op)
assert_raises(OverflowError, lambda: getattr(a, _op)(s))
else:
# Only test for no error
with numpy.testing.suppress_warnings() as sup:
with suppress_warnings() as sup:
# ignore warnings from pow(BIG_INT)
sup.filter(RuntimeWarning,
"invalid value encountered in power")
getattr(a, _op)(s)
else:
assert_raises(TypeError, lambda: getattr(a, _op)(s), _op)
assert_raises(TypeError, lambda: getattr(a, _op)(s))

# Test array op array.
for _op in ops:
Expand All @@ -178,25 +173,25 @@ def _array_vals():
# See the promotion table in NEP 47 or the array
# API spec page on type promotion. Mixed kind
# promotion is not defined.
if (op not in comparison_ops and
(x.dtype == uint64 and y.dtype in [int8, int16, int32, int64]
or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64]
or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes
or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes
or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes
or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes
or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
)):
assert_raises(TypeError, lambda: getattr(x, _op)(y), _op)
if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64]
or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64]
or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes
or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes
or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes
or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes
or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
):
assert_raises(TypeError, lambda: getattr(x, _op)(y))
# Ensure in-place operators only promote to the same dtype as the left operand.
elif (
_op.startswith("__i")
and result_type(x.dtype, y.dtype) != x.dtype
):
assert_raises(TypeError, lambda: getattr(x, _op)(y), _op)
assert_raises(TypeError, lambda: getattr(x, _op)(y))
# Ensure only those dtypes that are required for every operator are allowed.
elif (dtypes == "all"
elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes
or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
or (dtypes == "real numeric" and x.dtype in _real_numeric_dtypes and y.dtype in _real_numeric_dtypes)
or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _integer_dtypes
Expand All @@ -207,7 +202,7 @@ def _array_vals():
):
getattr(x, _op)(y)
else:
assert_raises(TypeError, lambda: getattr(x, _op)(y), (x, _op, y))
assert_raises(TypeError, lambda: getattr(x, _op)(y))

unary_op_dtypes = {
"__abs__": "numeric",
Expand All @@ -226,7 +221,7 @@ def _array_vals():
# Only test for no error
getattr(a, op)()
else:
assert_raises(TypeError, lambda: getattr(a, op)(), _op)
assert_raises(TypeError, lambda: getattr(a, op)())

# Finally, matmul() must be tested separately, because it works a bit
# different from the other operations.
Expand All @@ -245,9 +240,9 @@ def _matmul_array_vals():
or type(s) == int and a.dtype in _integer_dtypes):
# Type promotion is valid, but @ is not allowed on 0-D
# inputs, so the error is a ValueError
assert_raises(ValueError, lambda: getattr(a, _op)(s), _op)
assert_raises(ValueError, lambda: getattr(a, _op)(s))
else:
assert_raises(TypeError, lambda: getattr(a, _op)(s), _op)
assert_raises(TypeError, lambda: getattr(a, _op)(s))

for x in _matmul_array_vals():
for y in _matmul_array_vals():
Expand Down Expand Up @@ -361,17 +356,20 @@ def test_allow_newaxis():

def test_disallow_flat_indexing_with_newaxis():
a = ones((3, 3, 3))
assert_raises(IndexError, lambda: a[None, 0, 0])
with pytest.raises(IndexError):
a[None, 0, 0]

def test_disallow_mask_with_newaxis():
a = ones((3, 3, 3))
assert_raises(IndexError, lambda: a[None, asarray(True)])
with pytest.raises(IndexError):
a[None, asarray(True)]

@pytest.mark.parametrize("shape", [(), (5,), (3, 3, 3)])
@pytest.mark.parametrize("index", ["string", False, True])
def test_error_on_invalid_index(shape, index):
a = ones(shape)
assert_raises(IndexError, lambda: a[index])
with pytest.raises(IndexError):
a[index]

def test_mask_0d_array_without_errors():
a = ones(())
Expand All @@ -382,8 +380,10 @@ def test_mask_0d_array_without_errors():
)
def test_error_on_invalid_index_with_ellipsis(i):
a = ones((3, 3, 3))
assert_raises(IndexError, lambda: a[..., i])
assert_raises(IndexError, lambda: a[i, ...])
with pytest.raises(IndexError):
a[..., i]
with pytest.raises(IndexError):
a[i, ...]

def test_array_keys_use_private_array():
"""
Expand All @@ -400,7 +400,8 @@ def test_array_keys_use_private_array():

a = ones((0,), dtype=bool_)
key = ones((0, 0), dtype=bool_)
assert_raises(IndexError, lambda: a[key])
with pytest.raises(IndexError):
a[key]

def test_array_namespace():
a = ones((3, 3))
Expand All @@ -421,16 +422,16 @@ def test_array_namespace():
assert a.__array_namespace__(api_version="2021.12") is array_api_strict
assert array_api_strict.__array_api_version__ == "2021.12"

assert_raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
assert_raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12"))
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12"))

def test_iter():
assert_raises(TypeError, lambda: iter(asarray(3)))
pytest.raises(TypeError, lambda: iter(asarray(3)))
assert list(ones(3)) == [asarray(1.), asarray(1.), asarray(1.)]
assert all_(isinstance(a, Array) for a in iter(ones(3)))
assert all_(a.shape == () for a in iter(ones(3)))
assert all_(a.dtype == float64 for a in iter(ones(3)))
assert_raises(TypeError, lambda: iter(ones((3, 3))))
pytest.raises(TypeError, lambda: iter(ones((3, 3))))

@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])
def dlpack_2023_12(api_version):
Expand All @@ -446,17 +447,17 @@ def dlpack_2023_12(api_version):


exception = NotImplementedError if api_version >= '2023.12' else ValueError
assert_raises(exception, lambda:
pytest.raises(exception, lambda:
a.__dlpack__(dl_device=CPU_DEVICE))
assert_raises(exception, lambda:
pytest.raises(exception, lambda:
a.__dlpack__(dl_device=None))
assert_raises(exception, lambda:
pytest.raises(exception, lambda:
a.__dlpack__(max_version=(1, 0)))
assert_raises(exception, lambda:
pytest.raises(exception, lambda:
a.__dlpack__(max_version=None))
assert_raises(exception, lambda:
pytest.raises(exception, lambda:
a.__dlpack__(copy=False))
assert_raises(exception, lambda:
pytest.raises(exception, lambda:
a.__dlpack__(copy=True))
assert_raises(exception, lambda:
pytest.raises(exception, lambda:
a.__dlpack__(copy=None))
Loading

0 comments on commit 899ad12

Please sign in to comment.