99from . import pytest_helpers as ph
1010from . import xps
1111
12+ shared_shapes = st .shared (hh .shapes (), key = "shape" )
13+
1214
1315@given (
1416 shape = hh .shapes (min_dims = 1 ),
@@ -32,6 +34,81 @@ def test_concat(shape, dtypes, kw, data):
3234 # TODO: assert out elements match input arrays
3335
3436
37+ @given (
38+ x = xps .arrays (dtype = xps .scalar_dtypes (), shape = shared_shapes ),
39+ axis = shared_shapes .flatmap (lambda s : st .integers (- len (s ), len (s ))),
40+ )
41+ def test_expand_dims (x , axis ):
42+ xp .expand_dims (x , axis = axis )
43+ # TODO
44+
45+
46+ @given (
47+ x = xps .arrays (dtype = xps .scalar_dtypes (), shape = shared_shapes ),
48+ kw = hh .kwargs (
49+ axis = st .one_of (
50+ st .none (),
51+ shared_shapes .flatmap (
52+ lambda s : st .none ()
53+ if len (s ) == 0
54+ else st .integers (- len (s ) + 1 , len (s ) - 1 ),
55+ ),
56+ )
57+ ),
58+ )
59+ def test_flip (x , kw ):
60+ xp .flip (x , ** kw )
61+ # TODO
62+
63+
64+ @given (
65+ x = xps .arrays (dtype = xps .scalar_dtypes (), shape = shared_shapes ),
66+ axes = shared_shapes .flatmap (
67+ lambda s : st .lists (
68+ st .integers (0 , max (len (s ) - 1 , 0 )),
69+ min_size = len (s ),
70+ max_size = len (s ),
71+ unique = True ,
72+ ).map (tuple )
73+ ),
74+ )
75+ def test_permute_dims (x , axes ):
76+ xp .permute_dims (x , axes )
77+ # TODO
78+
79+
80+ @given (
81+ x = xps .arrays (dtype = xps .scalar_dtypes (), shape = shared_shapes ),
82+ shape = shared_shapes , # TODO: test more compatible shapes
83+ )
84+ def test_reshape (x , shape ):
85+ xp .reshape (x , shape )
86+ # TODO
87+
88+
89+ @given (
90+ # TODO: axis arguments, update shift respectively
91+ x = xps .arrays (dtype = xps .scalar_dtypes (), shape = shared_shapes ),
92+ shift = shared_shapes .flatmap (lambda s : st .integers (0 , max (math .prod (s ) - 1 , 0 ))),
93+ )
94+ def test_roll (x , shift ):
95+ xp .roll (x , shift )
96+ # TODO
97+
98+
99+ @given (
100+ x = xps .arrays (dtype = xps .scalar_dtypes (), shape = shared_shapes ),
101+ axis = shared_shapes .flatmap (
102+ lambda s : st .just (0 )
103+ if len (s ) == 0
104+ else st .integers (- len (s ) + 1 , len (s ) - 1 ).filter (lambda i : s [i ] == 1 )
105+ ), # TODO: tuple of axis i.e. axes
106+ )
107+ def test_squeeze (x , axis ):
108+ xp .squeeze (x , axis )
109+ # TODO
110+
111+
35112@given (
36113 shape = hh .shapes (),
37114 dtypes = hh .mutually_promotable_dtypes (None ),
0 commit comments