@@ -111,7 +111,10 @@ def test_arange(dtype, data):
111111 _start = start
112112 _stop = stop
113113
114- tol = max (abs (_stop - _start ) / (hh .MAX_ARRAY_SIZE - 1 ), 0.01 )
114+ # tol is the minimum tolerance for step values, used to avoid scenarios
115+ # where xp.arange() produces arrays that would be over MAX_ARRAY_SIZE.
116+ tol = max (abs (_stop - _start ) / (math .sqrt (hh .MAX_ARRAY_SIZE )), 0.01 )
117+ assert tol != 0 , "tol must not equal 0" # sanity check
115118 assume (- tol > int_min )
116119 assume (tol < int_max )
117120 if dtype is None or dh .is_float_dtype (dtype ):
@@ -124,7 +127,7 @@ def test_arange(dtype, data):
124127 step_max = max (math .ceil (tol ), 1 )
125128 step_strats .append (xps .from_dtype (dtype , min_value = step_max ))
126129 step = data .draw (st .one_of (step_strats ), label = "step" )
127- assert step != 0 , f" { step = } must not equal 0" # sanity check
130+ assert step != 0 , " step must not equal 0" # sanity check
128131
129132 all_int = all (arg is None or isinstance (arg , int ) for arg in [start , stop , step ])
130133
@@ -147,9 +150,9 @@ def test_arange(dtype, data):
147150 pos_step = step > 0
148151 if _start != _stop and pos_range == pos_step :
149152 if pos_step :
150- condition = lambda x : x <= _stop
153+ condition = lambda x : x < _stop
151154 else :
152- condition = lambda x : x >= _stop
155+ condition = lambda x : x > _stop
153156 scalar_type = int if dh .is_int_dtype (_dtype ) else float
154157 elements = list (
155158 scalar_type (n ) for n in takewhile (condition , count (_start , step ))
@@ -159,7 +162,7 @@ def test_arange(dtype, data):
159162 size = len (elements )
160163 assert (
161164 size <= hh .MAX_ARRAY_SIZE
162- ), f"{ size = } should be no more than { hh .MAX_ARRAY_SIZE = } " # sanity check
165+ ), f"{ size = } should be no more than { hh .MAX_ARRAY_SIZE } " # sanity check
163166
164167 out = xp .arange (start , stop = stop , step = step , dtype = dtype )
165168
@@ -181,7 +184,7 @@ def test_arange(dtype, data):
181184 # >>> xp.arange(2, step=0.3333333333333333)
182185 # [0.0, 0.33, 0.66, 1.0, 1.33, 1.66]
183186 #
184- assert out . size in ( size - 1 , size , size + 1 )
187+ assert math . floor ( math . sqrt ( size )) <= out . size <= math . ceil ( size ** 2 )
185188 assume (out .size == size )
186189 if dh .is_int_dtype (_dtype ):
187190 ah .assert_exactly_equal (out , ah .asarray (elements , dtype = _dtype ))
0 commit comments