11"""
22https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html
33"""
4+ import math
45from collections import defaultdict
56from typing import Tuple , Union , List
67
2425@given (hh .mutually_promotable_dtypes (None ))
2526def test_result_type (dtypes ):
2627 out = xp .result_type (* dtypes )
27- ph .assert_dtype (' result_type' , dtypes , out , out_name = ' out' )
28+ ph .assert_dtype (" result_type" , dtypes , out , out_name = " out" )
2829
2930
31+ # The number and size of generated arrays is arbitrarily limited to prevent
32+ # meshgrid() running out of memory.
3033@given (
31- dtypes = hh .mutually_promotable_dtypes (None , dtypes = dh .numeric_dtypes ),
34+ dtypes = hh .mutually_promotable_dtypes (5 , dtypes = dh .numeric_dtypes ),
3235 data = st .data (),
3336)
3437def test_meshgrid (dtypes , data ):
3538 arrays = []
36- shapes = data .draw (hh .mutually_broadcastable_shapes (len (dtypes )), label = 'shapes' )
39+ shapes = data .draw (
40+ hh .mutually_broadcastable_shapes (
41+ len (dtypes ), min_dims = 1 , max_dims = 1 , max_side = 5
42+ ),
43+ label = "shapes" ,
44+ )
3745 for i , (dtype , shape ) in enumerate (zip (dtypes , shapes ), 1 ):
38- x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f' x{ i } ' )
46+ x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f" x{ i } " )
3947 arrays .append (x )
48+ assert math .prod (x .size for x in arrays ) <= hh .MAX_ARRAY_SIZE # sanity check
4049 out = xp .meshgrid (* arrays )
4150 for i , x in enumerate (out ):
42- ph .assert_dtype (' meshgrid' , dtypes , x .dtype , out_name = f' out[{ i } ].dtype' )
51+ ph .assert_dtype (" meshgrid" , dtypes , x .dtype , out_name = f" out[{ i } ].dtype" )
4352
4453
4554@given (
@@ -50,10 +59,10 @@ def test_meshgrid(dtypes, data):
5059def test_concat (shape , dtypes , data ):
5160 arrays = []
5261 for i , dtype in enumerate (dtypes , 1 ):
53- x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f' x{ i } ' )
62+ x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f" x{ i } " )
5463 arrays .append (x )
5564 out = xp .concat (arrays )
56- ph .assert_dtype (' concat' , dtypes , out .dtype )
65+ ph .assert_dtype (" concat" , dtypes , out .dtype )
5766
5867
5968@given (
@@ -64,26 +73,26 @@ def test_concat(shape, dtypes, data):
6473def test_stack (shape , dtypes , data ):
6574 arrays = []
6675 for i , dtype in enumerate (dtypes , 1 ):
67- x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f' x{ i } ' )
76+ x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f" x{ i } " )
6877 arrays .append (x )
6978 out = xp .stack (arrays )
70- ph .assert_dtype (' stack' , dtypes , out .dtype )
79+ ph .assert_dtype (" stack" , dtypes , out .dtype )
7180
7281
7382bitwise_shift_funcs = [
74- ' bitwise_left_shift' ,
75- ' bitwise_right_shift' ,
76- ' __lshift__' ,
77- ' __rshift__' ,
78- ' __ilshift__' ,
79- ' __irshift__' ,
83+ " bitwise_left_shift" ,
84+ " bitwise_right_shift" ,
85+ " __lshift__" ,
86+ " __rshift__" ,
87+ " __ilshift__" ,
88+ " __irshift__" ,
8089]
8190
8291
8392# We pass kwargs to the elements strategy used by xps.arrays() so that we don't
8493# generate array elements that are erroneous or undefined for a function.
8594func_elements = defaultdict (
86- lambda : None , {func : {' min_value' : 1 } for func in bitwise_shift_funcs }
95+ lambda : None , {func : {" min_value" : 1 } for func in bitwise_shift_funcs }
8796)
8897
8998
@@ -94,7 +103,7 @@ def make_id(
94103) -> str :
95104 f_args = dh .fmt_types (in_dtypes )
96105 f_out_dtype = dh .dtype_to_name [out_dtype ]
97- return f' { func_name } ({ f_args } ) -> { f_out_dtype } '
106+ return f" { func_name } ({ f_args } ) -> { f_out_dtype } "
98107
99108
100109func_params : List [Param [str , Tuple [DataType , ...], DataType ]] = []
@@ -128,25 +137,25 @@ def make_id(
128137 raise NotImplementedError ()
129138
130139
131- @pytest .mark .parametrize (' func_name, in_dtypes, out_dtype' , func_params )
140+ @pytest .mark .parametrize (" func_name, in_dtypes, out_dtype" , func_params )
132141@given (data = st .data ())
133142def test_func_promotion (func_name , in_dtypes , out_dtype , data ):
134143 func = getattr (xp , func_name )
135144 elements = func_elements [func_name ]
136145 if len (in_dtypes ) == 1 :
137146 x = data .draw (
138147 xps .arrays (dtype = in_dtypes [0 ], shape = hh .shapes (), elements = elements ),
139- label = 'x' ,
148+ label = "x" ,
140149 )
141150 out = func (x )
142151 else :
143152 arrays = []
144153 shapes = data .draw (
145- hh .mutually_broadcastable_shapes (len (in_dtypes )), label = ' shapes'
154+ hh .mutually_broadcastable_shapes (len (in_dtypes )), label = " shapes"
146155 )
147156 for i , (dtype , shape ) in enumerate (zip (in_dtypes , shapes ), 1 ):
148157 x = data .draw (
149- xps .arrays (dtype = dtype , shape = shape , elements = elements ), label = f' x{ i } '
158+ xps .arrays (dtype = dtype , shape = shape , elements = elements ), label = f" x{ i } "
150159 )
151160 arrays .append (x )
152161 try :
@@ -161,46 +170,46 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
161170 p = pytest .param (
162171 (dtype1 , dtype2 ),
163172 promoted_dtype ,
164- id = make_id ('' , (dtype1 , dtype2 ), promoted_dtype ),
173+ id = make_id ("" , (dtype1 , dtype2 ), promoted_dtype ),
165174 )
166175 promotion_params .append (p )
167176
168177
169- @pytest .mark .parametrize (' in_dtypes, out_dtype' , promotion_params )
178+ @pytest .mark .parametrize (" in_dtypes, out_dtype" , promotion_params )
170179@given (shapes = hh .mutually_broadcastable_shapes (3 ), data = st .data ())
171180def test_where (in_dtypes , out_dtype , shapes , data ):
172- x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
173- x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
174- cond = data .draw (xps .arrays (dtype = xp .bool , shape = shapes [2 ]), label = ' condition' )
181+ x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = "x1" )
182+ x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = "x2" )
183+ cond = data .draw (xps .arrays (dtype = xp .bool , shape = shapes [2 ]), label = " condition" )
175184 out = xp .where (cond , x1 , x2 )
176- ph .assert_dtype (' where' , in_dtypes , out .dtype , out_dtype )
185+ ph .assert_dtype (" where" , in_dtypes , out .dtype , out_dtype )
177186
178187
179188numeric_promotion_params = promotion_params [1 :]
180189
181190
182- @pytest .mark .parametrize (' in_dtypes, out_dtype' , numeric_promotion_params )
191+ @pytest .mark .parametrize (" in_dtypes, out_dtype" , numeric_promotion_params )
183192@given (shapes = hh .mutually_broadcastable_shapes (2 , min_dims = 2 ), data = st .data ())
184193def test_tensordot (in_dtypes , out_dtype , shapes , data ):
185- x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
186- x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
194+ x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = "x1" )
195+ x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = "x2" )
187196 out = xp .tensordot (x1 , x2 )
188- ph .assert_dtype (' tensordot' , in_dtypes , out .dtype , out_dtype )
197+ ph .assert_dtype (" tensordot" , in_dtypes , out .dtype , out_dtype )
189198
190199
191- @pytest .mark .parametrize (' in_dtypes, out_dtype' , numeric_promotion_params )
200+ @pytest .mark .parametrize (" in_dtypes, out_dtype" , numeric_promotion_params )
192201@given (shapes = hh .mutually_broadcastable_shapes (2 , min_dims = 1 ), data = st .data ())
193202def test_vecdot (in_dtypes , out_dtype , shapes , data ):
194- x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
195- x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
203+ x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = "x1" )
204+ x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = "x2" )
196205 out = xp .vecdot (x1 , x2 )
197- ph .assert_dtype (' vecdot' , in_dtypes , out .dtype , out_dtype )
206+ ph .assert_dtype (" vecdot" , in_dtypes , out .dtype , out_dtype )
198207
199208
200209op_params : List [Param [str , str , Tuple [DataType , ...], DataType ]] = []
201210op_to_symbol = {** dh .unary_op_to_symbol , ** dh .binary_op_to_symbol }
202211for op , symbol in op_to_symbol .items ():
203- if op == ' __matmul__' :
212+ if op == " __matmul__" :
204213 continue
205214 valid_in_dtypes = dh .func_in_dtypes [op ]
206215 ndtypes = ph .nargs (op )
@@ -209,7 +218,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
209218 out_dtype = xp .bool if dh .func_returns_bool [op ] else in_dtype
210219 p = pytest .param (
211220 op ,
212- f' { symbol } x' ,
221+ f" { symbol } x" ,
213222 (in_dtype ,),
214223 out_dtype ,
215224 id = make_id (op , (in_dtype ,), out_dtype ),
@@ -221,42 +230,42 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
221230 out_dtype = xp .bool if dh .func_returns_bool [op ] else promoted_dtype
222231 p = pytest .param (
223232 op ,
224- f' x1 { symbol } x2' ,
233+ f" x1 { symbol } x2" ,
225234 (in_dtype1 , in_dtype2 ),
226235 out_dtype ,
227236 id = make_id (op , (in_dtype1 , in_dtype2 ), out_dtype ),
228237 )
229238 op_params .append (p )
230239# We generate params for abs seperately as it does not have an associated symbol
231- for in_dtype in dh .func_in_dtypes [' __abs__' ]:
240+ for in_dtype in dh .func_in_dtypes [" __abs__" ]:
232241 p = pytest .param (
233- ' __abs__' ,
234- ' abs(x)' ,
242+ " __abs__" ,
243+ " abs(x)" ,
235244 (in_dtype ,),
236245 in_dtype ,
237- id = make_id (' __abs__' , (in_dtype ,), in_dtype ),
246+ id = make_id (" __abs__" , (in_dtype ,), in_dtype ),
238247 )
239248 op_params .append (p )
240249
241250
242- @pytest .mark .parametrize (' op, expr, in_dtypes, out_dtype' , op_params )
251+ @pytest .mark .parametrize (" op, expr, in_dtypes, out_dtype" , op_params )
243252@given (data = st .data ())
244253def test_op_promotion (op , expr , in_dtypes , out_dtype , data ):
245254 elements = func_elements [func_name ]
246255 if len (in_dtypes ) == 1 :
247256 x = data .draw (
248257 xps .arrays (dtype = in_dtypes [0 ], shape = hh .shapes (), elements = elements ),
249- label = 'x' ,
258+ label = "x" ,
250259 )
251- out = eval (expr , {'x' : x })
260+ out = eval (expr , {"x" : x })
252261 else :
253262 locals_ = {}
254263 shapes = data .draw (
255- hh .mutually_broadcastable_shapes (len (in_dtypes )), label = ' shapes'
264+ hh .mutually_broadcastable_shapes (len (in_dtypes )), label = " shapes"
256265 )
257266 for i , (dtype , shape ) in enumerate (zip (in_dtypes , shapes ), 1 ):
258- locals_ [f' x{ i } ' ] = data .draw (
259- xps .arrays (dtype = dtype , shape = shape , elements = elements ), label = f' x{ i } '
267+ locals_ [f" x{ i } " ] = data .draw (
268+ xps .arrays (dtype = dtype , shape = shape , elements = elements ), label = f" x{ i } "
260269 )
261270 try :
262271 out = eval (expr , locals_ )
@@ -267,7 +276,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
267276
268277inplace_params : List [Param [str , str , Tuple [DataType , ...], DataType ]] = []
269278for op , symbol in dh .inplace_op_to_symbol .items ():
270- if op == ' __imatmul__' :
279+ if op == " __imatmul__" :
271280 continue
272281 valid_in_dtypes = dh .func_in_dtypes [op ]
273282 for (in_dtype1 , in_dtype2 ), promoted_dtype in dh .promotion_table .items ():
@@ -278,44 +287,44 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
278287 ):
279288 p = pytest .param (
280289 op ,
281- f' x1 { symbol } x2' ,
290+ f" x1 { symbol } x2" ,
282291 (in_dtype1 , in_dtype2 ),
283292 promoted_dtype ,
284293 id = make_id (op , (in_dtype1 , in_dtype2 ), promoted_dtype ),
285294 )
286295 inplace_params .append (p )
287296
288297
289- @pytest .mark .parametrize (' op, expr, in_dtypes, out_dtype' , inplace_params )
298+ @pytest .mark .parametrize (" op, expr, in_dtypes, out_dtype" , inplace_params )
290299@given (shapes = hh .mutually_broadcastable_shapes (2 ), data = st .data ())
291300def test_inplace_op_promotion (op , expr , in_dtypes , out_dtype , shapes , data ):
292301 assume (len (shapes [0 ]) >= len (shapes [1 ]))
293302 elements = func_elements [func_name ]
294303 x1 = data .draw (
295- xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ], elements = elements ), label = 'x1'
304+ xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ], elements = elements ), label = "x1"
296305 )
297306 x2 = data .draw (
298- xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ], elements = elements ), label = 'x2'
307+ xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ], elements = elements ), label = "x2"
299308 )
300- locals_ = {'x1' : x1 , 'x2' : x2 }
309+ locals_ = {"x1" : x1 , "x2" : x2 }
301310 try :
302311 exec (expr , locals_ )
303312 except OverflowError :
304313 reject ()
305- x1 = locals_ ['x1' ]
306- ph .assert_dtype (op , in_dtypes , x1 .dtype , out_dtype , out_name = ' x1.dtype' )
314+ x1 = locals_ ["x1" ]
315+ ph .assert_dtype (op , in_dtypes , x1 .dtype , out_dtype , out_name = " x1.dtype" )
307316
308317
309318op_scalar_params : List [Param [str , str , DataType , ScalarType , DataType ]] = []
310319for op , symbol in dh .binary_op_to_symbol .items ():
311- if op == ' __matmul__' :
320+ if op == " __matmul__" :
312321 continue
313322 for in_dtype in dh .func_in_dtypes [op ]:
314323 out_dtype = xp .bool if dh .func_returns_bool [op ] else in_dtype
315324 for in_stype in dh .dtype_to_scalars [in_dtype ]:
316325 p = pytest .param (
317326 op ,
318- f' x { symbol } s' ,
327+ f" x { symbol } s" ,
319328 in_dtype ,
320329 in_stype ,
321330 out_dtype ,
@@ -324,57 +333,57 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
324333 op_scalar_params .append (p )
325334
326335
327- @pytest .mark .parametrize (' op, expr, in_dtype, in_stype, out_dtype' , op_scalar_params )
336+ @pytest .mark .parametrize (" op, expr, in_dtype, in_stype, out_dtype" , op_scalar_params )
328337@given (data = st .data ())
329338def test_op_scalar_promotion (op , expr , in_dtype , in_stype , out_dtype , data ):
330339 elements = func_elements [func_name ]
331- kw = {k : in_stype is float for k in (' allow_nan' , ' allow_infinity' )}
332- s = data .draw (xps .from_dtype (in_dtype , ** kw ).map (in_stype ), label = ' scalar' )
340+ kw = {k : in_stype is float for k in (" allow_nan" , " allow_infinity" )}
341+ s = data .draw (xps .from_dtype (in_dtype , ** kw ).map (in_stype ), label = " scalar" )
333342 x = data .draw (
334- xps .arrays (dtype = in_dtype , shape = hh .shapes (), elements = elements ), label = 'x'
343+ xps .arrays (dtype = in_dtype , shape = hh .shapes (), elements = elements ), label = "x"
335344 )
336345 try :
337- out = eval (expr , {'x' : x , 's' : s })
346+ out = eval (expr , {"x" : x , "s" : s })
338347 except OverflowError :
339348 reject ()
340349 ph .assert_dtype (op , (in_dtype , in_stype ), out .dtype , out_dtype )
341350
342351
343352inplace_scalar_params : List [Param [str , str , DataType , ScalarType ]] = []
344353for op , symbol in dh .inplace_op_to_symbol .items ():
345- if op == ' __imatmul__' :
354+ if op == " __imatmul__" :
346355 continue
347356 for dtype in dh .func_in_dtypes [op ]:
348357 for in_stype in dh .dtype_to_scalars [dtype ]:
349358 p = pytest .param (
350359 op ,
351- f' x { symbol } s' ,
360+ f" x { symbol } s" ,
352361 dtype ,
353362 in_stype ,
354363 id = make_id (op , (dtype , in_stype ), dtype ),
355364 )
356365 inplace_scalar_params .append (p )
357366
358367
359- @pytest .mark .parametrize (' op, expr, dtype, in_stype' , inplace_scalar_params )
368+ @pytest .mark .parametrize (" op, expr, dtype, in_stype" , inplace_scalar_params )
360369@given (data = st .data ())
361370def test_inplace_op_scalar_promotion (op , expr , dtype , in_stype , data ):
362371 elements = func_elements [func_name ]
363- kw = {k : in_stype is float for k in (' allow_nan' , ' allow_infinity' )}
364- s = data .draw (xps .from_dtype (dtype , ** kw ).map (in_stype ), label = ' scalar' )
372+ kw = {k : in_stype is float for k in (" allow_nan" , " allow_infinity" )}
373+ s = data .draw (xps .from_dtype (dtype , ** kw ).map (in_stype ), label = " scalar" )
365374 x = data .draw (
366- xps .arrays (dtype = dtype , shape = hh .shapes (), elements = elements ), label = 'x'
375+ xps .arrays (dtype = dtype , shape = hh .shapes (), elements = elements ), label = "x"
367376 )
368- locals_ = {'x' : x , 's' : s }
377+ locals_ = {"x" : x , "s" : s }
369378 try :
370379 exec (expr , locals_ )
371380 except OverflowError :
372381 reject ()
373- x = locals_ ['x' ]
374- assert x .dtype == dtype , f' { x .dtype = !s} , but should be { dtype } '
375- ph .assert_dtype (op , (dtype , in_stype ), x .dtype , dtype , out_name = ' x.dtype' )
382+ x = locals_ ["x" ]
383+ assert x .dtype == dtype , f" { x .dtype = !s} , but should be { dtype } "
384+ ph .assert_dtype (op , (dtype , in_stype ), x .dtype , dtype , out_name = " x.dtype" )
376385
377386
378- if __name__ == ' __main__' :
387+ if __name__ == " __main__" :
379388 for (i , j ), p in dh .promotion_table .items ():
380- print (f' ({ i } , { j } ) -> { p } ' )
389+ print (f" ({ i } , { j } ) -> { p } " )
0 commit comments