Skip to content

Commit 001d86d

Browse files
committed
Fix tolerance issues with test_arange
1 parent d7c8844 commit 001d86d

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,10 @@ def test_arange(dtype, data):
111111
_start = start
112112
_stop = stop
113113

114-
tol = max(abs(_stop - _start) / (hh.MAX_ARRAY_SIZE - 1), 0.01)
114+
# tol is the minimum tolerance for step values, used to avoid scenarios
115+
# where xp.arange() produces arrays that would be over MAX_ARRAY_SIZE.
116+
tol = max(abs(_stop - _start) / (math.sqrt(hh.MAX_ARRAY_SIZE)), 0.01)
117+
assert tol != 0, "tol must not equal 0" # sanity check
115118
assume(-tol > int_min)
116119
assume(tol < int_max)
117120
if dtype is None or dh.is_float_dtype(dtype):
@@ -124,7 +127,7 @@ def test_arange(dtype, data):
124127
step_max = max(math.ceil(tol), 1)
125128
step_strats.append(xps.from_dtype(dtype, min_value=step_max))
126129
step = data.draw(st.one_of(step_strats), label="step")
127-
assert step != 0, f"{step=} must not equal 0" # sanity check
130+
assert step != 0, "step must not equal 0" # sanity check
128131

129132
all_int = all(arg is None or isinstance(arg, int) for arg in [start, stop, step])
130133

@@ -147,9 +150,9 @@ def test_arange(dtype, data):
147150
pos_step = step > 0
148151
if _start != _stop and pos_range == pos_step:
149152
if pos_step:
150-
condition = lambda x: x <= _stop
153+
condition = lambda x: x < _stop
151154
else:
152-
condition = lambda x: x >= _stop
155+
condition = lambda x: x > _stop
153156
scalar_type = int if dh.is_int_dtype(_dtype) else float
154157
elements = list(
155158
scalar_type(n) for n in takewhile(condition, count(_start, step))
@@ -159,7 +162,7 @@ def test_arange(dtype, data):
159162
size = len(elements)
160163
assert (
161164
size <= hh.MAX_ARRAY_SIZE
162-
), f"{size=} should be no more than {hh.MAX_ARRAY_SIZE=}" # sanity check
165+
), f"{size=} should be no more than {hh.MAX_ARRAY_SIZE}" # sanity check
163166

164167
out = xp.arange(start, stop=stop, step=step, dtype=dtype)
165168

@@ -181,7 +184,7 @@ def test_arange(dtype, data):
181184
# >>> xp.arange(2, step=0.3333333333333333)
182185
# [0.0, 0.33, 0.66, 1.0, 1.33, 1.66]
183186
#
184-
assert out.size in (size - 1, size, size + 1)
187+
assert math.floor(math.sqrt(size)) <= out.size <= math.ceil(size ** 2)
185188
assume(out.size == size)
186189
if dh.is_int_dtype(_dtype):
187190
ah.assert_exactly_equal(out, ah.asarray(elements, dtype=_dtype))

0 commit comments

Comments
 (0)