-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Static broadcast #149
Static broadcast #149
Conversation
8d9c142
to
7dc4035
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #149 +/- ##
==========================================
- Coverage 79.98% 79.97% -0.01%
==========================================
Files 169 169
Lines 44607 44619 +12
Branches 9426 9431 +5
==========================================
+ Hits 35678 35685 +7
- Misses 6738 6741 +3
- Partials 2191 2193 +2
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. I think we should already revert the Python and C perform code to fail with dynamic broadcasting like it did before (with a better error message than it used to). That will highlight any functionality that may still be implicitly depending on dynamic broadcasting.
In addition we should add the gradient broadcasting tests that showed up in the original issue in the Aesara repo.
Related, but not necessarily in this PR we also have to disable dynamic broadcasting that's done by RandomVariables. We should at least open an issue to track that.
7dc4035
to
22b436b
Compare
This PR is a pandora box. There were a lot of small fixes to the new broadcasting. I still find some places where I'm unsure gradients will correctly propagate wit dynamic broadcasting. https://github.com/aesara-devs/aesara/pulls?page=2&q=is%3Apr+broadcast+is%3Aclosed One of the concerns so far is the way infer shape works in |
out_broadcastable = tuple(all(bcast) for bcast in zip(*broadcast_patterns)) | ||
except ValueError as e: | ||
raise ValueError( | ||
"Incompatible Elemwise input broadcasting pattern: " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about something like "Incompatible Elemwise input broadcasting: Broadcasting is only allowed if the shape of the broadcasted axis is statically known to be one. Use input.specify_shape
to inform pytensor that a shape is 1."
I don't think we have to explain why we do it like this in the error message. We could also add a faq entry, and link to that.
Yeah, I was afraid that might be the case. I think we should go ahead and merge the cases that we know about, instead of waiting a long time to try and find everything in the first PR. It would be great if we had some way of testing this automatically, but I don't really know how that would work... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to revise type.filter_variable
, type.is_super
type.in_same_class
to consider broadcasting flags. Those are called during rewrites to make sure the replacement types are compatible with the original types and/or to apply some simple operations if that would make the types equivalent (e.g., add a specify_shape
and now an "Unbroadcast" perhaps).
Also we should check if we can prevent dynamic broadcasting in the JAX dispatch of Elemwise. That doesn't need to be done in this PR but we should confirm it can indeed be done.
692542d
to
885ff0c
Compare
I've changed that. Seems like rewrites are important to check now UPD: replaces indeed get broken |
Once I go through small fixes they arise more problems. Fixing the variable filtering I just discovered
|
} | ||
if (%(lv{j0})s_n{x0} != %(lv{j})s_n{x}) | ||
{{ | ||
PyErr_Format(PyExc_ValueError, "Input dimension mismatch implicit broadcasting is not supported. (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe worth it to make the error conditional on numpy broadcasting case. In that case we can say not supported, link to the FAQ and whatever. On the case where it's a mismatch without shape of 1 we should just have the vanilla error message. This is the error most users will be hitting as long as C is the default backend.
4971ac3
to
5facd79
Compare
The issue that pops up seems to be introduced when refactoring Alloc rewrites here https://github.com/aesara-devs/aesara/pull/1102/files |
Yes, it makes sense to revert those changes. The rewrite originally followed the broadcastable conventions as you can see from the original issue aesara-devs/aesara#1094. It no longer respects them because it was rewritten to support dynamic broadcasting |
yeah, the changes are in this specific commit |
5facd79
to
e5b5294
Compare
@ricardoV94 @ferrine the conflict in |
it will be fun to go back to this PR, a lot of rebase conflicts... |
I think it might work better to take it piece by piece, maybe without attempting a direct git revert. I'll try to spin-off a PR to reintroduce it for Elemwise. We can leave the Blas Ops for later PR #372 |
Closing as we did already some progress elsewhere |
Motivation for these changes
Dynamic broadcasting creates tremendous graph obfuscations. While in forward pass it is not visible, the backward pass should always check if the broadcasting had happened or not. It may sound simple, but still creates 2^n if else statements. Originally, theano had static broadcasting and the same we get in this PR
Related Issues and PRs
Elemwise.c_code
aesara-devs/aesara#928Type
inference inRandomVariable
andScan
aesara-devs/aesara#1253TensorType.broadcastable
usage fromlocal_elemwise_alloc
aesara-devs/aesara#1102Implementation details
Checklist
Major / Breaking Changes
New features
Bugfixes
Documentation
Maintenance