@@ -148,23 +148,29 @@ def cross_args(draw, dtype_objects=dh.real_dtypes):
148
148
in the drawn axis.
149
149
150
150
"""
151
- shape = list ( draw (shapes ()) )
152
- size = len (shape )
153
- assume (size > 0 )
151
+ shape1 , shape2 = draw (two_mutually_broadcastable_shapes )
152
+ min_ndim = min ( len (shape1 ), len ( shape2 ) )
153
+ assume (min_ndim > 0 )
154
154
155
- kw = draw (kwargs (axis = integers (- size , size - 1 )))
155
+ kw = draw (kwargs (axis = integers (- min_ndim , - 1 )))
156
156
axis = kw .get ('axis' , - 1 )
157
- shape [axis ] = 3
158
- shape = tuple (shape )
157
+ if draw (booleans ()):
158
+ # Sometimes allow invalid inputs to test it errors
159
+ shape1 = list (shape1 )
160
+ shape1 [axis ] = 3
161
+ shape1 = tuple (shape1 )
162
+ shape2 = list (shape2 )
163
+ shape2 [axis ] = 3
164
+ shape2 = tuple (shape2 )
159
165
160
166
mutual_dtypes = shared (mutually_promotable_dtypes (dtypes = dtype_objects ))
161
167
arrays1 = arrays (
162
168
dtype = mutual_dtypes .map (lambda pair : pair [0 ]),
163
- shape = shape ,
169
+ shape = shape1 ,
164
170
)
165
171
arrays2 = arrays (
166
172
dtype = mutual_dtypes .map (lambda pair : pair [1 ]),
167
- shape = shape ,
173
+ shape = shape2 ,
168
174
)
169
175
return draw (arrays1 ), draw (arrays2 ), kw
170
176
@@ -176,15 +182,17 @@ def test_cross(x1_x2_kw):
176
182
x1 , x2 , kw = x1_x2_kw
177
183
178
184
axis = kw .get ('axis' , - 1 )
179
- err = "test_cross produced invalid input. This indicates a bug in the test suite."
180
- assert x1 . shape == x2 . shape , err
181
- shape = x1 . shape
182
- assert x1 . shape [ axis ] == x2 . shape [ axis ] == 3 , err
185
+ if not ( x1 . shape [ axis ] == x2 . shape [ axis ] == 3 ):
186
+ ph . raises ( Exception , lambda : xp . cross ( x1 , x2 , ** kw ),
187
+ "cross did not raise an exception for invalid shapes" )
188
+ return
183
189
184
190
res = linalg .cross (x1 , x2 , ** kw )
185
191
192
+ broadcasted_shape = sh .broadcast_shapes (x1 .shape , x2 .shape )
193
+
186
194
assert res .dtype == dh .result_type (x1 .dtype , x2 .dtype ), "cross() did not return the correct dtype"
187
- assert res .shape == shape , "cross() did not return the correct shape"
195
+ assert res .shape == broadcasted_shape , "cross() did not return the correct shape"
188
196
189
197
def exact_cross (a , b ):
190
198
assert a .shape == b .shape == (3 ,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
0 commit comments