-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Dot and BatchedDot in PyTensor #878
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well that was a round about trip :D Thanks!
|
||
|
||
@pytorch_funcify.register(Dot) | ||
def pytorch_funcify_Dot(op, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have to import this file from pytorch.dispatch.__init__
for it to be registered (the test is failing in the CI). But Dot
is not defined in nlinalg
, so we should put it in dispatch/match.py
? Same for the test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I based it off the JAX link. If you take a look at pytensor/link/jax/dispatch/nlinalg.py you will see Max, Argmax, and Dot Op
s from math in there. Do you want me to separate them out for JAX too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can also put the Argmax I am implementing in pytorch/dispatch/math.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah in general we want to keep it more or less mirrored with the file structure where they are defined. Although our tensor/basic.py and tensor/math.py are in need of being split of as they have way too many lines
|
Yep, |
@ricardoV94 BatchedDot is done. I will do Max and Argmax next. They are tougher nuts to crack. |
Isn't Max done already? Should be like |
pytensor/link/jax/dispatch/blas.py
Outdated
def jax_funcify_BatchedDot(op, **kwargs): | ||
def batched_dot(a, b): | ||
if a.shape[0] != b.shape[0]: | ||
raise TypeError("Shapes must match in the 0-th dimension") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise TypeError("Shapes must match in the 0-th dimension") | |
raise TypeError("Shapes must match along the first dimension of BatchedDot") |
Can you split the JAX changes into a separate PR? It's better have PRs atomic as it makes it easier to review. Sometimes it's okay to have multiple functionality in a PR but then you have to respect this part of the checklist:
|
pytensor/link/jax/dispatch/math.py
Outdated
|
||
|
||
@jax_funcify.register(Max) | ||
def jax_funcify_Max(op, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have the dispatch for the other CAReduce
, like All, Any
... on the same file as Max
Okay I moved it to #913 |
@ricardoV94 For this one, I will stop at |
tests/link/pytorch/test_blas.py
Outdated
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) | ||
pytorch_mode = Mode(PytorchLinker(), opts) | ||
pytensor_pytorch_fn = function(fgraph.inputs, fgraph.outputs, mode=pytorch_mode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does the same?
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) | |
pytorch_mode = Mode(PytorchLinker(), opts) | |
pytensor_pytorch_fn = function(fgraph.inputs, fgraph.outputs, mode=pytorch_mode) | |
pytorch_mode_no_rewrites = Mode(PytorchLinker(), None) | |
pytensor_pytorch_fn = function(fgraph.inputs, fgraph.outputs, mode= pytorch_mode_no_rewrites) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But if I am not mistaken compare_pytorch_and_py
returns the torch function, so you could just reuse it?
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #878 +/- ##
=======================================
Coverage 81.40% 81.40%
=======================================
Files 173 175 +2
Lines 46914 46934 +20
Branches 11426 11427 +1
=======================================
+ Hits 38188 38205 +17
- Misses 6544 6547 +3
Partials 2182 2182
|
tests/link/pytorch/test_blas.py
Outdated
pytorch_mode_no_rewrites = Mode(PytorchLinker(), None) | ||
pytensor_pytorch_fn.mode = pytorch_mode_no_rewrites |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a thing you can do (or rather has no effect). Once a function it's compiled that's it, the mode plays no role anymore
pytorch_mode_no_rewrites = Mode(PytorchLinker(), None) | |
pytensor_pytorch_fn.mode = pytorch_mode_no_rewrites |
tests/link/pytorch/test_blas.py
Outdated
a.tag.test_value = ( | ||
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are getting rid of the test_value machinery. Just pass these directly to the test function, no point in putting them in the tag to then retrieve it again
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just a nit if you want to address
tests/link/pytorch/test_blas.py
Outdated
def test_pytorch_BatchedDot(): | ||
# tensor3 . tensor3 | ||
a = tensor3("a") | ||
A = np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: A more conventional name would be a_test, b_test for variables a and b
GitHub says the branch has conflicts with main. Can you update? |
Done. |
Description
Implemented the PyTorch link and unit tests for the math
Op
Dot
and blasOp
BatchedDot
. Did not touch onBatchedDot
orDot
for sparse matrices.Progress
Related Issue
Checklist
Type of change