Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Static broadcast #149

Closed
wants to merge 3 commits into from
Closed

Static broadcast #149

wants to merge 3 commits into from

Conversation

ferrine
Copy link
Member

@ferrine ferrine commented Dec 23, 2022

Motivation for these changes

Dynamic broadcasting creates tremendous graph obfuscations. While in forward pass it is not visible, the backward pass should always check if the broadcasting had happened or not. It may sound simple, but still creates 2^n if else statements. Originally, theano had static broadcasting and the same we get in this PR

Related Issues and PRs

Implementation details

  • Removed deprecations of broadcasting in TensorType
  • Changed tests to check cases against static broadcasting
  • Brought back shape inference for multioutput elemwise ops

Checklist

Major / Breaking Changes

  • ...

New features

  • ...

Bugfixes

  • ...

Documentation

  • ...

Maintenance

  • ...

@ferrine ferrine force-pushed the static-broadcast branch 10 times, most recently from 8d9c142 to 7dc4035 Compare December 28, 2022 12:37
@ferrine ferrine marked this pull request as ready for review December 28, 2022 13:25
@ferrine ferrine requested a review from ricardoV94 December 28, 2022 13:25
@codecov-commenter
Copy link

Codecov Report

Merging #149 (7dc4035) into main (f4de2fd) will decrease coverage by 0.00%.
The diff coverage is 84.12%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #149      +/-   ##
==========================================
- Coverage   79.98%   79.97%   -0.01%     
==========================================
  Files         169      169              
  Lines       44607    44619      +12     
  Branches     9426     9431       +5     
==========================================
+ Hits        35678    35685       +7     
- Misses       6738     6741       +3     
- Partials     2191     2193       +2     
Impacted Files Coverage Δ
pytensor/tensor/type.py 92.98% <77.77%> (-1.22%) ⬇️
pytensor/tensor/elemwise.py 88.29% <88.88%> (+0.20%) ⬆️

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. I think we should already revert the Python and C perform code to fail with dynamic broadcasting like it did before (with a better error message than it used to). That will highlight any functionality that may still be implicitly depending on dynamic broadcasting.

In addition we should add the gradient broadcasting tests that showed up in the original issue in the Aesara repo.

Related, but not necessarily in this PR we also have to disable dynamic broadcasting that's done by RandomVariables. We should at least open an issue to track that.

pytensor/tensor/elemwise.py Outdated Show resolved Hide resolved
@ferrine ferrine marked this pull request as draft December 29, 2022 15:18
@ferrine
Copy link
Member Author

ferrine commented Dec 29, 2022

This PR is a pandora box. There were a lot of small fixes to the new broadcasting. I still find some places where I'm unsure gradients will correctly propagate wit dynamic broadcasting.

https://github.com/aesara-devs/aesara/pulls?page=2&q=is%3Apr+broadcast+is%3Aclosed

One of the concerns so far is the way infer shape works in
https://github.com/search?q=repo%3Apymc-devs%2Fpytensor%20broadcast_shape&type=code

out_broadcastable = tuple(all(bcast) for bcast in zip(*broadcast_patterns))
except ValueError as e:
raise ValueError(
"Incompatible Elemwise input broadcasting pattern: "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about something like "Incompatible Elemwise input broadcasting: Broadcasting is only allowed if the shape of the broadcasted axis is statically known to be one. Use input.specify_shape to inform pytensor that a shape is 1."
I don't think we have to explain why we do it like this in the error message. We could also add a faq entry, and link to that.

@aseyboldt
Copy link
Member

This PR is a pandora box. There were a lot of small fixes to the new broadcasting. I still find some places where I'm unsure gradients will correctly propagate wit dynamic broadcasting.

Yeah, I was afraid that might be the case. I think we should go ahead and merge the cases that we know about, instead of waiting a long time to try and find everything in the first PR.

It would be great if we had some way of testing this automatically, but I don't really know how that would work...

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to revise type.filter_variable, type.is_super type.in_same_class to consider broadcasting flags. Those are called during rewrites to make sure the replacement types are compatible with the original types and/or to apply some simple operations if that would make the types equivalent (e.g., add a specify_shape and now an "Unbroadcast" perhaps).

Also we should check if we can prevent dynamic broadcasting in the JAX dispatch of Elemwise. That doesn't need to be done in this PR but we should confirm it can indeed be done.

@ferrine ferrine force-pushed the static-broadcast branch 2 times, most recently from 692542d to 885ff0c Compare January 15, 2023 17:08
@ferrine
Copy link
Member Author

ferrine commented Jan 15, 2023

We need to revise type.filter_variable, type.is_super type.in_same_class to consider broadcasting flags. Those are called during rewrites to make sure the replacement types are compatible with the original types and/or to apply some simple operations if that would make the types equivalent (e.g., add a specify_shape and now an "Unbroadcast" perhaps).

Also we should check if we can prevent dynamic broadcasting in the JAX dispatch of Elemwise. That doesn't need to be done in this PR but we should confirm it can indeed be done.

I've changed that. Seems like rewrites are important to check now

UPD: replaces indeed get broken

@ferrine
Copy link
Member Author

ferrine commented Jan 15, 2023

Once I go through small fixes they arise more problems. Fixing the variable filtering I just discovered

  • sparse variables have incomplete support for broadcasting.
  • constant folding violates broadcastable property
  • gradients sometimes ignore broadcasting property

pytensor/tensor/blas.py Outdated Show resolved Hide resolved
}
if (%(lv{j0})s_n{x0} != %(lv{j})s_n{x})
{{
PyErr_Format(PyExc_ValueError, "Input dimension mismatch implicit broadcasting is not supported. (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth it to make the error conditional on numpy broadcasting case. In that case we can say not supported, link to the FAQ and whatever. On the case where it's a mismatch without shape of 1 we should just have the vanilla error message. This is the error most users will be hitting as long as C is the default backend.

@ferrine
Copy link
Member Author

ferrine commented Jan 16, 2023

The issue that pops up seems to be introduced when refactoring Alloc rewrites here

https://github.com/aesara-devs/aesara/pull/1102/files

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 17, 2023

The issue that pops up seems to be introduced when refactoring Alloc rewrites here

https://github.com/aesara-devs/aesara/pull/1102/files

Yes, it makes sense to revert those changes. The rewrite originally followed the broadcastable conventions as you can see from the original issue aesara-devs/aesara#1094. It no longer respects them because it was rewritten to support dynamic broadcasting

@ferrine
Copy link
Member Author

ferrine commented Jan 17, 2023

yeah, the changes are in this specific commit
aesara-devs/aesara@f604e1f

@michaelosthege
Copy link
Member

@ricardoV94 @ferrine the conflict in elemwise.py should probably be resolved by accepting the incoming version.
For the one in elemwise_cgen.py it looks like a similar change was done on both branches. The original might have been @ricardoV94's edit and @ferrine copied the changes to this branch?
I don't understand this part enough to resolve it =/

@ferrine
Copy link
Member Author

ferrine commented Jun 29, 2023

it will be fun to go back to this PR, a lot of rebase conflicts...

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 4, 2023

I think it might work better to take it piece by piece, maybe without attempting a direct git revert. I'll try to spin-off a PR to reintroduce it for Elemwise. We can leave the Blas Ops for later

PR #372

@ricardoV94
Copy link
Member

Closing as we did already some progress elsewhere

@ricardoV94 ricardoV94 closed this Aug 24, 2023
@ricardoV94 ricardoV94 deleted the static-broadcast branch June 13, 2024 12:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants