1010from . import pytest_helpers as ph
1111from . import xps
1212
13+ RTOL = 0.05
14+
1315
1416@given (
1517 x = xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes (min_side = 1 )),
@@ -37,7 +39,7 @@ def test_min(x, data):
3739 if keepdims :
3840 idx = tuple (1 for _ in x .shape )
3941 msg = f"{ out .shape = } , should be reduced dimension { idx } [{ f_func } ]"
40- assert out .shape == idx
42+ assert out .shape == idx , msg
4143 else :
4244 ph .assert_shape ("min" , out .shape , (), ** kw )
4345
@@ -84,7 +86,7 @@ def test_max(x, data):
8486 if keepdims :
8587 idx = tuple (1 for _ in x .shape )
8688 msg = f"{ out .shape = } , should be reduced dimension { idx } [{ f_func } ]"
87- assert out .shape == idx
89+ assert out .shape == idx , msg
8890 else :
8991 ph .assert_shape ("max" , out .shape , (), ** kw )
9092
@@ -105,11 +107,47 @@ def test_max(x, data):
105107 assert max_ == expected , msg
106108
107109
108- # TODO: generate kwargs
109- @given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes (min_side = 1 )))
110- def test_mean (x ):
111- xp .mean (x )
112- # TODO
110+ @given (
111+ x = xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes (min_side = 1 )),
112+ data = st .data (),
113+ )
114+ def test_mean (x , data ):
115+ axis_strats = [st .none ()]
116+ if x .shape != ():
117+ axis_strats .append (
118+ st .integers (- x .ndim , x .ndim - 1 ) | xps .valid_tuple_axes (x .ndim )
119+ )
120+ kw = data .draw (
121+ hh .kwargs (axis = st .one_of (axis_strats ), keepdims = st .booleans ()), label = "kw"
122+ )
123+
124+ out = xp .mean (x , ** kw )
125+
126+ ph .assert_dtype ("mean" , x .dtype , out .dtype )
127+
128+ f_func = f"mean({ ph .fmt_kw (kw )} )"
129+
130+ # TODO: support axis
131+ if kw .get ("axis" ) is None :
132+ keepdims = kw .get ("keepdims" , False )
133+ if keepdims :
134+ idx = tuple (1 for _ in x .shape )
135+ msg = f"{ out .shape = } , should be reduced dimension { idx } [{ f_func } ]"
136+ assert out .shape == idx , msg
137+ else :
138+ ph .assert_shape ("max" , out .shape , (), ** kw )
139+
140+ # TODO: figure out NaN behaviour
141+ if not xp .any (xp .isnan (x )):
142+ _out = xp .reshape (out , ()) if keepdims else out
143+ elements = []
144+ for idx in ah .ndindex (x .shape ):
145+ s = float (x [idx ])
146+ elements .append (s )
147+ mean = float (_out )
148+ expected = sum (elements ) / len (elements )
149+ msg = f"out={ mean } , should be roughly { expected } [{ f_func } ]"
150+ assert math .isclose (mean , expected , rel_tol = RTOL ), msg
113151
114152
115153# TODO: generate kwargs
0 commit comments