From 9c0a54abe9e45dc1c6fdf01be55713a8efc3903b Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 20 Oct 2021 11:34:06 +0100 Subject: [PATCH] Raise `InvalidArgument` with bad shape in `xps.arrays()` --- hypothesis-python/RELEASE.rst | 4 ++++ hypothesis-python/src/hypothesis/extra/array_api.py | 11 ++++++----- .../tests/array_api/test_argument_validation.py | 1 + 3 files changed, 11 insertions(+), 5 deletions(-) create mode 100644 hypothesis-python/RELEASE.rst diff --git a/hypothesis-python/RELEASE.rst b/hypothesis-python/RELEASE.rst new file mode 100644 index 0000000000..567d891614 --- /dev/null +++ b/hypothesis-python/RELEASE.rst @@ -0,0 +1,4 @@ +RELEASE_TYPE: patch + +This patch adds an error for when ``shapes`` in :func:`xps.arrays()` is not +passed as either a valid shape or strategy. diff --git a/hypothesis-python/src/hypothesis/extra/array_api.py b/hypothesis-python/src/hypothesis/extra/array_api.py index 1d281dccc1..b29703d239 100644 --- a/hypothesis-python/src/hypothesis/extra/array_api.py +++ b/hypothesis-python/src/hypothesis/extra/array_api.py @@ -468,16 +468,17 @@ def _arrays( return dtype.flatmap( lambda d: _arrays(xp, d, shape, elements=elements, fill=fill, unique=unique) ) + elif isinstance(dtype, str): + dtype = dtype_from_name(xp, dtype) + if isinstance(shape, st.SearchStrategy): return shape.flatmap( lambda s: _arrays(xp, dtype, s, elements=elements, fill=fill, unique=unique) ) - - if isinstance(dtype, str): - dtype = dtype_from_name(xp, dtype) - - if isinstance(shape, int): + elif isinstance(shape, int): shape = (shape,) + elif not isinstance(shape, tuple): + raise InvalidArgument(f"shape={shape} is not a valid shape or strategy") check_argument( all(isinstance(x, int) and x >= 0 for x in shape), f"shape={shape!r}, but all dimensions must be non-negative integers.", diff --git a/hypothesis-python/tests/array_api/test_argument_validation.py b/hypothesis-python/tests/array_api/test_argument_validation.py index 7afbd961ec..7717abdd1c 100644 --- a/hypothesis-python/tests/array_api/test_argument_validation.py +++ b/hypothesis-python/tests/array_api/test_argument_validation.py @@ -33,6 +33,7 @@ def e(a, **kwargs): e(xps.arrays, dtype=xp.int8, shape=(0.5,)), e(xps.arrays, dtype=xp.int8, shape=1, fill=3), e(xps.arrays, dtype=xp.int8, shape=1, elements="not a strategy"), + e(xps.arrays, dtype=xp.int8, shape=lambda: "not a strategy"), e(xps.array_shapes, min_side=2, max_side=1), e(xps.array_shapes, min_dims=3, max_dims=2), e(xps.array_shapes, min_dims=-1),