-
Notifications
You must be signed in to change notification settings - Fork 33
More fixes for torch linalg extension #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
968e499
Add implementation for torch.linalg.vecdot for integer dtypes
asmeurer dfc1275
Pass kwargs through to torch.linalg.vecdot
asmeurer d6bf189
Add wrapper for torch.linalg.solve that does correct type promotion
asmeurer e6aa969
The torch log1p test can apparently fail on CI too
asmeurer 32fc8e7
Add missing numpy 1.21 xfail
asmeurer 05204b3
Fix the naming of the numpy jobs on CI
asmeurer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These semantics allow for a weird construction like
vecdot(xp.randn(1,4,5,6), xp.randn(6), axis=0)
which would be equivalent toxp.randn(1,4,5,6) * xp.randn(6)
. This may be a discussion for the general broadcasting rules for the API, but perhaps you want that to assert thatall(x.ndim - 1 >= axis for x in input_tensors)
(perhaps with some special treatment for 0-dim tensors).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that I think about it, I wonder if this sort of thing should already be disallowed by the spec. https://data-apis.org/array-api/latest/API_specification/generated/array_api.vecdot.html#array_api.vecdot
I've been working on the assumption that axis refers to the dimension after broadcasting ("Must be an integer on the interval [-N, N), where N is the rank (number of dimensions) of the shape determined according to Broadcasting."). But it also says "The contracted axis (dimension) must not be broadcasted."
I had been interpreting that as meaning you shouldn't allow something like
vecdot(empty((3, 3)), empty((1, 3)), axis=0)
. But I suppose it could also be taken to mean that broadcasting shouldn't "broadcast up" to the contracted dimension either. Something more along the lines of(e.g.,
vecdot(empty((1, 2, 3, 4, 5)), empty((3, 4, 5)), axis=2)
is fine butvecdot(empty((1, 2, 3, 4, 5)), empty((3, 4, 5)), axis=0)
is not)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think disallowing both (not broadcast along the reduced dimension and ask for the axis to be well defined in the sense you just described) would be the safer thing to do, as otherwise you end up with these funny constructions, which is not great.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an issue in all the other implementations too, and arguably the spec as well. I'm going to deal with it in a separate PR.