11import math
22
3- from hypothesis import given
3+ from hypothesis import assume , given
44from hypothesis import strategies as st
55
66from . import _array_module as xp
99from . import hypothesis_helpers as hh
1010from . import pytest_helpers as ph
1111from . import xps
12+ from .typing import Scalar , ScalarType
1213
13- RTOL = 0.05
14+
15+ def assert_equals (
16+ func_name : str , type_ : ScalarType , out : Scalar , expected : Scalar , / , ** kw
17+ ):
18+ f_func = f"{ func_name } ({ ph .fmt_kw (kw )} )"
19+ if type_ is bool or type_ is int :
20+ msg = f"{ out = } , should be { expected } [{ f_func } ]"
21+ assert out == expected , msg
22+ elif math .isnan (expected ):
23+ msg = f"{ out = } , should be { expected } [{ f_func } ]"
24+ assert math .isnan (out ), msg
25+ else :
26+ msg = f"{ out = } , should be roughly { expected } [{ f_func } ]"
27+ assert math .isclose (out , expected , rel_tol = 0.05 ), msg
1428
1529
1630@given (
@@ -34,7 +48,7 @@ def test_min(x, data):
3448 f_func = f"min({ ph .fmt_kw (kw )} )"
3549
3650 # TODO: support axis
37- if kw .get ("axis" ) is None :
51+ if kw .get ("axis" , None ) is None :
3852 keepdims = kw .get ("keepdims" , False )
3953 if keepdims :
4054 idx = tuple (1 for _ in x .shape )
@@ -53,11 +67,7 @@ def test_min(x, data):
5367 elements .append (s )
5468 min_ = scalar_type (_out )
5569 expected = min (elements )
56- msg = f"out={ min_ } , should be { expected } [{ f_func } ]"
57- if math .isnan (min_ ):
58- assert math .isnan (expected ), msg
59- else :
60- assert min_ == expected , msg
70+ assert_equals ("min" , dh .get_scalar_type (out .dtype ), min_ , expected )
6171
6272
6373@given (
@@ -81,7 +91,7 @@ def test_max(x, data):
8191 f_func = f"max({ ph .fmt_kw (kw )} )"
8292
8393 # TODO: support axis
84- if kw .get ("axis" ) is None :
94+ if kw .get ("axis" , None ) is None :
8595 keepdims = kw .get ("keepdims" , False )
8696 if keepdims :
8797 idx = tuple (1 for _ in x .shape )
@@ -100,11 +110,7 @@ def test_max(x, data):
100110 elements .append (s )
101111 max_ = scalar_type (_out )
102112 expected = max (elements )
103- msg = f"out={ max_ } , should be { expected } [{ f_func } ]"
104- if math .isnan (max_ ):
105- assert math .isnan (expected ), msg
106- else :
107- assert max_ == expected , msg
113+ assert_equals ("mean" , dh .get_scalar_type (out .dtype ), max_ , expected )
108114
109115
110116@given (
@@ -128,7 +134,7 @@ def test_mean(x, data):
128134 f_func = f"mean({ ph .fmt_kw (kw )} )"
129135
130136 # TODO: support axis
131- if kw .get ("axis" ) is None :
137+ if kw .get ("axis" , None ) is None :
132138 keepdims = kw .get ("keepdims" , False )
133139 if keepdims :
134140 idx = tuple (1 for _ in x .shape )
@@ -146,15 +152,75 @@ def test_mean(x, data):
146152 elements .append (s )
147153 mean = float (_out )
148154 expected = sum (elements ) / len (elements )
149- msg = f"out={ mean } , should be roughly { expected } [{ f_func } ]"
150- assert math .isclose (mean , expected , rel_tol = RTOL ), msg
155+ assert_equals ("mean" , float , mean , expected )
151156
152157
153158# TODO: generate kwargs
154- @given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes (min_side = 1 )))
155- def test_prod (x ):
156- xp .prod (x )
157- # TODO
159+ @given (
160+ x = xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes (min_side = 1 )),
161+ data = st .data (),
162+ )
163+ def test_prod (x , data ):
164+ axis_strats = [st .none ()]
165+ if x .shape != ():
166+ axis_strats .append (
167+ st .integers (- x .ndim , x .ndim - 1 ) | xps .valid_tuple_axes (x .ndim )
168+ )
169+ kw = data .draw (
170+ hh .kwargs (
171+ axis = st .one_of (axis_strats ),
172+ dtype = st .none () | st .just (x .dtype ), # TODO: all valid dtypes
173+ keepdims = st .booleans (),
174+ ),
175+ label = "kw" ,
176+ )
177+
178+ out = xp .prod (x , ** kw )
179+
180+ dtype = kw .get ("dtype" , None )
181+ if dtype is None :
182+ if dh .is_int_dtype (x .dtype ):
183+ m , M = dh .dtype_ranges [x .dtype ]
184+ d_m , d_M = dh .dtype_ranges [dh .default_int ]
185+ if m < d_m or M > d_M :
186+ _dtype = x .dtype
187+ else :
188+ _dtype = dh .default_int
189+ else :
190+ if dh .dtype_nbits [x .dtype ] > dh .dtype_nbits [dh .default_float ]:
191+ _dtype = x .dtype
192+ else :
193+ _dtype = dh .default_float
194+ else :
195+ _dtype = dtype
196+ ph .assert_dtype ("prod" , x .dtype , out .dtype , _dtype )
197+
198+ f_func = f"prod({ ph .fmt_kw (kw )} )"
199+
200+ # TODO: support axis
201+ if kw .get ("axis" , None ) is None :
202+ keepdims = kw .get ("keepdims" , False )
203+ if keepdims :
204+ idx = tuple (1 for _ in x .shape )
205+ msg = f"{ out .shape = } , should be reduced dimension { idx } [{ f_func } ]"
206+ assert out .shape == idx , msg
207+ else :
208+ ph .assert_shape ("prod" , out .shape , (), ** kw )
209+
210+ # TODO: figure out NaN behaviour
211+ if dh .is_int_dtype (x .dtype ) or not xp .any (xp .isnan (x )):
212+ _out = xp .reshape (out , ()) if keepdims else out
213+ scalar_type = dh .get_scalar_type (out .dtype )
214+ elements = []
215+ for idx in ah .ndindex (x .shape ):
216+ s = scalar_type (x [idx ])
217+ elements .append (s )
218+ prod = scalar_type (_out )
219+ expected = math .prod (elements )
220+ if dh .is_int_dtype (out .dtype ):
221+ m , M = dh .dtype_ranges [out .dtype ]
222+ assume (m <= expected <= M )
223+ assert_equals ("prod" , dh .get_scalar_type (out .dtype ), prod , expected )
158224
159225
160226# TODO: generate kwargs
0 commit comments