Skip to content

Commit 4403061

Browse files
committed
Fix test_vecdot to only generate axis in [-min(x1.ndim, x2.ndim), -1]
See data-apis/array-api#740
1 parent 6c4d455 commit 4403061

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

Diff for: array_api_tests/test_linalg.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -769,17 +769,15 @@ def true_trace(x_stack, offset=0):
769769

770770
@given(
771771
*two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)),
772-
kwargs(axis=integers()),
772+
data(),
773773
)
774-
def test_vecdot(x1, x2, kw):
774+
def test_vecdot(x1, x2, data):
775775
# TODO: vary shapes, test different axis arguments
776776
broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape)
777+
min_ndim = min(x1.ndim, x2.ndim)
777778
ndim = len(broadcasted_shape)
779+
kw = data.draw(kwargs(axis=integers(-min_ndim, -1)))
778780
axis = kw.get('axis', -1)
779-
if not (-ndim <= axis < ndim):
780-
ph.raises(Exception, lambda: xp.vecdot(x1, x2, **kw),
781-
f"vecdot did not raise an exception for invalid axis ({ndim=}, {kw=})")
782-
return
783781
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
784782
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
785783
if x1_shape[axis] != x2_shape[axis]:

0 commit comments

Comments
 (0)