diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 90994f3..fa3405a 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -9,6 +9,11 @@ _boolean_dtypes, _floating_dtypes, _integer_dtypes, + int8, + int16, + int32, + int64, + uint64, ) from .._flags import set_array_api_strict_flags @@ -115,6 +120,17 @@ def _array_vals(): func = getattr(_elementwise_functions, func_name) if nargs(func) == 2: for y in _array_vals(): + # Disallow dtypes that aren't type promotable + if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] + or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] + or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes + or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes + or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes + or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes + or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes + or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes + ): + assert_raises(TypeError, lambda: func(x, y)) if x.dtype not in dtypes or y.dtype not in dtypes: assert_raises(TypeError, lambda: func(x, y)) else: