1010from . import dtype_helpers as dh
1111from . import pytest_helpers as ph
1212from . import xps
13- from .typing import Shape , DataType
13+ from .typing import Shape , DataType , Array
1414
1515
1616def assert_default_float (func_name : str , dtype : DataType ):
@@ -33,11 +33,7 @@ def assert_default_int(func_name: str, dtype: DataType):
3333 assert dtype == dh .default_int , msg
3434
3535
36- def assert_kw_dtype (
37- func_name : str ,
38- kw_dtype : DataType ,
39- out_dtype : DataType ,
40- ):
36+ def assert_kw_dtype (func_name : str , kw_dtype : DataType , out_dtype : DataType ):
4137 f_kw_dtype = dh .dtype_to_name [kw_dtype ]
4238 f_out_dtype = dh .dtype_to_name [out_dtype ]
4339 msg = (
@@ -47,12 +43,7 @@ def assert_kw_dtype(
4743 assert out_dtype == kw_dtype , msg
4844
4945
50- def assert_shape (
51- func_name : str ,
52- out_shape : Shape ,
53- expected : Union [int , Shape ],
54- ** kw ,
55- ):
46+ def assert_shape (func_name : str , out_shape : Shape , expected : Union [int , Shape ], ** kw ):
5647 f_kw = ", " .join (f"{ k } ={ v } " for k , v in kw .items ())
5748 msg = f"out.shape={ out_shape } , but should be { expected } [{ func_name } ({ f_kw } )]"
5849 if isinstance (expected , int ):
@@ -61,6 +52,18 @@ def assert_shape(
6152
6253
6354
55+ def assert_fill (func_name : str , fill : float , dtype : DataType , out : Array , ** kw ):
56+ f_kw = ", " .join (f"{ k } ={ v } " for k , v in kw .items ())
57+ msg = (
58+ f"out not filled with { fill } [{ func_name } ({ f_kw } )]\n "
59+ f"{ out = } "
60+ )
61+ if math .isnan (fill ):
62+ assert ah .all (ah .isnan (out )), msg
63+ else :
64+ assert ah .all (ah .equal (out , ah .asarray (fill , dtype = dtype ))), msg
65+
66+
6467# Testing xp.arange() requires bounding the start/stop/step arguments to only
6568# test argument combinations compliant with the Array API, as well as to not
6669# produce arrays with sizes not supproted by an array module.
@@ -234,8 +237,9 @@ def test_eye(n_rows, n_cols, kw):
234237)
235238
236239
240+
237241@st .composite
238- def full_fill_values (draw ):
242+ def full_fill_values (draw ) -> st . SearchStrategy [ float ] :
239243 kw = draw (st .shared (hh .kwargs (dtype = st .none () | xps .scalar_dtypes ()), key = "full_kw" ))
240244 dtype = kw .get ("dtype" , None ) or draw (default_safe_dtypes )
241245 return draw (xps .from_dtype (dtype ))
@@ -266,10 +270,7 @@ def test_full(shape, fill_value, kw):
266270 else :
267271 assert_kw_dtype ("full" , kw ["dtype" ], out .dtype )
268272 assert_shape ("full" , out .shape , shape , shape = shape )
269- if dh .is_float_dtype (out .dtype ) and math .isnan (fill_value ):
270- assert ah .all (ah .isnan (out )), "full() array did not equal the fill value"
271- else :
272- assert ah .all (ah .equal (out , ah .asarray (fill_value , dtype = dtype ))), "full() array did not equal the fill value"
273+ assert_fill ("full" , fill_value , dtype , out , fill_value = fill_value )
273274
274275
275276@st .composite
@@ -291,13 +292,8 @@ def test_full_like(x, fill_value, kw):
291292 ph .assert_dtype ("full_like" , (x .dtype ,), out .dtype )
292293 else :
293294 assert_kw_dtype ("full_like" , kw ["dtype" ], out .dtype )
294-
295295 assert_shape ("full_like" , out .shape , x .shape )
296- if dh .is_float_dtype (dtype ) and math .isnan (fill_value ):
297- assert ah .all (ah .isnan (out )), "full_like() array did not equal the fill value"
298- else :
299- assert ah .all (ah .equal (out , ah .asarray (fill_value , dtype = dtype ))), "full_like() array did not equal the fill value"
300-
296+ assert_fill ("full_like" , fill_value , dtype , out , fill_value = fill_value )
301297
302298finite_kw = {"allow_nan" : False , "allow_infinity" : False }
303299
@@ -364,7 +360,7 @@ def test_linspace(num, dtype, endpoint, data):
364360 # TODO: array assertions ala test_arange
365361
366362
367- def make_one (dtype ) :
363+ def make_one (dtype : DataType ) -> Union [ bool , float ] :
368364 if dtype is None or dh .is_float_dtype (dtype ):
369365 return 1.0
370366 elif dh .is_int_dtype (dtype ):
@@ -382,7 +378,7 @@ def test_ones(shape, kw):
382378 assert_kw_dtype ("ones" , kw ["dtype" ], out .dtype )
383379 assert_shape ("ones" , out .shape , shape , shape = shape )
384380 dtype = kw .get ("dtype" , None ) or dh .default_float
385- assert ah . all ( ah . equal ( out , ah . asarray ( make_one (dtype ), dtype = dtype ))), "ones() array did not equal 1"
381+ assert_fill ( "ones" , make_one (dtype ), dtype , out )
386382
387383
388384@given (
@@ -397,10 +393,10 @@ def test_ones_like(x, kw):
397393 assert_kw_dtype ("ones_like" , kw ["dtype" ], out .dtype )
398394 assert_shape ("ones_like" , out .shape , x .shape )
399395 dtype = kw .get ("dtype" , None ) or x .dtype
400- assert ah . all ( ah . equal ( out , ah . asarray ( make_one (dtype ), dtype = dtype ))), "ones_like() array elements did not equal 1"
396+ assert_fill ( "ones_like" , make_one (dtype ), dtype , out )
401397
402398
403- def make_zero (dtype ) :
399+ def make_zero (dtype : DataType ) -> Union [ bool , float ] :
404400 if dtype is None or dh .is_float_dtype (dtype ):
405401 return 0.0
406402 elif dh .is_int_dtype (dtype ):
@@ -418,7 +414,7 @@ def test_zeros(shape, kw):
418414 assert_kw_dtype ("zeros" , kw ["dtype" ], out .dtype )
419415 assert_shape ("zeros" , out .shape , shape , shape = shape )
420416 dtype = kw .get ("dtype" , None ) or dh .default_float
421- assert ah . all ( ah . equal ( out , ah . asarray ( make_zero (dtype ), dtype = dtype ))), "zeros() array did not equal 0"
417+ assert_fill ( "zeros" , make_zero (dtype ), dtype , out )
422418
423419
424420@given (
@@ -433,4 +429,4 @@ def test_zeros_like(x, kw):
433429 assert_kw_dtype ("zeros_like" , kw ["dtype" ], out .dtype )
434430 assert_shape ("zeros_like" , out .shape , x .shape )
435431 dtype = kw .get ("dtype" , None ) or x .dtype
436- assert ah . all ( ah . equal ( out , ah . asarray ( make_zero (dtype ), dtype = out . dtype ))), "xp.zeros_like() array elements did not ah.all xp.equal 0"
432+ assert_fill ( "zeros_like" , make_zero (dtype ), dtype , out )
0 commit comments