@@ -22,6 +22,34 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
22
22
x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
23
23
return torch_linalg .cross (x1 , x2 , dim = axis )
24
24
25
- __all__ = linalg_all + ['outer' , 'trace' , 'matrix_transpose' , 'tensordot' ]
25
+ def vecdot (x1 : array , x2 : array , / , * , axis : int = - 1 , ** kwargs ) -> array :
26
+ from ._aliases import isdtype
27
+
28
+ x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
29
+
30
+ # torch.linalg.vecdot doesn't support integer dtypes
31
+ if isdtype (x1 .dtype , 'integral' ) or isdtype (x2 .dtype , 'integral' ):
32
+ if kwargs :
33
+ raise RuntimeError ("vecdot kwargs not supported for integral dtypes" )
34
+ ndim = max (x1 .ndim , x2 .ndim )
35
+ x1_shape = (1 ,)* (ndim - x1 .ndim ) + tuple (x1 .shape )
36
+ x2_shape = (1 ,)* (ndim - x2 .ndim ) + tuple (x2 .shape )
37
+ if x1_shape [axis ] != x2_shape [axis ]:
38
+ raise ValueError ("x1 and x2 must have the same size along the given axis" )
39
+
40
+ x1_ , x2_ = torch .broadcast_tensors (x1 , x2 )
41
+ x1_ = torch .moveaxis (x1_ , axis , - 1 )
42
+ x2_ = torch .moveaxis (x2_ , axis , - 1 )
43
+
44
+ res = x1_ [..., None , :] @ x2_ [..., None ]
45
+ return res [..., 0 , 0 ]
46
+ return torch .linalg .vecdot (x1 , x2 , dim = axis , ** kwargs )
47
+
48
+ def solve (x1 : array , x2 : array , / , ** kwargs ) -> array :
49
+ x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
50
+ return torch .linalg .solve (x1 , x2 , ** kwargs )
51
+
52
+ __all__ = linalg_all + ['outer' , 'trace' , 'matrix_transpose' , 'tensordot' ,
53
+ 'vecdot' , 'solve' ]
26
54
27
55
del linalg_all
0 commit comments