Skip to content

Commit

Permalink
Restore extra testing of elementwise function disallowed type promotions
Browse files Browse the repository at this point in the history
  • Loading branch information
asmeurer committed Jul 18, 2024
1 parent 899ad12 commit 1c03aaa
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions array_api_strict/tests/test_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
_boolean_dtypes,
_floating_dtypes,
_integer_dtypes,
int8,
int16,
int32,
int64,
uint64,
)
from .._flags import set_array_api_strict_flags

Expand Down Expand Up @@ -115,6 +120,17 @@ def _array_vals():
func = getattr(_elementwise_functions, func_name)
if nargs(func) == 2:
for y in _array_vals():
# Disallow dtypes that aren't type promotable
if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64]
or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64]
or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes
or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes
or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes
or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes
or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
):
assert_raises(TypeError, lambda: func(x, y))
if x.dtype not in dtypes or y.dtype not in dtypes:
assert_raises(TypeError, lambda: func(x, y))
else:
Expand Down

0 comments on commit 1c03aaa

Please sign in to comment.