-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Broadcasting for linalg functions that accept an axis #617
Comments
I don't believe this is correct. The spec, as you quote above, says "where N is the rank (number of dimensions) of the shape determined according to Broadcasting". Meaning, the broadcasted shape, not the pre-broadcasted shape. |
@asmeurer I'm having some difficulty following your OP. Can you suggest specifically how you'd like to see the specification revised? If it would be easier, you can submit a PR with the proposed updated guidance. |
Nonnegative axes and negative axes less than the smaller of the two arrays are unspecified. This is because it is ambiguous in these cases whether the dimension should refer to the axis before or after broadcasting. Preciously, the spec stated it should refer to the dimension before broadcasting, but this deviates from NumPy gufunc behavior, and results in ambiguous and confusing situations, where, for instance, the result of a the function is different when the inputs are manually broadcasted together. Also clean up some of the cross text a little bit since the computed dimension must be exactly size 3. Fixes data-apis#724 Fixes data-apis#617 See the discussion in those issues for more details.
Fix at #740. |
This commit updates specification guidance in `vecdot` and `cross` to no longer explicitly support positive `axis` kwarg values. Previous specification guidance conflicts with NumPy gufuncs and restricting to negative integers removes ambiguity in determining over which axis to perform computation. This commit uses `should`, not `must`, to allow conforming libraries to support nonnegative `axis` values for backward compatibility. Closes: #724 Closes: #617 PR-URL: #740 Reviewed-by: Athan Reines <kgryte@gmail.com>
Two linear algebra functions allow contraction over an arbitrary axis, cross and vecdot.
These APIs currently specify:
as well as
This is ambiguous however when the contracted axis is a dimension that is created by broadcasting, for instance
Here axis=0 applied to the first dimension would be size
1
.I think this case should be disallowed.
Additionally,
axis
is ambiguous. It isn't clear if it should refer to the axis before or after broadcasting:This is of particular interest if
axis >= 0
.NumPy appears to refer to the axis before broadcasting:
In fact, these two arrays aren't strictly broadcast compatible. What NumPy does is move the
axis
dimension ofx1
andx2
to the end of the arrays, then broadcastsx1[..., 0]
andx2[..., 0]
. Effectively:In other words, the arrays should be broadcast compatible after removing
axis
from the shape (and we should havex1.shape[axis] == x2.shape[axis] == 3
).NumPy doesn't have
vecdot
yet, but it should obviously work the same (the only difference being the contracted axis can have any size in vecdot, not just 3, and unlike cross, in vecdot the contracted axis is removed from the resulting shape). torch.linalg.cross doesn't appear to support any broadcasting.My implementations of
vecdot
innumpy.array_api
andarray-api-compat
have been using the idea thataxis
refers to the axis broadcasting and allowing an added broadcasted axis. But I think this should be changed to work likenp.cross
. Thenumpy.array_api
andarray_api_compat.numpy
cross
implementations just reusenp.cross
and therefore use those semantics (I didn't realize til now that we weren't actually testing any broadcasting rules forcross
in the test suite).This was discussed at data-apis/array-api-compat#35 (comment) (CC @lezcano).
Finally, note that tensordot doesn't have this issue because the axes are specified for each array separately.
The text was updated successfully, but these errors were encountered: