Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit e16b937

Browse files
committedMar 10, 2022
[Array API] Add linalg.vecdot
This PR adds the function `linalg.vecdot` specified by the [Array API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot) For the complex case, it chooses to implement \sum x_i y_i. See the discussion in data-apis/array-api#356 ghstack-source-id: 9ebc81abaf7a31b10ff3552f661d0c1944b86aa7 Pull Request resolved: #70542
1 parent d0f9556 commit e16b937

File tree

7 files changed

+104
-7
lines changed

7 files changed

+104
-7
lines changed
 

‎aten/src/ATen/native/BatchLinearAlgebra.cpp

+23
Original file line numberDiff line numberDiff line change
@@ -3835,6 +3835,29 @@ TransposeType to_transpose_type(const bool contig, const bool conj) {
38353835
}
38363836
} // end of anonymous namespace
38373837

3838+
Tensor& linalg_vecdot_out(const Tensor& x, const Tensor& y, int64_t dim, Tensor& out) {
3839+
at::native::checkFloatingOrComplex(x, "linalg.vecdot");
3840+
at::native::checkFloatingOrComplex(y, "linalg.vecdot");
3841+
// Computes x^H y
3842+
if (x.dim() == 1 && y.dim() == 1) {
3843+
at::native::resize_output(out, {});
3844+
return at::vdot_out(out, x, y);
3845+
} else {
3846+
return at::sum_out(out, x.conj() * y, /*dim=*/dim);
3847+
}
3848+
}
3849+
3850+
Tensor linalg_vecdot(const Tensor& x, const Tensor& y, int64_t dim) {
3851+
at::native::checkFloatingOrComplex(x, "linalg.vecdot");
3852+
at::native::checkFloatingOrComplex(y, "linalg.vecdot");
3853+
// Computes x^H y
3854+
if (x.dim() == 1 && y.dim() == 1) {
3855+
return at::vdot(x, y);
3856+
} else {
3857+
return x.conj().mul(y).sum(/*dim=*/dim);
3858+
}
3859+
}
3860+
38383861
/*
38393862
Solves the matrix equation AX = B for A triangular.
38403863
'left' If true solves AX = B, if false solves XA = B

‎aten/src/ATen/native/native_functions.yaml

+7
Original file line numberDiff line numberDiff line change
@@ -10865,6 +10865,13 @@
1086510865
- func: linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
1086610866
python_module: linalg
1086710867

10868+
- func: linalg_vecdot(Tensor x, Tensor y, *, int dim=-1) -> Tensor
10869+
python_module: linalg
10870+
variants: function
10871+
10872+
- func: linalg_vecdot.out(Tensor x, Tensor y, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)
10873+
python_module: linalg
10874+
1086810875
- func: linalg_matrix_exp(Tensor self) -> Tensor
1086910876
python_module: linalg
1087010877
variants: function

‎docs/source/linalg.rst

+1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ Matrix Products
8282

8383
cross
8484
matmul
85+
vecdot
8586
multi_dot
8687
householder_product
8788

‎torch/_torch_docs.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -3446,21 +3446,34 @@ def merge_dicts(*dicts):
34463446
r"""
34473447
vdot(input, other, *, out=None) -> Tensor
34483448
3449-
Computes the dot product of two 1D tensors. The vdot(a, b) function handles complex numbers
3450-
differently than dot(a, b). If the first argument is complex, the complex conjugate of the
3451-
first argument is used for the calculation of the dot product.
3449+
Computes the dot product of two 1D vectors along a dimension.
3450+
3451+
In symbols, this function computes
3452+
3453+
.. math::
3454+
3455+
\sum_{i=1}^n \overline{x_i}y_i.
3456+
3457+
where :math:`\overline{x_i}` denotes the conjugate for complex
3458+
vectors, and it is the identity for real vectors.
34523459
34533460
.. note::
34543461
34553462
Unlike NumPy's vdot, torch.vdot intentionally only supports computing the dot product
34563463
of two 1D tensors with the same number of elements.
34573464
3465+
.. seealso::
3466+
3467+
:func:`torch.linalg.vecdot` computes the dot product of two batches of vectors along a dimension.
3468+
34583469
Args:
34593470
input (Tensor): first tensor in the dot product, must be 1D. Its conjugate is used if it's complex.
34603471
other (Tensor): second tensor in the dot product, must be 1D.
34613472
34623473
Keyword args:
3463-
{out}
3474+
""" + fr"""
3475+
.. note:: {common_args["out"]}
3476+
""" + r"""
34643477
34653478
Example::
34663479
@@ -3472,7 +3485,7 @@ def merge_dicts(*dicts):
34723485
tensor([16.+1.j])
34733486
>>> torch.vdot(b, a)
34743487
tensor([16.-1.j])
3475-
""".format(**common_args))
3488+
""")
34763489

34773490
add_docstr(torch.eig,
34783491
r"""

‎torch/linalg/__init__.py

+35
Original file line numberDiff line numberDiff line change
@@ -2378,3 +2378,38 @@
23782378
>>> torch.dist(Q.mT @ Q, torch.eye(4))
23792379
tensor(6.2158e-07)
23802380
""")
2381+
2382+
vecdot = _add_docstr(_linalg.linalg_vecdot, r"""
2383+
linalg.vecdot(x, y, *, dim=-1, out=None) -> Tensor
2384+
2385+
Computes the dot product of two batches of vectors along a dimension.
2386+
2387+
In symbols, this function computes
2388+
2389+
.. math::
2390+
2391+
\sum_{i=1}^n \overline{x_i}y_i.
2392+
2393+
over the dimension :attr:`dim` where :math:`\overline{x_i}` denotes the conjugate for complex
2394+
vectors, and it is the identity for real vectors.
2395+
2396+
Supports input of half, bfloat16, float, double, cfloat, cdouble and integral dtypes.
2397+
It also supports broadcasting.
2398+
2399+
Args:
2400+
x (Tensor): first batch of vectors of shape `(*, n)`.
2401+
y (Tensor): second batch of vectors of shape `(*, n)`.
2402+
2403+
Keyword args:
2404+
dim (int): Dimension along which to compute the dot product. Default: `-1`.
2405+
out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`.
2406+
2407+
Examples::
2408+
2409+
>>> v1 = torch.randn(3, 2)
2410+
>>> v2 = torch.randn(3, 2)
2411+
>>> linalg.vecdot(v1, v2)
2412+
tensor([ 0.3223, 0.2815, -0.1944])
2413+
>>> torch.vdot(v1[0], v2[0])
2414+
tensor(0.3223)
2415+
""")

‎torch/overrides.py

+1
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
874874
torch.ravel: lambda input: -1,
875875
torch.real: lambda input, out=None: -1,
876876
torch.vdot: lambda input, other, out=None: -1,
877+
torch.linalg.vecdot: lambda input, other, dim=-1, out=None: -1,
877878
torch.view_as_real: lambda input: -1,
878879
torch.view_as_complex: lambda input: -1,
879880
torch.reciprocal: lambda input, out=None: -1,

‎torch/testing/_internal/common_methods_invocations.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,7 @@ def sample_inputs_reduction(op_info, device, dtype, requires_grad, **kwargs):
991991
supports_multiple_dims: bool = kwargs.get('supports_multiple_dims', True)
992992

993993
# TODO(@heitorschueroff) Once all reduction operators are using ReductionOpInfo
994-
# use op_info.genearte_args_kwargs directly.
994+
# use op_info.generate_args_kwargs directly.
995995
generate_args_kwargs = kwargs.get('generate_args_kwargs', lambda *args, **kwargs: (yield tuple(), {}))
996996

997997
inputs: List[SampleInput] = []
@@ -1170,7 +1170,7 @@ class ReductionOpInfo(OpInfo):
11701170
the optional keyword parameters of the ReductionOpInfo constructor.
11711171

11721172
If a reduction operator does not yet implement the full required API of
1173-
reduction operators, this should be documented by skipping the failing
1173+
reduction operators, this should be documented by xfailing the failing
11741174
tests rather than adding optional parameters to ReductionOpInfo.
11751175

11761176
NOTE
@@ -2120,6 +2120,16 @@ def sample_inputs_isclose(
21202120
yield SampleInput(lhs, args=(rhs,),
21212121
kwargs=dict(op_kwargs, rtol=rtol, atol=atol, equal_nan=equal_nan))
21222122

2123+
def sample_inputs_linalg_vecdot(op_info, device, dtype, requires_grad, **kwargs):
2124+
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2125+
batches = ((), (0,), (1,), (5,))
2126+
ns = (0, 1, 3, 5)
2127+
for b, n in product(batches, ns):
2128+
shape = b + (n,)
2129+
yield SampleInput(make_arg(shape), args=(make_arg(shape),))
2130+
for i in range(len(shape)):
2131+
yield SampleInput(make_arg(shape), args=(make_arg(shape),), kwargs=dict(dim=i))
2132+
21232133
def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs):
21242134
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
21252135
return (SampleInput(make_arg((1, 2))),
@@ -9898,6 +9908,13 @@ def ref_pairwise_distance(input1, input2):
98989908
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
98999909
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
99009910
),
9911+
OpInfo('linalg.vecdot',
9912+
aten_name='linalg_vecdot',
9913+
ref=lambda x, y, *, dim=-1: (x.conj() * y).sum(dim),
9914+
dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16),
9915+
sample_inputs_func=sample_inputs_linalg_vecdot,
9916+
supports_forward_ad=True,
9917+
supports_fwgrad_bwgrad=True),
99019918
OpInfo('linalg.cond',
99029919
aten_name='linalg_cond',
99039920
dtypes=floating_and_complex_types(),

0 commit comments

Comments
 (0)
Please sign in to comment.