Skip to content

Commit 8f4327c

Browse files
authored
Add complex number support to tensordot (#558)
1 parent 031987d commit 8f4327c

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

spec/API_specification/array_api/linear_algebra_functions.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,15 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
6060
"""
6161
Returns a tensor contraction of ``x1`` and ``x2`` over specific axes.
6262
63+
.. note::
64+
The ``tensordot`` function corresponds to the generalized matrix product.
65+
6366
Parameters
6467
----------
6568
x1: array
66-
first input array. Should have a real-valued data type.
69+
first input array. Should have a numeric data type.
6770
x2: array
68-
second input array. Should have a real-valued data type. Corresponding contracted axes of ``x1`` and ``x2`` must be equal.
71+
second input array. Should have a numeric data type. Corresponding contracted axes of ``x1`` and ``x2`` must be equal.
6972
7073
.. note::
7174
Contracted axes (dimensions) must not be broadcasted.
@@ -81,6 +84,10 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
8184
8285
If ``axes`` is a tuple of two sequences ``(x1_axes, x2_axes)``, the first sequence must apply to ``x`` and the second sequence to ``x2``. Both sequences must have the same length. Each axis (dimension) ``x1_axes[i]`` for ``x1`` must have the same size as the respective axis (dimension) ``x2_axes[i]`` for ``x2``. Each sequence must consist of unique (nonnegative) integers that specify valid axes for each respective array.
8386
87+
88+
.. note::
89+
If either ``x1`` or ``x2`` has a complex floating-point data type, neither argument must be complex-conjugated or transposed. If conjugation and/or transposition is desired, these operations should be explicitly performed prior to computing the generalized matrix product.
90+
8491
Returns
8592
-------
8693
out: array

0 commit comments

Comments
 (0)