@@ -225,12 +225,14 @@ def test_arange(dtype, data):
225225 # [0.0, 0.33, 0.66, 1.0, 1.33, 1.66]
226226 #
227227 assert math .floor (math .sqrt (size )) <= out .size <= math .ceil (size ** 2 )
228-
229228 assume (out .size == size )
230229 if dh .is_int_dtype (_dtype ):
231230 ah .assert_exactly_equal (out , ah .asarray (list (r ), dtype = _dtype ))
232231 else :
233- pass # TODO: either emulate array module behaviour or assert a rough equals
232+ if out .size > 0 :
233+ assert ah .equal (
234+ out [0 ], ah .asarray (_start , dtype = out .dtype )
235+ ), f"out[0]={ out [0 ]} , but should be { _start } [linspace({ start = } , { stop = } )]"
234236
235237
236238@given (hh .shapes (), hh .kwargs (dtype = st .none () | hh .shared_dtypes ))
@@ -357,15 +359,21 @@ def test_full_like(x, fill_value, kw):
357359finite_kw = {"allow_nan" : False , "allow_infinity" : False }
358360
359361
360- @st .composite
361- def int_stops (draw , start : int , min_gap : int , m : int , M : int ):
362- sign = draw (st .booleans ().map (int ))
363- max_gap = abs (M - m )
364- max_int = math .floor (math .sqrt (max_gap ))
365- gap = draw (st .just (0 ) | st .integers (1 , max_int ).map (lambda n : min_gap ** n ))
366- stop = start + sign * gap
367- assume (m <= stop <= M )
368- return stop
362+ def int_stops (
363+ start : int , num , dtype : DataType , endpoint : bool
364+ ) -> st .SearchStrategy [int ]:
365+ min_gap = num
366+ if endpoint :
367+ min_gap += 1
368+ m , M = dh .dtype_ranges [dtype ]
369+ max_pos_gap = M - start
370+ max_neg_gap = start - m
371+ max_pos_mul = max_pos_gap // min_gap
372+ max_neg_mul = max_neg_gap // min_gap
373+ return st .one_of (
374+ st .integers (0 , max_pos_mul ).map (lambda n : start + min_gap * n ),
375+ st .integers (0 , max_neg_mul ).map (lambda n : start - min_gap * n ),
376+ )
369377
370378
371379@given (
@@ -381,17 +389,13 @@ def test_linspace(num, dtype, endpoint, data):
381389 if dh .is_float_dtype (_dtype ):
382390 stop = data .draw (xps .from_dtype (_dtype , ** finite_kw ), label = "stop" )
383391 # avoid overflow errors
384- delta = ah .asarray (stop - start , dtype = _dtype )
385- assume (not ah .isnan (delta ))
392+ assume ( not ah .isnan ( ah . asarray (stop - start , dtype = _dtype )) )
393+ assume (not ah .isnan (ah . asarray ( start - stop , dtype = _dtype ) ))
386394 else :
387395 if num == 0 :
388396 stop = start
389397 else :
390- min_gap = num
391- if endpoint :
392- min_gap += 1
393- m , M = dh .dtype_ranges [_dtype ]
394- stop = data .draw (int_stops (start , min_gap , m , M ), label = "stop" )
398+ stop = data .draw (int_stops (start , num , _dtype , endpoint ), label = "stop" )
395399
396400 kw = data .draw (
397401 specified_kwargs (
@@ -403,7 +407,10 @@ def test_linspace(num, dtype, endpoint, data):
403407 out = xp .linspace (start , stop , num , ** kw )
404408
405409 assert_shape ("linspace" , out .shape , num , start = stop , stop = stop , num = num )
406-
410+ if num > 0 :
411+ assert ah .equal (
412+ out [0 ], ah .asarray (start , dtype = out .dtype )
413+ ), f"out[0]={ out [0 ]} , but should be { start = } [linspace({ stop = } , { num = } )]"
407414 if endpoint :
408415 if num > 1 :
409416 assert ah .equal (
@@ -415,13 +422,9 @@ def test_linspace(num, dtype, endpoint, data):
415422 expected = xp .linspace (start , stop , num + 1 , dtype = dtype , endpoint = True )
416423 expected = expected [:- 1 ]
417424 ah .assert_exactly_equal (out , expected )
418-
419- if num > 0 :
420- assert ah .equal (
421- out [0 ], ah .asarray (start , dtype = out .dtype )
422- ), f"out[0]={ out [0 ]} , but should be { start = } [linspace({ stop = } , { num = } )]"
423-
424- # TODO: array assertions ala test_arange
425+ assert (
426+ out .size == num
427+ ), f"{ out .size = } , but should be { num = } [linspace({ start = } , { stop = } )]"
425428
426429
427430def make_one (dtype : DataType ) -> Scalar :
0 commit comments