|
| 1 | +import math |
| 2 | + |
1 | 3 | from hypothesis import given |
2 | 4 | from hypothesis import strategies as st |
3 | 5 |
|
|
11 | 13 | @given( |
12 | 14 | shape=hh.shapes(min_dims=1), |
13 | 15 | dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes), |
| 16 | + kw=hh.kwargs(axis=st.just(0) | st.none()), # TODO: test with axis >= 1 |
14 | 17 | data=st.data(), |
15 | 18 | ) |
16 | | -def test_concat(shape, dtypes, data): |
| 19 | +def test_concat(shape, dtypes, kw, data): |
17 | 20 | arrays = [] |
18 | 21 | for i, dtype in enumerate(dtypes, 1): |
19 | 22 | x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") |
20 | 23 | arrays.append(x) |
21 | | - out = xp.concat(arrays) |
| 24 | + out = xp.concat(arrays, **kw) |
22 | 25 | ph.assert_dtype("concat", dtypes, out.dtype) |
23 | | - # TODO |
| 26 | + shapes = tuple(x.shape for x in arrays) |
| 27 | + if kw.get("axis", 0) == 0: |
| 28 | + pass # TODO: assert expected shape |
| 29 | + elif kw["axis"] is None: |
| 30 | + size = sum(math.prod(s) for s in shapes) |
| 31 | + ph.assert_result_shape("concat", shapes, out.shape, (size,), **kw) |
| 32 | + # TODO: assert out elements match input arrays |
24 | 33 |
|
25 | 34 |
|
26 | 35 | @given( |
|
0 commit comments