@@ -19,18 +19,10 @@ def non_complex_dtypes():
19
19
return xps .boolean_dtypes () | hh .real_dtypes
20
20
21
21
22
- def numeric_dtypes ():
23
- return xps .boolean_dtypes () | hh .real_dtypes | hh .complex_dtypes
24
-
25
-
26
22
def float32 (n : Union [int , float ]) -> float :
27
23
return struct .unpack ("!f" , struct .pack ("!f" , float (n )))[0 ]
28
24
29
25
30
- def _float_match_complex (complex_dtype ):
31
- return xp .float32 if complex_dtype == xp .complex64 else xp .float64
32
-
33
-
34
26
@given (
35
27
x_dtype = non_complex_dtypes (),
36
28
dtype = non_complex_dtypes (),
@@ -115,46 +107,23 @@ def test_broadcast_to(x, data):
115
107
# TODO: test values
116
108
117
109
118
- @given (_from = numeric_dtypes (), to = numeric_dtypes (), data = st .data ())
119
- def test_can_cast (_from , to , data ):
120
- from_ = data .draw (
121
- st .just (_from ) | hh .arrays (dtype = _from , shape = hh .shapes ()), label = "from_"
122
- )
110
+ @given (_from = hh .all_dtypes , to = hh .all_dtypes )
111
+ def test_can_cast (_from , to ):
112
+ out = xp .can_cast (_from , to )
123
113
124
- out = xp .can_cast (from_ , to )
114
+ expected = False
115
+ for other in dh .all_dtypes :
116
+ if dh .promotion_table .get ((_from , other )) == to :
117
+ expected = True
118
+ break
125
119
126
120
f_func = f"[can_cast({ dh .dtype_to_name [_from ]} , { dh .dtype_to_name [to ]} )]"
127
- assert isinstance (out , bool ), f"{ type (out )= } , but should be bool { f_func } "
128
- if _from == xp .bool :
129
- expected = to == xp .bool
130
- else :
131
- same_family = None
132
- for dtypes in [dh .all_int_dtypes , dh .real_float_dtypes , dh .complex_dtypes ]:
133
- if _from in dtypes :
134
- same_family = to in dtypes
135
- break
136
- assert same_family is not None # sanity check
137
- if same_family :
138
- from_dtype = (_float_match_complex (_from )
139
- if _from in (xp .complex64 , xp .complex128 )
140
- else _from )
141
- to_dtype = (_float_match_complex (to )
142
- if to in (xp .complex64 , xp .complex128 )
143
- else to )
144
-
145
- from_min , from_max = dh .dtype_ranges [from_dtype ]
146
- to_min , to_max = dh .dtype_ranges [to_dtype ]
147
- expected = from_min >= to_min and from_max <= to_max
148
- else :
149
- expected = False
150
121
if expected :
151
122
# cross-kind casting is not explicitly disallowed. We can only test
152
- # the cases where it should return True. TODO: if expected=False,
153
- # check that the array library actually allows such casts.
123
+ # the cases where it should return True.
154
124
assert out == expected , f"{ out = } , but should be { expected } { f_func } "
155
125
156
126
157
-
158
127
@pytest .mark .parametrize ("dtype" , dh .real_float_dtypes )
159
128
def test_finfo (dtype ):
160
129
out = xp .finfo (dtype )
0 commit comments