Skip to content

Commit 9f28ec8

Browse files
committed
Simplify test_can_cast
1) test all dtypes 2) check against the promotion table 3) remove checking the limits (value-based casting?)
1 parent 5120204 commit 9f28ec8

File tree

1 file changed

+9
-40
lines changed

1 file changed

+9
-40
lines changed

array_api_tests/test_data_type_functions.py

+9-40
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,10 @@ def non_complex_dtypes():
1919
return xps.boolean_dtypes() | hh.real_dtypes
2020

2121

22-
def numeric_dtypes():
23-
return xps.boolean_dtypes() | hh.real_dtypes | hh.complex_dtypes
24-
25-
2622
def float32(n: Union[int, float]) -> float:
2723
return struct.unpack("!f", struct.pack("!f", float(n)))[0]
2824

2925

30-
def _float_match_complex(complex_dtype):
31-
return xp.float32 if complex_dtype == xp.complex64 else xp.float64
32-
33-
3426
@given(
3527
x_dtype=non_complex_dtypes(),
3628
dtype=non_complex_dtypes(),
@@ -115,46 +107,23 @@ def test_broadcast_to(x, data):
115107
# TODO: test values
116108

117109

118-
@given(_from=numeric_dtypes(), to=numeric_dtypes(), data=st.data())
119-
def test_can_cast(_from, to, data):
120-
from_ = data.draw(
121-
st.just(_from) | hh.arrays(dtype=_from, shape=hh.shapes()), label="from_"
122-
)
110+
@given(_from=hh.all_dtypes, to=hh.all_dtypes)
111+
def test_can_cast(_from, to):
112+
out = xp.can_cast(_from, to)
123113

124-
out = xp.can_cast(from_, to)
114+
expected = False
115+
for other in dh.all_dtypes:
116+
if dh.promotion_table.get((_from, other)) == to:
117+
expected = True
118+
break
125119

126120
f_func = f"[can_cast({dh.dtype_to_name[_from]}, {dh.dtype_to_name[to]})]"
127-
assert isinstance(out, bool), f"{type(out)=}, but should be bool {f_func}"
128-
if _from == xp.bool:
129-
expected = to == xp.bool
130-
else:
131-
same_family = None
132-
for dtypes in [dh.all_int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]:
133-
if _from in dtypes:
134-
same_family = to in dtypes
135-
break
136-
assert same_family is not None # sanity check
137-
if same_family:
138-
from_dtype = (_float_match_complex(_from)
139-
if _from in (xp.complex64, xp.complex128)
140-
else _from)
141-
to_dtype = (_float_match_complex(to)
142-
if to in (xp.complex64, xp.complex128)
143-
else to)
144-
145-
from_min, from_max = dh.dtype_ranges[from_dtype]
146-
to_min, to_max = dh.dtype_ranges[to_dtype]
147-
expected = from_min >= to_min and from_max <= to_max
148-
else:
149-
expected = False
150121
if expected:
151122
# cross-kind casting is not explicitly disallowed. We can only test
152-
# the cases where it should return True. TODO: if expected=False,
153-
# check that the array library actually allows such casts.
123+
# the cases where it should return True.
154124
assert out == expected, f"{out=}, but should be {expected} {f_func}"
155125

156126

157-
158127
@pytest.mark.parametrize("dtype", dh.real_float_dtypes)
159128
def test_finfo(dtype):
160129
out = xp.finfo(dtype)

0 commit comments

Comments
 (0)