|
3 | 3 | from itertools import product |
4 | 4 | from typing import Iterable, Iterator, Tuple, Union |
5 | 5 |
|
| 6 | +import pytest |
6 | 7 | from hypothesis import assume, given |
7 | 8 | from hypothesis import strategies as st |
8 | 9 |
|
@@ -168,23 +169,26 @@ def test_expand_dims(x, axis): |
168 | 169 | data=st.data(), |
169 | 170 | ) |
170 | 171 | def test_squeeze(x, data): |
171 | | - # TODO: generate valid negative axis (which keep uniqueness) |
172 | | - squeezable_axes = st.sampled_from( |
173 | | - [i for i, side in enumerate(x.shape) if side == 1] |
174 | | - ) |
| 172 | + axes = st.integers(-x.ndim, x.ndim - 1) |
175 | 173 | axis = data.draw( |
176 | | - squeezable_axes | st.lists(squeezable_axes, unique=True).map(tuple), |
| 174 | + axes |
| 175 | + | st.lists(axes, unique_by=lambda i: i if i >= 0 else i + x.ndim).map(tuple), |
177 | 176 | label="axis", |
178 | 177 | ) |
179 | 178 |
|
| 179 | + axes = (axis,) if isinstance(axis, int) else axis |
| 180 | + axes = normalise_axis(axes, x.ndim) |
| 181 | + |
| 182 | + squeezable_axes = [i for i, side in enumerate(x.shape) if side == 1] |
| 183 | + if any(i not in squeezable_axes for i in axes): |
| 184 | + with pytest.raises(ValueError): |
| 185 | + xp.squeeze(x, axis) |
| 186 | + return |
| 187 | + |
180 | 188 | out = xp.squeeze(x, axis) |
181 | 189 |
|
182 | 190 | ph.assert_dtype("squeeze", x.dtype, out.dtype) |
183 | 191 |
|
184 | | - if isinstance(axis, int): |
185 | | - axes = (axis,) |
186 | | - else: |
187 | | - axes = axis |
188 | 192 | shape = [] |
189 | 193 | for i, side in enumerate(x.shape): |
190 | 194 | if i not in axes: |
|
0 commit comments