22from operator import mul
33from math import sqrt
44import itertools
5- from typing import Tuple
5+ from typing import Tuple , Optional , List
66
77from hypothesis import assume
88from hypothesis .strategies import (lists , integers , sampled_from ,
99 shared , floats , just , composite , one_of ,
10- none , booleans )
11- from hypothesis .strategies ._internal .strategies import SearchStrategy
10+ none , booleans , SearchStrategy )
1211
1312from .pytest_helpers import nargs
1413from .array_helpers import ndindex
14+ from .typing import DataType , Shape
1515from . import dtype_helpers as dh
1616from ._array_module import (full , float32 , float64 , bool as bool_dtype ,
1717 _UndefinedStub , eye , broadcast_to )
5050_dtype_categories = [(xp .bool ,), dh .uint_dtypes , dh .int_dtypes , dh .float_dtypes ]
5151_sorted_dtypes = [d for category in _dtype_categories for d in category ]
5252
53- def _dtypes_sorter (dtype_pair ):
53+ def _dtypes_sorter (dtype_pair : Tuple [ DataType , DataType ] ):
5454 dtype1 , dtype2 = dtype_pair
5555 if dtype1 == dtype2 :
5656 return _sorted_dtypes .index (dtype1 )
@@ -67,7 +67,7 @@ def _dtypes_sorter(dtype_pair):
6767 key += 1
6868 return key
6969
70- promotable_dtypes = sorted (dh .promotion_table .keys (), key = _dtypes_sorter )
70+ promotable_dtypes : List [ Tuple [ DataType , DataType ]] = sorted (dh .promotion_table .keys (), key = _dtypes_sorter )
7171
7272if FILTER_UNDEFINED_DTYPES :
7373 promotable_dtypes = [
@@ -77,10 +77,34 @@ def _dtypes_sorter(dtype_pair):
7777 ]
7878
7979
80- def mutually_promotable_dtypes (dtype_objs = dh .all_dtypes ):
81- return sampled_from (
82- [(i , j ) for i , j in promotable_dtypes if i in dtype_objs and j in dtype_objs ]
83- )
80+ def mutually_promotable_dtypes (
81+ max_size : Optional [int ] = 2 ,
82+ * ,
83+ dtypes : Tuple [DataType , ...] = dh .all_dtypes ,
84+ ) -> SearchStrategy [Tuple [DataType , ...]]:
85+ if max_size == 2 :
86+ return sampled_from (
87+ [(i , j ) for i , j in promotable_dtypes if i in dtypes and j in dtypes ]
88+ )
89+ if isinstance (max_size , int ) and max_size < 2 :
90+ raise ValueError (f'{ max_size = } should be >=2' )
91+ strats = []
92+ category_samples = {
93+ category : [d for d in dtypes if d in category ] for category in _dtype_categories
94+ }
95+ for samples in category_samples .values ():
96+ if len (samples ) > 0 :
97+ strat = lists (sampled_from (samples ), min_size = 2 , max_size = max_size )
98+ strats .append (strat )
99+ if len (category_samples [dh .uint_dtypes ]) > 0 and len (category_samples [dh .int_dtypes ]) > 0 :
100+ mixed_samples = category_samples [dh .uint_dtypes ] + category_samples [dh .int_dtypes ]
101+ strat = lists (sampled_from (mixed_samples ), min_size = 2 , max_size = max_size )
102+ if xp .uint64 in mixed_samples :
103+ strat = strat .filter (
104+ lambda l : not (xp .uint64 in l and any (d in dh .int_dtypes for d in l ))
105+ )
106+ return one_of (strats ).map (tuple )
107+
84108
85109# shared() allows us to draw either the function or the function name and they
86110# will both correspond to the same function.
@@ -113,15 +137,19 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False)
113137
114138# Use this to avoid memory errors with NumPy.
115139# See https://github.com/numpy/numpy/issues/15753
116- shapes = xps .array_shapes (min_dims = 0 , min_side = 0 ).filter (
117- lambda shape : prod (i for i in shape if i ) < MAX_ARRAY_SIZE
118- )
140+ def shapes (** kw ):
141+ kw .setdefault ('min_dims' , 0 )
142+ kw .setdefault ('min_side' , 0 )
143+ return xps .array_shapes (** kw ).filter (
144+ lambda shape : prod (i for i in shape if i ) < MAX_ARRAY_SIZE
145+ )
146+
119147
120148one_d_shapes = xps .array_shapes (min_dims = 1 , max_dims = 1 , min_side = 0 , max_side = SQRT_MAX_ARRAY_SIZE )
121149
122150# Matrix shapes assume stacks of matrices
123151@composite
124- def matrix_shapes (draw , stack_shapes = shapes ):
152+ def matrix_shapes (draw , stack_shapes = shapes () ):
125153 stack_shape = draw (stack_shapes )
126154 mat_shape = draw (xps .array_shapes (max_dims = 2 , min_dims = 2 ))
127155 shape = stack_shape + mat_shape
@@ -135,9 +163,11 @@ def matrix_shapes(draw, stack_shapes=shapes):
135163 elements = dict (allow_nan = False ,
136164 allow_infinity = False ))
137165
138- def mutually_broadcastable_shapes (num_shapes : int ) -> SearchStrategy [Tuple [Tuple ]]:
166+ def mutually_broadcastable_shapes (
167+ num_shapes : int , ** kw
168+ ) -> SearchStrategy [Tuple [Shape , ...]]:
139169 return (
140- xps .mutually_broadcastable_shapes (num_shapes )
170+ xps .mutually_broadcastable_shapes (num_shapes , ** kw )
141171 .map (lambda BS : BS .input_shapes )
142172 .filter (lambda shapes : all (
143173 prod (i for i in s if i > 0 ) < MAX_ARRAY_SIZE for s in shapes
@@ -164,13 +194,13 @@ def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):
164194 # using something like
165195 # https://github.com/scikit-learn/scikit-learn/blob/844b4be24/sklearn/datasets/_samples_generator.py#L1351.
166196 n = draw (integers (0 ))
167- shape = draw (shapes ) + (n , n )
197+ shape = draw (shapes () ) + (n , n )
168198 assume (prod (i for i in shape if i ) < MAX_ARRAY_SIZE )
169199 dtype = draw (dtypes )
170200 return broadcast_to (eye (n , dtype = dtype ), shape )
171201
172202@composite
173- def invertible_matrices (draw , dtypes = xps .floating_dtypes (), stack_shapes = shapes ):
203+ def invertible_matrices (draw , dtypes = xps .floating_dtypes (), stack_shapes = shapes () ):
174204 # For now, just generate stacks of diagonal matrices.
175205 n = draw (integers (0 , SQRT_MAX_ARRAY_SIZE ),)
176206 stack_shape = draw (stack_shapes )
@@ -318,9 +348,10 @@ def multiaxis_indices(draw, shapes):
318348
319349
320350def two_mutual_arrays (
321- dtype_objs = dh .all_dtypes , two_shapes = two_mutually_broadcastable_shapes
322- ):
323- mutual_dtypes = shared (mutually_promotable_dtypes (dtype_objs ))
351+ dtypes : Tuple [DataType , ...] = dh .all_dtypes ,
352+ two_shapes : SearchStrategy [Tuple [Shape , Shape ]] = two_mutually_broadcastable_shapes ,
353+ ) -> SearchStrategy :
354+ mutual_dtypes = shared (mutually_promotable_dtypes (dtypes = dtypes ))
324355 mutual_shapes = shared (two_shapes )
325356 arrays1 = xps .arrays (
326357 dtype = mutual_dtypes .map (lambda pair : pair [0 ]),
0 commit comments