2424@given (hh .mutually_promotable_dtypes (None ))
2525def test_result_type (dtypes ):
2626 out = xp .result_type (* dtypes )
27- ph .assert_dtype (' result_type' , dtypes , out , out_name = ' out' )
27+ ph .assert_dtype (" result_type" , dtypes , out , out_name = " out" )
2828
2929
30+ # The number and size of generated arrays is arbitrarily limited to prevent
31+ # meshgrid() running out of memory.
3032@given (
31- dtypes = hh .mutually_promotable_dtypes (None , dtypes = dh .numeric_dtypes ),
33+ dtypes = hh .mutually_promotable_dtypes (5 , dtypes = dh .numeric_dtypes ),
3234 data = st .data (),
3335)
3436def test_meshgrid (dtypes , data ):
3537 arrays = []
36- shapes = data .draw (hh .mutually_broadcastable_shapes (len (dtypes )), label = 'shapes' )
38+ shapes = data .draw (
39+ hh .mutually_broadcastable_shapes (
40+ len (dtypes ), min_dims = 1 , max_dims = 1 , max_side = 5
41+ ),
42+ label = "shapes" ,
43+ )
3744 for i , (dtype , shape ) in enumerate (zip (dtypes , shapes ), 1 ):
38- x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f' x{ i } ' )
45+ x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f" x{ i } " )
3946 arrays .append (x )
4047 out = xp .meshgrid (* arrays )
4148 for i , x in enumerate (out ):
42- ph .assert_dtype (' meshgrid' , dtypes , x .dtype , out_name = f' out[{ i } ].dtype' )
49+ ph .assert_dtype (" meshgrid" , dtypes , x .dtype , out_name = f" out[{ i } ].dtype" )
4350
4451
4552@given (
@@ -50,10 +57,10 @@ def test_meshgrid(dtypes, data):
5057def test_concat (shape , dtypes , data ):
5158 arrays = []
5259 for i , dtype in enumerate (dtypes , 1 ):
53- x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f' x{ i } ' )
60+ x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f" x{ i } " )
5461 arrays .append (x )
5562 out = xp .concat (arrays )
56- ph .assert_dtype (' concat' , dtypes , out .dtype )
63+ ph .assert_dtype (" concat" , dtypes , out .dtype )
5764
5865
5966@given (
@@ -64,26 +71,26 @@ def test_concat(shape, dtypes, data):
6471def test_stack (shape , dtypes , data ):
6572 arrays = []
6673 for i , dtype in enumerate (dtypes , 1 ):
67- x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f' x{ i } ' )
74+ x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f" x{ i } " )
6875 arrays .append (x )
6976 out = xp .stack (arrays )
70- ph .assert_dtype (' stack' , dtypes , out .dtype )
77+ ph .assert_dtype (" stack" , dtypes , out .dtype )
7178
7279
7380bitwise_shift_funcs = [
74- ' bitwise_left_shift' ,
75- ' bitwise_right_shift' ,
76- ' __lshift__' ,
77- ' __rshift__' ,
78- ' __ilshift__' ,
79- ' __irshift__' ,
81+ " bitwise_left_shift" ,
82+ " bitwise_right_shift" ,
83+ " __lshift__" ,
84+ " __rshift__" ,
85+ " __ilshift__" ,
86+ " __irshift__" ,
8087]
8188
8289
8390# We pass kwargs to the elements strategy used by xps.arrays() so that we don't
8491# generate array elements that are erroneous or undefined for a function.
8592func_elements = defaultdict (
86- lambda : None , {func : {' min_value' : 1 } for func in bitwise_shift_funcs }
93+ lambda : None , {func : {" min_value" : 1 } for func in bitwise_shift_funcs }
8794)
8895
8996
@@ -94,7 +101,7 @@ def make_id(
94101) -> str :
95102 f_args = dh .fmt_types (in_dtypes )
96103 f_out_dtype = dh .dtype_to_name [out_dtype ]
97- return f' { func_name } ({ f_args } ) -> { f_out_dtype } '
104+ return f" { func_name } ({ f_args } ) -> { f_out_dtype } "
98105
99106
100107func_params : List [Param [str , Tuple [DataType , ...], DataType ]] = []
@@ -128,25 +135,25 @@ def make_id(
128135 raise NotImplementedError ()
129136
130137
131- @pytest .mark .parametrize (' func_name, in_dtypes, out_dtype' , func_params )
138+ @pytest .mark .parametrize (" func_name, in_dtypes, out_dtype" , func_params )
132139@given (data = st .data ())
133140def test_func_promotion (func_name , in_dtypes , out_dtype , data ):
134141 func = getattr (xp , func_name )
135142 elements = func_elements [func_name ]
136143 if len (in_dtypes ) == 1 :
137144 x = data .draw (
138145 xps .arrays (dtype = in_dtypes [0 ], shape = hh .shapes (), elements = elements ),
139- label = 'x' ,
146+ label = "x" ,
140147 )
141148 out = func (x )
142149 else :
143150 arrays = []
144151 shapes = data .draw (
145- hh .mutually_broadcastable_shapes (len (in_dtypes )), label = ' shapes'
152+ hh .mutually_broadcastable_shapes (len (in_dtypes )), label = " shapes"
146153 )
147154 for i , (dtype , shape ) in enumerate (zip (in_dtypes , shapes ), 1 ):
148155 x = data .draw (
149- xps .arrays (dtype = dtype , shape = shape , elements = elements ), label = f' x{ i } '
156+ xps .arrays (dtype = dtype , shape = shape , elements = elements ), label = f" x{ i } "
150157 )
151158 arrays .append (x )
152159 try :
@@ -161,46 +168,46 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
161168 p = pytest .param (
162169 (dtype1 , dtype2 ),
163170 promoted_dtype ,
164- id = make_id ('' , (dtype1 , dtype2 ), promoted_dtype ),
171+ id = make_id ("" , (dtype1 , dtype2 ), promoted_dtype ),
165172 )
166173 promotion_params .append (p )
167174
168175
169- @pytest .mark .parametrize (' in_dtypes, out_dtype' , promotion_params )
176+ @pytest .mark .parametrize (" in_dtypes, out_dtype" , promotion_params )
170177@given (shapes = hh .mutually_broadcastable_shapes (3 ), data = st .data ())
171178def test_where (in_dtypes , out_dtype , shapes , data ):
172- x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
173- x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
174- cond = data .draw (xps .arrays (dtype = xp .bool , shape = shapes [2 ]), label = ' condition' )
179+ x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = "x1" )
180+ x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = "x2" )
181+ cond = data .draw (xps .arrays (dtype = xp .bool , shape = shapes [2 ]), label = " condition" )
175182 out = xp .where (cond , x1 , x2 )
176- ph .assert_dtype (' where' , in_dtypes , out .dtype , out_dtype )
183+ ph .assert_dtype (" where" , in_dtypes , out .dtype , out_dtype )
177184
178185
179186numeric_promotion_params = promotion_params [1 :]
180187
181188
182- @pytest .mark .parametrize (' in_dtypes, out_dtype' , numeric_promotion_params )
189+ @pytest .mark .parametrize (" in_dtypes, out_dtype" , numeric_promotion_params )
183190@given (shapes = hh .mutually_broadcastable_shapes (2 , min_dims = 2 ), data = st .data ())
184191def test_tensordot (in_dtypes , out_dtype , shapes , data ):
185- x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
186- x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
192+ x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = "x1" )
193+ x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = "x2" )
187194 out = xp .tensordot (x1 , x2 )
188- ph .assert_dtype (' tensordot' , in_dtypes , out .dtype , out_dtype )
195+ ph .assert_dtype (" tensordot" , in_dtypes , out .dtype , out_dtype )
189196
190197
191- @pytest .mark .parametrize (' in_dtypes, out_dtype' , numeric_promotion_params )
198+ @pytest .mark .parametrize (" in_dtypes, out_dtype" , numeric_promotion_params )
192199@given (shapes = hh .mutually_broadcastable_shapes (2 , min_dims = 1 ), data = st .data ())
193200def test_vecdot (in_dtypes , out_dtype , shapes , data ):
194- x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
195- x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
201+ x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = "x1" )
202+ x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = "x2" )
196203 out = xp .vecdot (x1 , x2 )
197- ph .assert_dtype (' vecdot' , in_dtypes , out .dtype , out_dtype )
204+ ph .assert_dtype (" vecdot" , in_dtypes , out .dtype , out_dtype )
198205
199206
200207op_params : List [Param [str , str , Tuple [DataType , ...], DataType ]] = []
201208op_to_symbol = {** dh .unary_op_to_symbol , ** dh .binary_op_to_symbol }
202209for op , symbol in op_to_symbol .items ():
203- if op == ' __matmul__' :
210+ if op == " __matmul__" :
204211 continue
205212 valid_in_dtypes = dh .func_in_dtypes [op ]
206213 ndtypes = ph .nargs (op )
@@ -209,7 +216,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
209216 out_dtype = xp .bool if dh .func_returns_bool [op ] else in_dtype
210217 p = pytest .param (
211218 op ,
212- f' { symbol } x' ,
219+ f" { symbol } x" ,
213220 (in_dtype ,),
214221 out_dtype ,
215222 id = make_id (op , (in_dtype ,), out_dtype ),
@@ -221,42 +228,42 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
221228 out_dtype = xp .bool if dh .func_returns_bool [op ] else promoted_dtype
222229 p = pytest .param (
223230 op ,
224- f' x1 { symbol } x2' ,
231+ f" x1 { symbol } x2" ,
225232 (in_dtype1 , in_dtype2 ),
226233 out_dtype ,
227234 id = make_id (op , (in_dtype1 , in_dtype2 ), out_dtype ),
228235 )
229236 op_params .append (p )
230237# We generate params for abs seperately as it does not have an associated symbol
231- for in_dtype in dh .func_in_dtypes [' __abs__' ]:
238+ for in_dtype in dh .func_in_dtypes [" __abs__" ]:
232239 p = pytest .param (
233- ' __abs__' ,
234- ' abs(x)' ,
240+ " __abs__" ,
241+ " abs(x)" ,
235242 (in_dtype ,),
236243 in_dtype ,
237- id = make_id (' __abs__' , (in_dtype ,), in_dtype ),
244+ id = make_id (" __abs__" , (in_dtype ,), in_dtype ),
238245 )
239246 op_params .append (p )
240247
241248
242- @pytest .mark .parametrize (' op, expr, in_dtypes, out_dtype' , op_params )
249+ @pytest .mark .parametrize (" op, expr, in_dtypes, out_dtype" , op_params )
243250@given (data = st .data ())
244251def test_op_promotion (op , expr , in_dtypes , out_dtype , data ):
245252 elements = func_elements [func_name ]
246253 if len (in_dtypes ) == 1 :
247254 x = data .draw (
248255 xps .arrays (dtype = in_dtypes [0 ], shape = hh .shapes (), elements = elements ),
249- label = 'x' ,
256+ label = "x" ,
250257 )
251- out = eval (expr , {'x' : x })
258+ out = eval (expr , {"x" : x })
252259 else :
253260 locals_ = {}
254261 shapes = data .draw (
255- hh .mutually_broadcastable_shapes (len (in_dtypes )), label = ' shapes'
262+ hh .mutually_broadcastable_shapes (len (in_dtypes )), label = " shapes"
256263 )
257264 for i , (dtype , shape ) in enumerate (zip (in_dtypes , shapes ), 1 ):
258- locals_ [f' x{ i } ' ] = data .draw (
259- xps .arrays (dtype = dtype , shape = shape , elements = elements ), label = f' x{ i } '
265+ locals_ [f" x{ i } " ] = data .draw (
266+ xps .arrays (dtype = dtype , shape = shape , elements = elements ), label = f" x{ i } "
260267 )
261268 try :
262269 out = eval (expr , locals_ )
@@ -267,7 +274,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
267274
268275inplace_params : List [Param [str , str , Tuple [DataType , ...], DataType ]] = []
269276for op , symbol in dh .inplace_op_to_symbol .items ():
270- if op == ' __imatmul__' :
277+ if op == " __imatmul__" :
271278 continue
272279 valid_in_dtypes = dh .func_in_dtypes [op ]
273280 for (in_dtype1 , in_dtype2 ), promoted_dtype in dh .promotion_table .items ():
@@ -278,44 +285,44 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
278285 ):
279286 p = pytest .param (
280287 op ,
281- f' x1 { symbol } x2' ,
288+ f" x1 { symbol } x2" ,
282289 (in_dtype1 , in_dtype2 ),
283290 promoted_dtype ,
284291 id = make_id (op , (in_dtype1 , in_dtype2 ), promoted_dtype ),
285292 )
286293 inplace_params .append (p )
287294
288295
289- @pytest .mark .parametrize (' op, expr, in_dtypes, out_dtype' , inplace_params )
296+ @pytest .mark .parametrize (" op, expr, in_dtypes, out_dtype" , inplace_params )
290297@given (shapes = hh .mutually_broadcastable_shapes (2 ), data = st .data ())
291298def test_inplace_op_promotion (op , expr , in_dtypes , out_dtype , shapes , data ):
292299 assume (len (shapes [0 ]) >= len (shapes [1 ]))
293300 elements = func_elements [func_name ]
294301 x1 = data .draw (
295- xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ], elements = elements ), label = 'x1'
302+ xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ], elements = elements ), label = "x1"
296303 )
297304 x2 = data .draw (
298- xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ], elements = elements ), label = 'x2'
305+ xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ], elements = elements ), label = "x2"
299306 )
300- locals_ = {'x1' : x1 , 'x2' : x2 }
307+ locals_ = {"x1" : x1 , "x2" : x2 }
301308 try :
302309 exec (expr , locals_ )
303310 except OverflowError :
304311 reject ()
305- x1 = locals_ ['x1' ]
306- ph .assert_dtype (op , in_dtypes , x1 .dtype , out_dtype , out_name = ' x1.dtype' )
312+ x1 = locals_ ["x1" ]
313+ ph .assert_dtype (op , in_dtypes , x1 .dtype , out_dtype , out_name = " x1.dtype" )
307314
308315
309316op_scalar_params : List [Param [str , str , DataType , ScalarType , DataType ]] = []
310317for op , symbol in dh .binary_op_to_symbol .items ():
311- if op == ' __matmul__' :
318+ if op == " __matmul__" :
312319 continue
313320 for in_dtype in dh .func_in_dtypes [op ]:
314321 out_dtype = xp .bool if dh .func_returns_bool [op ] else in_dtype
315322 for in_stype in dh .dtype_to_scalars [in_dtype ]:
316323 p = pytest .param (
317324 op ,
318- f' x { symbol } s' ,
325+ f" x { symbol } s" ,
319326 in_dtype ,
320327 in_stype ,
321328 out_dtype ,
@@ -324,57 +331,57 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
324331 op_scalar_params .append (p )
325332
326333
327- @pytest .mark .parametrize (' op, expr, in_dtype, in_stype, out_dtype' , op_scalar_params )
334+ @pytest .mark .parametrize (" op, expr, in_dtype, in_stype, out_dtype" , op_scalar_params )
328335@given (data = st .data ())
329336def test_op_scalar_promotion (op , expr , in_dtype , in_stype , out_dtype , data ):
330337 elements = func_elements [func_name ]
331- kw = {k : in_stype is float for k in (' allow_nan' , ' allow_infinity' )}
332- s = data .draw (xps .from_dtype (in_dtype , ** kw ).map (in_stype ), label = ' scalar' )
338+ kw = {k : in_stype is float for k in (" allow_nan" , " allow_infinity" )}
339+ s = data .draw (xps .from_dtype (in_dtype , ** kw ).map (in_stype ), label = " scalar" )
333340 x = data .draw (
334- xps .arrays (dtype = in_dtype , shape = hh .shapes (), elements = elements ), label = 'x'
341+ xps .arrays (dtype = in_dtype , shape = hh .shapes (), elements = elements ), label = "x"
335342 )
336343 try :
337- out = eval (expr , {'x' : x , 's' : s })
344+ out = eval (expr , {"x" : x , "s" : s })
338345 except OverflowError :
339346 reject ()
340347 ph .assert_dtype (op , (in_dtype , in_stype ), out .dtype , out_dtype )
341348
342349
343350inplace_scalar_params : List [Param [str , str , DataType , ScalarType ]] = []
344351for op , symbol in dh .inplace_op_to_symbol .items ():
345- if op == ' __imatmul__' :
352+ if op == " __imatmul__" :
346353 continue
347354 for dtype in dh .func_in_dtypes [op ]:
348355 for in_stype in dh .dtype_to_scalars [dtype ]:
349356 p = pytest .param (
350357 op ,
351- f' x { symbol } s' ,
358+ f" x { symbol } s" ,
352359 dtype ,
353360 in_stype ,
354361 id = make_id (op , (dtype , in_stype ), dtype ),
355362 )
356363 inplace_scalar_params .append (p )
357364
358365
359- @pytest .mark .parametrize (' op, expr, dtype, in_stype' , inplace_scalar_params )
366+ @pytest .mark .parametrize (" op, expr, dtype, in_stype" , inplace_scalar_params )
360367@given (data = st .data ())
361368def test_inplace_op_scalar_promotion (op , expr , dtype , in_stype , data ):
362369 elements = func_elements [func_name ]
363- kw = {k : in_stype is float for k in (' allow_nan' , ' allow_infinity' )}
364- s = data .draw (xps .from_dtype (dtype , ** kw ).map (in_stype ), label = ' scalar' )
370+ kw = {k : in_stype is float for k in (" allow_nan" , " allow_infinity" )}
371+ s = data .draw (xps .from_dtype (dtype , ** kw ).map (in_stype ), label = " scalar" )
365372 x = data .draw (
366- xps .arrays (dtype = dtype , shape = hh .shapes (), elements = elements ), label = 'x'
373+ xps .arrays (dtype = dtype , shape = hh .shapes (), elements = elements ), label = "x"
367374 )
368- locals_ = {'x' : x , 's' : s }
375+ locals_ = {"x" : x , "s" : s }
369376 try :
370377 exec (expr , locals_ )
371378 except OverflowError :
372379 reject ()
373- x = locals_ ['x' ]
374- assert x .dtype == dtype , f' { x .dtype = !s} , but should be { dtype } '
375- ph .assert_dtype (op , (dtype , in_stype ), x .dtype , dtype , out_name = ' x.dtype' )
380+ x = locals_ ["x" ]
381+ assert x .dtype == dtype , f" { x .dtype = !s} , but should be { dtype } "
382+ ph .assert_dtype (op , (dtype , in_stype ), x .dtype , dtype , out_name = " x.dtype" )
376383
377384
378- if __name__ == ' __main__' :
385+ if __name__ == " __main__" :
379386 for (i , j ), p in dh .promotion_table .items ():
380- print (f' ({ i } , { j } ) -> { p } ' )
387+ print (f" ({ i } , { j } ) -> { p } " )
0 commit comments