@@ -56,10 +56,10 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[Tuple[Shape, .
5656
5757def assert_keepdimable_shape (
5858 func_name : str ,
59+ out_shape : Shape ,
5960 in_shape : Shape ,
6061 axes : Tuple [int , ...],
6162 keepdims : bool ,
62- out_shape : Shape ,
6363 / ,
6464 ** kw ,
6565):
@@ -108,7 +108,7 @@ def test_min(x, data):
108108 ph .assert_dtype ("min" , x .dtype , out .dtype )
109109 _axes = normalise_axis (kw .get ("axis" , None ), x .ndim )
110110 assert_keepdimable_shape (
111- "min" , x .shape , _axes , kw .get ("keepdims" , False ), out . shape , ** kw
111+ "min" , out . shape , x .shape , _axes , kw .get ("keepdims" , False ), ** kw
112112 )
113113 scalar_type = dh .get_scalar_type (out .dtype )
114114 for indices , out_idx in zip (axes_ndindex (x .shape , _axes ), ah .ndindex (out .shape )):
@@ -137,7 +137,7 @@ def test_max(x, data):
137137 ph .assert_dtype ("max" , x .dtype , out .dtype )
138138 _axes = normalise_axis (kw .get ("axis" , None ), x .ndim )
139139 assert_keepdimable_shape (
140- "max" , x .shape , _axes , kw .get ("keepdims" , False ), out . shape , ** kw
140+ "max" , out . shape , x .shape , _axes , kw .get ("keepdims" , False ), ** kw
141141 )
142142 scalar_type = dh .get_scalar_type (out .dtype )
143143 for indices , out_idx in zip (axes_ndindex (x .shape , _axes ), ah .ndindex (out .shape )):
@@ -166,7 +166,7 @@ def test_mean(x, data):
166166 ph .assert_dtype ("mean" , x .dtype , out .dtype )
167167 _axes = normalise_axis (kw .get ("axis" , None ), x .ndim )
168168 assert_keepdimable_shape (
169- "mean" , x .shape , _axes , kw .get ("keepdims" , False ), out . shape , ** kw
169+ "mean" , out . shape , x .shape , _axes , kw .get ("keepdims" , False ), ** kw
170170 )
171171 for indices , out_idx in zip (axes_ndindex (x .shape , _axes ), ah .ndindex (out .shape )):
172172 mean = float (out [out_idx ])
@@ -217,7 +217,7 @@ def test_prod(x, data):
217217 ph .assert_dtype ("prod" , x .dtype , out .dtype , _dtype )
218218 _axes = normalise_axis (kw .get ("axis" , None ), x .ndim )
219219 assert_keepdimable_shape (
220- "prod" , x .shape , _axes , kw .get ("keepdims" , False ), out . shape , ** kw
220+ "prod" , out . shape , x .shape , _axes , kw .get ("keepdims" , False ), ** kw
221221 )
222222 scalar_type = dh .get_scalar_type (out .dtype )
223223 for indices , out_idx in zip (axes_ndindex (x .shape , _axes ), ah .ndindex (out .shape )):
@@ -264,20 +264,97 @@ def test_std(x, data):
264264
265265 ph .assert_dtype ("std" , x .dtype , out .dtype )
266266 assert_keepdimable_shape (
267- "std" , x .shape , _axes , kw .get ("keepdims" , False ), out . shape , ** kw
267+ "std" , out . shape , x .shape , _axes , kw .get ("keepdims" , False ), ** kw
268268 )
269269 # We can't easily test the result(s) as standard deviation methods vary a lot
270270
271271
272- # TODO: generate kwargs
273- @given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes (min_side = 1 )))
274- def test_sum (x ):
275- xp .sum (x )
276- # TODO
272+ @given (
273+ x = xps .arrays (
274+ dtype = xps .floating_dtypes (),
275+ shape = hh .shapes (min_side = 1 ),
276+ elements = {"allow_nan" : False },
277+ ).filter (lambda x : x .size >= 2 ),
278+ data = st .data (),
279+ )
280+ def test_var (x , data ):
281+ axis = data .draw (axes (x .ndim ), label = "axis" )
282+ _axes = normalise_axis (axis , x .ndim )
283+ N = sum (side for axis , side in enumerate (x .shape ) if axis not in _axes )
284+ correction = data .draw (
285+ st .floats (0.0 , N , allow_infinity = False , allow_nan = False ) | st .integers (0 , N ),
286+ label = "correction" ,
287+ )
288+ keepdims = data .draw (st .booleans (), label = "keepdims" )
289+ kw = data .draw (
290+ hh .specified_kwargs (
291+ ("axis" , axis , None ),
292+ ("correction" , correction , 0.0 ),
293+ ("keepdims" , keepdims , False ),
294+ ),
295+ label = "kw" ,
296+ )
297+
298+ out = xp .var (x , ** kw )
299+
300+ ph .assert_dtype ("var" , x .dtype , out .dtype )
301+ assert_keepdimable_shape (
302+ "var" , out .shape , x .shape , _axes , kw .get ("keepdims" , False ), ** kw
303+ )
304+ # We can't easily test the result(s) as variance methods vary a lot
305+
306+
307+ @given (
308+ x = xps .arrays (
309+ dtype = xps .numeric_dtypes (),
310+ shape = hh .shapes (min_side = 1 ),
311+ elements = {"allow_nan" : False },
312+ ),
313+ data = st .data (),
314+ )
315+ def test_sum (x , data ):
316+ kw = data .draw (
317+ hh .kwargs (
318+ axis = axes (x .ndim ),
319+ dtype = st .none () | st .just (x .dtype ), # TODO: all valid dtypes
320+ keepdims = st .booleans (),
321+ ),
322+ label = "kw" ,
323+ )
277324
325+ out = xp .sum (x , ** kw )
278326
279- # TODO: generate kwargs
280- @given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes (min_side = 1 )))
281- def test_var (x ):
282- xp .var (x )
283- # TODO
327+ dtype = kw .get ("dtype" , None )
328+ if dtype is None :
329+ if dh .is_int_dtype (x .dtype ):
330+ m , M = dh .dtype_ranges [x .dtype ]
331+ d_m , d_M = dh .dtype_ranges [dh .default_int ]
332+ if m < d_m or M > d_M :
333+ _dtype = x .dtype
334+ else :
335+ _dtype = dh .default_int
336+ else :
337+ if dh .dtype_nbits [x .dtype ] > dh .dtype_nbits [dh .default_float ]:
338+ _dtype = x .dtype
339+ else :
340+ _dtype = dh .default_float
341+ else :
342+ _dtype = dtype
343+ ph .assert_dtype ("sum" , x .dtype , out .dtype , _dtype )
344+ _axes = normalise_axis (kw .get ("axis" , None ), x .ndim )
345+ assert_keepdimable_shape (
346+ "sum" , out .shape , x .shape , _axes , kw .get ("keepdims" , False ), ** kw
347+ )
348+ scalar_type = dh .get_scalar_type (out .dtype )
349+ for indices , out_idx in zip (axes_ndindex (x .shape , _axes ), ah .ndindex (out .shape )):
350+ sum_ = scalar_type (out [out_idx ])
351+ assume (not math .isinf (sum_ ))
352+ elements = []
353+ for idx in indices :
354+ s = scalar_type (x [idx ])
355+ elements .append (s )
356+ expected = sum (elements )
357+ if dh .is_int_dtype (out .dtype ):
358+ m , M = dh .dtype_ranges [out .dtype ]
359+ assume (m <= expected <= M )
360+ assert_equals ("sum" , dh .get_scalar_type (out .dtype ), out_idx , sum_ , expected )
0 commit comments