@@ -59,12 +59,14 @@ def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
5959
6060
6161def assert_shape (
62- func_name : str , out_shape : Shape , expected : Union [int , Shape ], / , ** kw
62+ func_name : str , out_shape : Union [ int , Shape ] , expected : Union [int , Shape ], / , ** kw
6363):
64- f_kw = ", " . join ( f" { k } = { v } " for k , v in kw . items ())
65- msg = f"out.shape= { out_shape } , but should be { expected } [ { func_name } ( { f_kw } )]"
64+ if isinstance ( out_shape , int ):
65+ out_shape = ( out_shape ,)
6666 if isinstance (expected , int ):
6767 expected = (expected ,)
68+ f_kw = ", " .join (f"{ k } ={ v } " for k , v in kw .items ())
69+ msg = f"out.shape={ out_shape } , but should be { expected } [{ func_name } ({ f_kw } )]"
6870 assert out_shape == expected , msg
6971
7072
@@ -183,7 +185,7 @@ def test_arange(dtype, data):
183185 else :
184186 _dtype = dtype
185187
186- # sanity check
188+ # sanity checks
187189 if dh .is_int_dtype (_dtype ):
188190 m , M = dh .dtype_ranges [_dtype ]
189191 assert m <= _start <= M
@@ -213,9 +215,10 @@ def test_arange(dtype, data):
213215 assert_default_float ("arange" , out .dtype )
214216 else :
215217 assert out .dtype == dtype
216- assert out .ndim == 1 , f"{ out .ndim = } , should be 1 [linspace()]"
218+ assert out .ndim == 1 , f"{ out .ndim = } , but should be 1 [linspace()]"
219+ f_func = f"[linspace({ start = } , { stop = } , { step = } )]"
217220 if dh .is_int_dtype (_dtype ):
218- assert out .size == size
221+ assert out .size == size , f" { out . size = } , but should be { size } { f_func } "
219222 else :
220223 # We check size is roughly as expected to avoid edge cases e.g.
221224 #
@@ -224,7 +227,11 @@ def test_arange(dtype, data):
224227 # >>> xp.arange(2, step=0.3333333333333333)
225228 # [0.0, 0.33, 0.66, 1.0, 1.33, 1.66]
226229 #
227- assert math .floor (math .sqrt (size )) <= out .size <= math .ceil (size ** 2 )
230+ min_size = math .floor (size * 0.9 )
231+ max_size = math .ceil (size * 1.1 )
232+ assert (
233+ min_size <= out .size <= max_size
234+ ), f"{ out .size = } , but should be roughly { size } { f_func } "
228235 assume (out .size == size )
229236 if dh .is_int_dtype (_dtype ):
230237 ah .assert_exactly_equal (out , ah .asarray (list (r ), dtype = _dtype ))
@@ -407,24 +414,22 @@ def test_linspace(num, dtype, endpoint, data):
407414 out = xp .linspace (start , stop , num , ** kw )
408415
409416 assert_shape ("linspace" , out .shape , num , start = stop , stop = stop , num = num )
417+ f_func = f"[linspace({ start = } , { stop = } , { num = } )]"
410418 if num > 0 :
411419 assert ah .equal (
412420 out [0 ], ah .asarray (start , dtype = out .dtype )
413- ), f"out[0]={ out [0 ]} , but should be { start = } [linspace( { stop = } , { num = } )] "
421+ ), f"out[0]={ out [0 ]} , but should be { start } { f_func } "
414422 if endpoint :
415423 if num > 1 :
416424 assert ah .equal (
417425 out [- 1 ], ah .asarray (stop , dtype = out .dtype )
418- ), f"out[-1]={ out [- 1 ]} , but should be { stop = } [linspace( { start = } , { num = } )] "
426+ ), f"out[-1]={ out [- 1 ]} , but should be { stop } { f_func } "
419427 else :
420428 # linspace(..., num, endpoint=True) should return an array equivalent to
421429 # the first num elements when endpoint=False
422430 expected = xp .linspace (start , stop , num + 1 , dtype = dtype , endpoint = True )
423431 expected = expected [:- 1 ]
424432 ah .assert_exactly_equal (out , expected )
425- assert (
426- out .size == num
427- ), f"{ out .size = } , but should be { num = } [linspace({ start = } , { stop = } )]"
428433
429434
430435def make_one (dtype : DataType ) -> Scalar :
0 commit comments