Skip to content

Commit 22a248e

Browse files
committed
Update test_cross to test broadcastable shapes
This also updates it to only test axes from [min(x1.ndim, x2.ndim), -1], as per data-apis/array-api#740
1 parent 4403061 commit 22a248e

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

array_api_tests/test_linalg.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -148,23 +148,29 @@ def cross_args(draw, dtype_objects=dh.real_dtypes):
148148
in the drawn axis.
149149
150150
"""
151-
shape = list(draw(shapes()))
152-
size = len(shape)
153-
assume(size > 0)
151+
shape1, shape2 = draw(two_mutually_broadcastable_shapes)
152+
min_ndim = min(len(shape1), len(shape2))
153+
assume(min_ndim > 0)
154154

155-
kw = draw(kwargs(axis=integers(-size, size-1)))
155+
kw = draw(kwargs(axis=integers(-min_ndim, -1)))
156156
axis = kw.get('axis', -1)
157-
shape[axis] = 3
158-
shape = tuple(shape)
157+
if draw(booleans()):
158+
# Sometimes allow invalid inputs to test it errors
159+
shape1 = list(shape1)
160+
shape1[axis] = 3
161+
shape1 = tuple(shape1)
162+
shape2 = list(shape2)
163+
shape2[axis] = 3
164+
shape2 = tuple(shape2)
159165

160166
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtype_objects))
161167
arrays1 = arrays(
162168
dtype=mutual_dtypes.map(lambda pair: pair[0]),
163-
shape=shape,
169+
shape=shape1,
164170
)
165171
arrays2 = arrays(
166172
dtype=mutual_dtypes.map(lambda pair: pair[1]),
167-
shape=shape,
173+
shape=shape2,
168174
)
169175
return draw(arrays1), draw(arrays2), kw
170176

@@ -176,15 +182,17 @@ def test_cross(x1_x2_kw):
176182
x1, x2, kw = x1_x2_kw
177183

178184
axis = kw.get('axis', -1)
179-
err = "test_cross produced invalid input. This indicates a bug in the test suite."
180-
assert x1.shape == x2.shape, err
181-
shape = x1.shape
182-
assert x1.shape[axis] == x2.shape[axis] == 3, err
185+
if not (x1.shape[axis] == x2.shape[axis] == 3):
186+
ph.raises(Exception, lambda: xp.cross(x1, x2, **kw),
187+
"cross did not raise an exception for invalid shapes")
188+
return
183189

184190
res = linalg.cross(x1, x2, **kw)
185191

192+
broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape)
193+
186194
assert res.dtype == dh.result_type(x1.dtype, x2.dtype), "cross() did not return the correct dtype"
187-
assert res.shape == shape, "cross() did not return the correct shape"
195+
assert res.shape == broadcasted_shape, "cross() did not return the correct shape"
188196

189197
def exact_cross(a, b):
190198
assert a.shape == b.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite."

0 commit comments

Comments
 (0)