Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Dot and BatchedDot in PyTensor #878

Merged
merged 13 commits into from
Jul 18, 2024

Conversation

HangenYuu
Copy link
Contributor

@HangenYuu HangenYuu commented Jul 3, 2024

Description

Implemented the PyTorch link and unit tests for the math Op Dot and blas Op BatchedDot. Did not touch on BatchedDot or Dot for sparse matrices.

Progress

  • Dot
  • BatchedDot

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@ricardoV94 ricardoV94 added enhancement New feature or request torch PyTorch backend labels Jul 4, 2024
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well that was a round about trip :D Thanks!



@pytorch_funcify.register(Dot)
def pytorch_funcify_Dot(op, **kwargs):
Copy link
Member

@ricardoV94 ricardoV94 Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have to import this file from pytorch.dispatch.__init__ for it to be registered (the test is failing in the CI). But Dot is not defined in nlinalg, so we should put it in dispatch/match.py? Same for the test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I based it off the JAX link. If you take a look at pytensor/link/jax/dispatch/nlinalg.py you will see Max, Argmax, and Dot Ops from math in there. Do you want me to separate them out for JAX too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can also put the Argmax I am implementing in pytorch/dispatch/math.py.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah in general we want to keep it more or less mirrored with the file structure where they are defined. Although our tensor/basic.py and tensor/math.py are in need of being split of as they have way too many lines

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 8, 2024

BatchedDot should be pretty simple as well, it's a matmul of 3d tensors without allowing broadcasting on the leading dimension, so with a check that a.shape[0] == b.shape[0]. Wanna give it a go? You should be able to trigger it if you test tensor(shape=(5, 3, 2)) @ tensor(shape=(5, 2, 4)) (pseudo-code)

@HangenYuu HangenYuu changed the title Added PyTorch link and unit tests for normal dot Added PyTorch link and unit tests for blas & math Ops Jul 10, 2024
@HangenYuu
Copy link
Contributor Author

Yep, torch.argmax also does not allow multiple partial axes.

@HangenYuu
Copy link
Contributor Author

@ricardoV94 BatchedDot is done. I will do Max and Argmax next. They are tougher nuts to crack.

@ricardoV94
Copy link
Member

Isn't Max done already? Should be like Sum/All, ... which we did already. Argmax just needs some ravelling and tranposing. You should be able to copy the logic inside perform

def jax_funcify_BatchedDot(op, **kwargs):
def batched_dot(a, b):
if a.shape[0] != b.shape[0]:
raise TypeError("Shapes must match in the 0-th dimension")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise TypeError("Shapes must match in the 0-th dimension")
raise TypeError("Shapes must match along the first dimension of BatchedDot")

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 10, 2024

Can you split the JAX changes into a separate PR? It's better have PRs atomic as it makes it easier to review. Sometimes it's okay to have multiple functionality in a PR but then you have to respect this part of the checklist:



@jax_funcify.register(Max)
def jax_funcify_Max(op, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have the dispatch for the other CAReduce, like All, Any... on the same file as Max

@HangenYuu
Copy link
Contributor Author

HangenYuu commented Jul 10, 2024

Can you split the JAX changes into a separate PR? It's better have PRs atomic as it makes it easier to review. Sometimes it's okay to have multiple functionality in a PR but then you have to respect this part of the checklist:

Okay I moved it to #913

@HangenYuu
Copy link
Contributor Author

@ricardoV94 For this one, I will stop at BatchedDot.

@HangenYuu HangenYuu changed the title Added PyTorch link and unit tests for blas & math Ops Added PyTorch link and unit tests for math.Dot and blas.BatchedDot Jul 15, 2024
Comment on lines 32 to 34
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
pytorch_mode = Mode(PytorchLinker(), opts)
pytensor_pytorch_fn = function(fgraph.inputs, fgraph.outputs, mode=pytorch_mode)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does the same?

Suggested change
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
pytorch_mode = Mode(PytorchLinker(), opts)
pytensor_pytorch_fn = function(fgraph.inputs, fgraph.outputs, mode=pytorch_mode)
pytorch_mode_no_rewrites = Mode(PytorchLinker(), None)
pytensor_pytorch_fn = function(fgraph.inputs, fgraph.outputs, mode= pytorch_mode_no_rewrites)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if I am not mistaken compare_pytorch_and_py returns the torch function, so you could just reuse it?

@ricardoV94 ricardoV94 changed the title Added PyTorch link and unit tests for math.Dot and blas.BatchedDot Implement Dot and BatchedDot in PyTensor Jul 15, 2024
Copy link

codecov bot commented Jul 17, 2024

Codecov Report

Attention: Patch coverage is 85.71429% with 3 lines in your changes missing coverage. Please review.

Project coverage is 81.40%. Comparing base (426931b) to head (f459866).
Report is 94 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/pytorch/dispatch/blas.py 80.00% 2 Missing ⚠️
pytensor/link/pytorch/dispatch/math.py 87.50% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #878   +/-   ##
=======================================
  Coverage   81.40%   81.40%           
=======================================
  Files         173      175    +2     
  Lines       46914    46934   +20     
  Branches    11426    11427    +1     
=======================================
+ Hits        38188    38205   +17     
- Misses       6544     6547    +3     
  Partials     2182     2182           
Files with missing lines Coverage Δ
pytensor/link/pytorch/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/link/pytorch/dispatch/math.py 87.50% <87.50%> (ø)
pytensor/link/pytorch/dispatch/blas.py 80.00% <80.00%> (ø)

Comment on lines 32 to 33
pytorch_mode_no_rewrites = Mode(PytorchLinker(), None)
pytensor_pytorch_fn.mode = pytorch_mode_no_rewrites
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a thing you can do (or rather has no effect). Once a function it's compiled that's it, the mode plays no role anymore

Suggested change
pytorch_mode_no_rewrites = Mode(PytorchLinker(), None)
pytensor_pytorch_fn.mode = pytorch_mode_no_rewrites

Comment on lines 17 to 19
a.tag.test_value = (
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are getting rid of the test_value machinery. Just pass these directly to the test function, no point in putting them in the tag to then retrieve it again

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just a nit if you want to address

def test_pytorch_BatchedDot():
# tensor3 . tensor3
a = tensor3("a")
A = np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: A more conventional name would be a_test, b_test for variables a and b

@ricardoV94
Copy link
Member

GitHub says the branch has conflicts with main. Can you update?

@HangenYuu
Copy link
Contributor Author

GitHub says the branch has conflicts with main. Can you update?

Done.

@ricardoV94 ricardoV94 merged commit 6ad1c5c into pymc-devs:main Jul 18, 2024
58 of 59 checks passed
@HangenYuu HangenYuu deleted the torch_dot branch July 22, 2024 02:15
Ch0ronomato pushed a commit to Ch0ronomato/pytensor that referenced this pull request Aug 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request torch PyTorch backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants