|
1 | 1 | import math |
2 | 2 | from collections import deque |
3 | 3 | from itertools import product |
4 | | -from typing import Iterable, Iterator, Tuple, Union |
| 4 | +from typing import Iterable, Union |
5 | 5 |
|
6 | 6 | from hypothesis import assume, given |
7 | 7 | from hypothesis import strategies as st |
@@ -45,15 +45,6 @@ def assert_array_ndindex( |
45 | 45 | assert out[out_idx] == x[x_idx], msg |
46 | 46 |
|
47 | 47 |
|
48 | | -def axis_ndindex( |
49 | | - shape: Shape, axis: int |
50 | | -) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]: |
51 | | - iterables = [range(side) for side in shape[:axis]] |
52 | | - for _ in range(len(shape[axis:])): |
53 | | - iterables.append([slice(None, None)]) |
54 | | - yield from product(*iterables) |
55 | | - |
56 | | - |
57 | 48 | def assert_equals( |
58 | 49 | func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw |
59 | 50 | ): |
@@ -124,7 +115,10 @@ def test_concat(dtypes, kw, data): |
124 | 115 | ) |
125 | 116 | else: |
126 | 117 | out_indices = ah.ndindex(out.shape) |
127 | | - for idx in axis_ndindex(shape, axis): |
| 118 | + axis_indices = [range(side) for side in shapes[0][:_axis]] |
| 119 | + for _ in range(_axis, len(shape)): |
| 120 | + axis_indices.append([slice(None, None)]) |
| 121 | + for idx in product(*axis_indices): |
128 | 122 | f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) |
129 | 123 | for x_num, x in enumerate(arrays, 1): |
130 | 124 | indexed_x = x[idx] |
|
0 commit comments