Skip to content

Commit eba54b3

Browse files
authored
Add support for broadcasting to linalg.cross (#417)
* Make explicit that broadcasting only applies to non-compute dimensions in vecdot * Add support for broadcasting to `linalg.cross` * Update copy * Update copy * Fix spacing
1 parent 55b8fb0 commit eba54b3

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

spec/API_specification/array_api/linalg.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,34 @@ def cholesky(x: array, /, *, upper: bool = False) -> array:
4141

4242
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
4343
"""
44-
Returns the cross product of 3-element vectors. If ``x1`` and ``x2`` are multi-dimensional arrays (i.e., both have a rank greater than ``1``), then the cross-product of each pair of corresponding 3-element vectors is independently computed.
44+
Returns the cross product of 3-element vectors.
45+
46+
If ``x1`` and/or ``x2`` are multi-dimensional arrays (i.e., the broadcasted result has a rank greater than ``1``), then the cross-product of each pair of corresponding 3-element vectors is independently computed.
4547
4648
Parameters
4749
----------
4850
x1: array
4951
first input array. Should have a real-valued data type.
5052
x2: array
51-
second input array. Must have the same shape as ``x1``. Should have a real-valued data type.
53+
second input array. Must be compatible with ``x1`` for all non-compute axes (see :ref:`broadcasting`). The size of the axis over which to compute the cross product must be the same size as the respective axis in ``x1``. Should have a real-valued data type.
54+
55+
.. note::
56+
The compute axis (dimension) must not be broadcasted.
57+
5258
axis: int
53-
the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the cross product. If set to ``-1``, the function computes the cross product for vectors defined by the last axis (dimension). Default: ``-1``.
59+
the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the cross product. Must be an integer on the interval ``[-N, N)``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting`. If specified as a negative integer, the function must determine the axis along which to compute the cross product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the cross product over the last axis. Default: ``-1``.
5460
5561
Returns
5662
-------
5763
out: array
5864
an array containing the cross products. The returned array must have a data type determined by :ref:`type-promotion`.
65+
66+
67+
**Raises**
68+
69+
- if provided an invalid ``axis``.
70+
- if the size of the axis over which to compute the cross product is not equal to ``3``.
71+
- if the size of the axis over which to compute the cross product is not the same (before broadcasting) for both ``x1`` and ``x2``.
5972
"""
6073

6174
def det(x: array, /) -> array:

spec/API_specification/array_api/linear_algebra_functions.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,12 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
9292
x1: array
9393
first input array. Should have a real-valued data type.
9494
x2: array
95-
second input array. Should have a real-valued data type. Corresponding contracted axes of ``x1`` and ``x2`` must be equal.
95+
second input array. Must be compatible with ``x1`` for all non-contracted axes (see :ref:`broadcasting`). The size of the axis over which to compute the dot product must be the same size as the respective axis in ``x1``. Should have a real-valued data type.
9696
9797
.. note::
98-
Contracted axes (dimensions) must not be broadcasted.
98+
The contracted axis (dimension) must not be broadcasted.
9999
100-
axis:int
100+
axis: int
101101
axis over which to compute the dot product. Must be an integer on the interval ``[-N, N)``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting`. If specified as a negative integer, the function must determine the axis along which to compute the dot product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the dot product over the last axis. Default: ``-1``.
102102
103103
Returns
@@ -109,7 +109,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
109109
**Raises**
110110
111111
- if provided an invalid ``axis``.
112-
- if the size of the axis over which to compute the dot product is not the same for both ``x1`` and ``x2``.
112+
- if the size of the axis over which to compute the dot product is not the same (before broadcasting) for both ``x1`` and ``x2``.
113113
"""
114114

115115
__all__ = ['matmul', 'matrix_transpose', 'tensordot', 'vecdot']

0 commit comments

Comments
 (0)