Skip to content

Commit

Permalink
Merge pull request #319 from asmeurer/reshape-fix
Browse files Browse the repository at this point in the history
Fix test_reshape
  • Loading branch information
ev-br authored Nov 23, 2024
2 parents ad81cf6 + bcfcdba commit c2e010e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 20 deletions.
62 changes: 62 additions & 0 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,68 @@ def shapes(**kw):
lambda shape: math.prod(i for i in shape if i) < MAX_ARRAY_SIZE
)

def _factorize(n: int) -> List[int]:
# Simple prime factorization. Only needs to handle n ~ MAX_ARRAY_SIZE
factors = []
while n % 2 == 0:
factors.append(2)
n //= 2

for i in range(3, int(math.sqrt(n)) + 1, 2):
while n % i == 0:
factors.append(i)
n //= i

if n > 1: # n is a prime number greater than 2
factors.append(n)

return factors

MAX_SIDE = MAX_ARRAY_SIZE // 64
# NumPy only supports up to 32 dims. TODO: Get this from the new inspection APIs
MAX_DIMS = min(MAX_ARRAY_SIZE // MAX_SIDE, 32)


@composite
def reshape_shapes(draw, arr_shape, ndims=integers(1, MAX_DIMS)):
"""
Generate shape tuples whose product equals the product of array_shape.
"""
shape = draw(arr_shape)

array_size = math.prod(shape)

n_dims = draw(ndims)

# Handle special cases
if array_size == 0:
# Generate a random tuple, and ensure at least one of the entries is 0
result = list(draw(shapes(min_dims=n_dims, max_dims=n_dims)))
pos = draw(integers(0, n_dims - 1))
result[pos] = 0
return tuple(result)

if array_size == 1:
return tuple(1 for _ in range(n_dims))

# Get prime factorization
factors = _factorize(array_size)

# Distribute prime factors randomly
result = [1] * n_dims
for factor in factors:
pos = draw(integers(0, n_dims - 1))
result[pos] *= factor

assert math.prod(result) == array_size

# An element of the reshape tuple can be -1, which means it is a stand-in
# for the remaining factors.
if draw(booleans()):
pos = draw(integers(0, n_dims - 1))
result[pos] = -1

return tuple(result)

one_d_shapes = xps.array_shapes(min_dims=1, max_dims=1, min_side=0, max_side=SQRT_MAX_ARRAY_SIZE)

Expand Down
25 changes: 5 additions & 20 deletions array_api_tests/test_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
from . import xps
from .typing import Array, Shape

MAX_SIDE = hh.MAX_ARRAY_SIZE // 64
MAX_DIMS = min(hh.MAX_ARRAY_SIZE // MAX_SIDE, 32) # NumPy only supports up to 32 dims


def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]:
key = "shape"
Expand Down Expand Up @@ -66,7 +63,7 @@ def test_concat(dtypes, base_shape, data):
shape_strat = hh.shapes()
else:
_axis = axis if axis >= 0 else len(base_shape) + axis
shape_strat = st.integers(0, MAX_SIDE).map(
shape_strat = st.integers(0, hh.MAX_SIDE).map(
lambda i: base_shape[:_axis] + (i,) + base_shape[_axis + 1 :]
)
arrays = []
Expand Down Expand Up @@ -348,26 +345,14 @@ def test_repeat(x, kw, data):
kw=kw)
start = end

@st.composite
def reshape_shapes(draw, shape):
size = 1 if len(shape) == 0 else math.prod(shape)
rshape = draw(st.lists(st.integers(0)).filter(lambda s: math.prod(s) == size))
assume(all(side <= MAX_SIDE for side in rshape))
if len(rshape) != 0 and size > 0 and draw(st.booleans()):
index = draw(st.integers(0, len(rshape) - 1))
rshape[index] = -1
return tuple(rshape)

reshape_shape = st.shared(hh.shapes(), key="reshape_shape")

@pytest.mark.unvectorized
@pytest.mark.skip("flaky") # TODO: fix!
@given(
x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(max_side=MAX_SIDE)),
data=st.data(),
x=hh.arrays(dtype=hh.all_dtypes, shape=reshape_shape),
shape=hh.reshape_shapes(reshape_shape),
)
def test_reshape(x, data):
shape = data.draw(reshape_shapes(x.shape))

def test_reshape(x, shape):
out = xp.reshape(x, shape)

ph.assert_dtype("reshape", in_dtype=x.dtype, out_dtype=out.dtype)
Expand Down

0 comments on commit c2e010e

Please sign in to comment.