11import math
2- from typing import Optional , Union
2+ from itertools import product
3+ from typing import Iterator , Optional , Tuple , Union
34
45from hypothesis import assume , given
56from hypothesis import strategies as st
@@ -21,23 +22,82 @@ def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]:
2122 return st .one_of (axes_strats )
2223
2324
25+ def normalise_axis (
26+ axis : Optional [Union [int , Tuple [int , ...]]], ndim : int
27+ ) -> Tuple [int , ...]:
28+ if axis is None :
29+ return tuple (range (ndim ))
30+ axes = axis if isinstance (axis , tuple ) else (axis ,)
31+ axes = tuple (axis if axis >= 0 else ndim + axis for axis in axes )
32+ return axes
33+
34+
35+ def axes_ndindex (shape : Shape , axes : Tuple [int , ...]) -> Iterator [Tuple [Shape , ...]]:
36+ base_iterables = []
37+ axes_iterables = []
38+ for axis , side in enumerate (shape ):
39+ if axis in axes :
40+ base_iterables .append ((None ,))
41+ axes_iterables .append (range (side ))
42+ else :
43+ base_iterables .append (range (side ))
44+ axes_iterables .append ((None ,))
45+ for base_idx in product (* base_iterables ):
46+ indices = []
47+ for idx in product (* axes_iterables ):
48+ idx = list (idx )
49+ for axis , side in enumerate (idx ):
50+ if axis not in axes :
51+ idx [axis ] = base_idx [axis ]
52+ idx = tuple (idx )
53+ indices .append (idx )
54+ yield tuple (indices )
55+
56+
57+ def assert_keepdimable_shape (
58+ func_name : str ,
59+ in_shape : Shape ,
60+ axes : Tuple [int , ...],
61+ keepdims : bool ,
62+ out_shape : Shape ,
63+ / ,
64+ ** kw ,
65+ ):
66+ if keepdims :
67+ shape = tuple (1 if axis in axes else side for axis , side in enumerate (in_shape ))
68+ else :
69+ shape = tuple (side for axis , side in enumerate (in_shape ) if axis not in axes )
70+ ph .assert_shape (func_name , out_shape , shape , ** kw )
71+
72+
2473def assert_equals (
25- func_name : str , type_ : ScalarType , out : Scalar , expected : Scalar , / , ** kw
74+ func_name : str ,
75+ type_ : ScalarType ,
76+ idx : Shape ,
77+ out : Scalar ,
78+ expected : Scalar ,
79+ / ,
80+ ** kw ,
2681):
82+ out_repr = "out" if idx == () else f"out[{ idx } ]"
2783 f_func = f"{ func_name } ({ ph .fmt_kw (kw )} )"
2884 if type_ is bool or type_ is int :
29- msg = f"{ out = } , should be { expected } [{ f_func } ]"
85+ msg = f"{ out_repr } = { out } , should be { expected } [{ f_func } ]"
3086 assert out == expected , msg
3187 elif math .isnan (expected ):
32- msg = f"{ out = } , should be { expected } [{ f_func } ]"
88+ msg = f"{ out_repr } = { out } , should be { expected } [{ f_func } ]"
3389 assert math .isnan (out ), msg
3490 else :
35- msg = f"{ out = } , should be roughly { expected } [{ f_func } ]"
91+ msg = f"{ out_repr } = { out } , should be roughly { expected } [{ f_func } ]"
3692 assert math .isclose (out , expected , rel_tol = 0.05 ), msg
3793
3894
3995@given (
40- x = xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes (min_side = 1 )),
96+ x = xps .arrays (
97+ dtype = xps .numeric_dtypes (),
98+ shape = hh .shapes (min_side = 1 ),
99+ elements = {"allow_nan" : False },
100+ ),
41101 data = st .data (),
42102)
43103def test_min (x , data ):
@@ -46,34 +106,27 @@ def test_min(x, data):
46106 out = xp .min (x , ** kw )
47107
48108 ph .assert_dtype ("min" , x .dtype , out .dtype )
49-
50- f_func = f"min({ ph .fmt_kw (kw )} )"
51-
52- # TODO: support axis
53- if kw .get ("axis" , None ) is None :
54- keepdims = kw .get ("keepdims" , False )
55- if keepdims :
56- shape = tuple (1 for _ in x .shape )
57- msg = f"{ out .shape = } , should be reduced dimension { shape } [{ f_func } ]"
58- assert out .shape == shape , msg
59- else :
60- ph .assert_shape ("min" , out .shape , (), ** kw )
61-
62- # TODO: figure out NaN behaviour
63- if dh .is_int_dtype (x .dtype ) or not xp .any (xp .isnan (x )):
64- _out = xp .reshape (out , ()) if keepdims else out
65- scalar_type = dh .get_scalar_type (out .dtype )
66- elements = []
67- for idx in ah .ndindex (x .shape ):
68- s = scalar_type (x [idx ])
69- elements .append (s )
70- min_ = scalar_type (_out )
71- expected = min (elements )
72- assert_equals ("min" , dh .get_scalar_type (out .dtype ), min_ , expected )
109+ _axes = normalise_axis (kw .get ("axis" , None ), x .ndim )
110+ assert_keepdimable_shape (
111+ "min" , x .shape , _axes , kw .get ("keepdims" , False ), out .shape , ** kw
112+ )
113+ scalar_type = dh .get_scalar_type (out .dtype )
114+ for indices , out_idx in zip (axes_ndindex (x .shape , _axes ), ah .ndindex (out .shape )):
115+ min_ = scalar_type (out [out_idx ])
116+ elements = []
117+ for idx in indices :
118+ s = scalar_type (x [idx ])
119+ elements .append (s )
120+ expected = min (elements )
121+ assert_equals ("min" , dh .get_scalar_type (out .dtype ), out_idx , min_ , expected )
73122
74123
75124@given (
76- x = xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes (min_side = 1 )),
125+ x = xps .arrays (
126+ dtype = xps .numeric_dtypes (),
127+ shape = hh .shapes (min_side = 1 ),
128+ elements = {"allow_nan" : False },
129+ ),
77130 data = st .data (),
78131)
79132def test_max (x , data ):
@@ -82,34 +135,27 @@ def test_max(x, data):
82135 out = xp .max (x , ** kw )
83136
84137 ph .assert_dtype ("max" , x .dtype , out .dtype )
85-
86- f_func = f"max({ ph .fmt_kw (kw )} )"
87-
88- # TODO: support axis
89- if kw .get ("axis" , None ) is None :
90- keepdims = kw .get ("keepdims" , False )
91- if keepdims :
92- shape = tuple (1 for _ in x .shape )
93- msg = f"{ out .shape = } , should be reduced dimension { shape } [{ f_func } ]"
94- assert out .shape == shape , msg
95- else :
96- ph .assert_shape ("max" , out .shape , (), ** kw )
97-
98- # TODO: figure out NaN behaviour
99- if dh .is_int_dtype (x .dtype ) or not xp .any (xp .isnan (x )):
100- _out = xp .reshape (out , ()) if keepdims else out
101- scalar_type = dh .get_scalar_type (out .dtype )
102- elements = []
103- for idx in ah .ndindex (x .shape ):
104- s = scalar_type (x [idx ])
105- elements .append (s )
106- max_ = scalar_type (_out )
107- expected = max (elements )
108- assert_equals ("mean" , dh .get_scalar_type (out .dtype ), max_ , expected )
138+ _axes = normalise_axis (kw .get ("axis" , None ), x .ndim )
139+ assert_keepdimable_shape (
140+ "max" , x .shape , _axes , kw .get ("keepdims" , False ), out .shape , ** kw
141+ )
142+ scalar_type = dh .get_scalar_type (out .dtype )
143+ for indices , out_idx in zip (axes_ndindex (x .shape , _axes ), ah .ndindex (out .shape )):
144+ max_ = scalar_type (out [out_idx ])
145+ elements = []
146+ for idx in indices :
147+ s = scalar_type (x [idx ])
148+ elements .append (s )
149+ expected = max (elements )
150+ assert_equals ("max" , dh .get_scalar_type (out .dtype ), out_idx , max_ , expected )
109151
110152
111153@given (
112- x = xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes (min_side = 1 )),
154+ x = xps .arrays (
155+ dtype = xps .floating_dtypes (),
156+ shape = hh .shapes (min_side = 1 ),
157+ elements = {"allow_nan" : False },
158+ ),
113159 data = st .data (),
114160)
115161def test_mean (x , data ):
@@ -118,33 +164,26 @@ def test_mean(x, data):
118164 out = xp .mean (x , ** kw )
119165
120166 ph .assert_dtype ("mean" , x .dtype , out .dtype )
121-
122- f_func = f"mean({ ph .fmt_kw (kw )} )"
123-
124- # TODO: support axis
125- if kw .get ("axis" , None ) is None :
126- keepdims = kw .get ("keepdims" , False )
127- if keepdims :
128- shape = tuple (1 for _ in x .shape )
129- msg = f"{ out .shape = } , should be reduced dimension { shape } [{ f_func } ]"
130- assert out .shape == shape , msg
131- else :
132- ph .assert_shape ("max" , out .shape , (), ** kw )
133-
134- # TODO: figure out NaN behaviour
135- if not xp .any (xp .isnan (x )):
136- _out = xp .reshape (out , ()) if keepdims else out
137- elements = []
138- for idx in ah .ndindex (x .shape ):
139- s = float (x [idx ])
140- elements .append (s )
141- mean = float (_out )
142- expected = sum (elements ) / len (elements )
143- assert_equals ("mean" , float , mean , expected )
167+ _axes = normalise_axis (kw .get ("axis" , None ), x .ndim )
168+ assert_keepdimable_shape (
169+ "mean" , x .shape , _axes , kw .get ("keepdims" , False ), out .shape , ** kw
170+ )
171+ for indices , out_idx in zip (axes_ndindex (x .shape , _axes ), ah .ndindex (out .shape )):
172+ mean = float (out [out_idx ])
173+ elements = []
174+ for idx in indices :
175+ s = float (x [idx ])
176+ elements .append (s )
177+ expected = sum (elements ) / len (elements )
178+ assert_equals ("mean" , dh .get_scalar_type (out .dtype ), out_idx , mean , expected )
144179
145180
146181@given (
147- x = xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes (min_side = 1 )),
182+ x = xps .arrays (
183+ dtype = xps .numeric_dtypes (),
184+ shape = hh .shapes (min_side = 1 ),
185+ elements = {"allow_nan" : False },
186+ ),
148187 data = st .data (),
149188)
150189def test_prod (x , data ):
@@ -176,52 +215,37 @@ def test_prod(x, data):
176215 else :
177216 _dtype = dtype
178217 ph .assert_dtype ("prod" , x .dtype , out .dtype , _dtype )
179-
180- f_func = f"prod({ ph .fmt_kw (kw )} )"
181-
182- # TODO: support axis
183- if kw .get ("axis" , None ) is None :
184- keepdims = kw .get ("keepdims" , False )
185- if keepdims :
186- shape = tuple (1 for _ in x .shape )
187- msg = f"{ out .shape = } , should be reduced dimension { shape } [{ f_func } ]"
188- assert out .shape == shape , msg
189- else :
190- ph .assert_shape ("prod" , out .shape , (), ** kw )
191-
192- # TODO: figure out NaN behaviour
193- if dh .is_int_dtype (x .dtype ) or not xp .any (xp .isnan (x )):
194- _out = xp .reshape (out , ()) if keepdims else out
195- scalar_type = dh .get_scalar_type (out .dtype )
196- elements = []
197- for idx in ah .ndindex (x .shape ):
198- s = scalar_type (x [idx ])
199- elements .append (s )
200- prod = scalar_type (_out )
201- expected = math .prod (elements )
202- if dh .is_int_dtype (out .dtype ):
203- m , M = dh .dtype_ranges [out .dtype ]
204- assume (m <= expected <= M )
205- assert_equals ("prod" , dh .get_scalar_type (out .dtype ), prod , expected )
218+ _axes = normalise_axis (kw .get ("axis" , None ), x .ndim )
219+ assert_keepdimable_shape (
220+ "prod" , x .shape , _axes , kw .get ("keepdims" , False ), out .shape , ** kw
221+ )
222+ scalar_type = dh .get_scalar_type (out .dtype )
223+ for indices , out_idx in zip (axes_ndindex (x .shape , _axes ), ah .ndindex (out .shape )):
224+ prod = scalar_type (out [out_idx ])
225+ assume (not math .isinf (prod ))
226+ elements = []
227+ for idx in indices :
228+ s = scalar_type (x [idx ])
229+ elements .append (s )
230+ expected = math .prod (elements )
231+ if dh .is_int_dtype (out .dtype ):
232+ m , M = dh .dtype_ranges [out .dtype ]
233+ assume (m <= expected <= M )
234+ assert_equals ("prod" , dh .get_scalar_type (out .dtype ), out_idx , prod , expected )
206235
207236
208237@given (
209- x = xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes (min_side = 1 )).filter (
210- lambda x : x .size >= 2
211- ),
238+ x = xps .arrays (
239+ dtype = xps .floating_dtypes (),
240+ shape = hh .shapes (min_side = 1 ),
241+ elements = {"allow_nan" : False },
242+ ).filter (lambda x : x .size >= 2 ),
212243 data = st .data (),
213244)
214245def test_std (x , data ):
215246 axis = data .draw (axes (x .ndim ), label = "axis" )
216- if axis is None :
217- N = x .size
218- _axes = tuple (range (x .ndim ))
219- else :
220- _axes = axis if isinstance (axis , tuple ) else (axis ,)
221- _axes = tuple (
222- axis if axis >= 0 else x .ndim + axis for axis in _axes
223- ) # normalise
224- N = sum (side for axis , side in enumerate (x .shape ) if axis not in _axes )
247+ _axes = normalise_axis (axis , x .ndim )
248+ N = sum (side for axis , side in enumerate (x .shape ) if axis not in _axes )
225249 correction = data .draw (
226250 st .floats (0.0 , N , allow_infinity = False , allow_nan = False ) | st .integers (0 , N ),
227251 label = "correction" ,
@@ -239,13 +263,9 @@ def test_std(x, data):
239263 out = xp .std (x , ** kw )
240264
241265 ph .assert_dtype ("std" , x .dtype , out .dtype )
242-
243- if keepdims :
244- shape = tuple (1 if axis in _axes else side for axis , side in enumerate (x .shape ))
245- else :
246- shape = tuple (side for axis , side in enumerate (x .shape ) if axis not in _axes )
247- ph .assert_shape ("std" , out .shape , shape , ** kw )
248-
266+ assert_keepdimable_shape (
267+ "std" , x .shape , _axes , kw .get ("keepdims" , False ), out .shape , ** kw
268+ )
249269 # We can't easily test the result(s) as standard deviation methods vary a lot
250270
251271
0 commit comments