We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 6c4d455 commit 4403061Copy full SHA for 4403061
array_api_tests/test_linalg.py
@@ -769,17 +769,15 @@ def true_trace(x_stack, offset=0):
769
770
@given(
771
*two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)),
772
- kwargs(axis=integers()),
+ data(),
773
)
774
-def test_vecdot(x1, x2, kw):
+def test_vecdot(x1, x2, data):
775
# TODO: vary shapes, test different axis arguments
776
broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape)
777
+ min_ndim = min(x1.ndim, x2.ndim)
778
ndim = len(broadcasted_shape)
779
+ kw = data.draw(kwargs(axis=integers(-min_ndim, -1)))
780
axis = kw.get('axis', -1)
- if not (-ndim <= axis < ndim):
- ph.raises(Exception, lambda: xp.vecdot(x1, x2, **kw),
781
- f"vecdot did not raise an exception for invalid axis ({ndim=}, {kw=})")
782
- return
783
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
784
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
785
if x1_shape[axis] != x2_shape[axis]:
0 commit comments