1111from . import hypothesis_helpers as hh
1212from . import pytest_helpers as ph
1313from . import xps
14- from .typing import Scalar , ScalarType , Shape
14+ from .typing import DataType , Scalar , ScalarType , Shape
1515
1616
1717def axes (ndim : int ) -> st .SearchStrategy [Optional [Union [int , Shape ]]]:
@@ -22,6 +22,11 @@ def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]:
2222 return st .one_of (axes_strats )
2323
2424
25+ def kwarg_dtypes (dtype : DataType ) -> st .SearchStrategy [Optional [DataType ]]:
26+ dtypes = [d2 for d1 , d2 in dh .promotion_table if d1 == dtype ]
27+ return st .none () | st .sampled_from (dtypes )
28+
29+
2530def normalise_axis (
2631 axis : Optional [Union [int , Tuple [int , ...]]], ndim : int
2732) -> Tuple [int , ...]:
@@ -190,7 +195,7 @@ def test_prod(x, data):
190195 kw = data .draw (
191196 hh .kwargs (
192197 axis = axes (x .ndim ),
193- dtype = st . none () | st . just ( x .dtype ), # TODO: all valid dtypes
198+ dtype = kwarg_dtypes ( x .dtype ),
194199 keepdims = st .booleans (),
195200 ),
196201 label = "kw" ,
@@ -316,7 +321,7 @@ def test_sum(x, data):
316321 kw = data .draw (
317322 hh .kwargs (
318323 axis = axes (x .ndim ),
319- dtype = st . none () | st . just ( x .dtype ), # TODO: all valid dtypes
324+ dtype = kwarg_dtypes ( x .dtype ),
320325 keepdims = st .booleans (),
321326 ),
322327 label = "kw" ,
0 commit comments