Skip to content

Commit 7728c98

Browse files
committed
Only require axis to be negative in vecdot and cross
Nonnegative axes and negative axes less than the smaller of the two arrays are unspecified. This is because it is ambiguous in these cases whether the dimension should refer to the axis before or after broadcasting. Preciously, the spec stated it should refer to the dimension before broadcasting, but this deviates from NumPy gufunc behavior, and results in ambiguous and confusing situations, where, for instance, the result of a the function is different when the inputs are manually broadcasted together. Also clean up some of the cross text a little bit since the computed dimension must be exactly size 3. Fixes data-apis#724 Fixes data-apis#617 See the discussion in those issues for more details.
1 parent 95332bb commit 7728c98

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

src/array_api_stubs/_draft/linalg.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,15 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
8383
Parameters
8484
----------
8585
x1: array
86-
first input array. Must have a numeric data type.
86+
first input array. Must have a numeric data type. The size of the axis over which the cross product is to be computed must be equal to 3.
8787
x2: array
88-
second input array. Must be compatible with ``x1`` for all non-compute axes (see :ref:`broadcasting`). The size of the axis over which to compute the cross product must be the same size as the respective axis in ``x1``. Must have a numeric data type.
88+
second input array. Must be broadcast compatible with ``x1`` along all axes other than the axis along which the cross-product is computed (see :ref:`broadcasting`). The size of the axis over which the cross product is to be computed must be equal to 3. Must have a numeric data type.
8989
9090
.. note::
9191
The compute axis (dimension) must not be broadcasted.
9292
9393
axis: int
94-
the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the cross product. Must be an integer on the interval ``[-N, N)``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting`. If specified as a negative integer, the function must determine the axis along which to compute the cross product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the cross product over the last axis. Default: ``-1``.
94+
the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the cross product. Should be an integer on the interval ``[-N, -1]``, where ``N`` is ``min(x1.ndim, x2.ndim)``. The function must determine the axis along which to compute the cross product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the cross product over the last axis. Default: ``-1``.
9595
9696
Returns
9797
-------
@@ -110,8 +110,7 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
110110
111111
**Raises**
112112
113-
- if the size of the axis over which to compute the cross product is not equal to ``3``.
114-
- if the size of the axis over which to compute the cross product is not the same (before broadcasting) for both ``x1`` and ``x2``.
113+
- if the size of the axis over which to compute the cross product is not equal to ``3`` (before broadcasting) for both ``x1`` and ``x2``.
115114
"""
116115

117116

src/array_api_stubs/_draft/linear_algebra_functions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
141141
The contracted axis (dimension) must not be broadcasted.
142142
143143
axis: int
144-
axis over which to compute the dot product. Must be an integer on the interval ``[-N, N)``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting`. If specified as a negative integer, the function must determine the axis along which to compute the dot product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the dot product over the last axis. Default: ``-1``.
144+
the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the dot product. Should be an integer on the interval ``[-N, -1]``, where ``N`` is ``min(x1.ndim, x2.ndim)``. The function must determine the axis along which to compute the dot product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the dot product over the last axis. Default: ``-1``.
145145
146146
Returns
147147
-------

0 commit comments

Comments
 (0)