-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use static-only broadcasting rules to compute shape of broadcasting #345
Conversation
Possibly related to #330 |
Nothing against the PR itself but note that these complicated graphs are likely to go away if you provide static shape for the inputs, which is the case with 99% of PyMC models. |
ea12b84
to
0fbec99
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #345 +/- ##
==========================================
- Coverage 80.40% 80.39% -0.01%
==========================================
Files 156 156
Lines 45416 45398 -18
Branches 11114 11106 -8
==========================================
- Hits 36515 36498 -17
+ Misses 6695 6694 -1
Partials 2206 2206
|
@ricardoV94 True. I was mostly just hoping we could get the changes like in #149 piece by piece, it does seem pretty tricky to do it all in one go... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, left some nitpick suggestions.
pytensor/tensor/extra_ops.py
Outdated
# Get the shapes in this dimension that are not definitively | ||
# broadcastable (i.e. not symbolically known to be broadcastable) | ||
# Get the shapes in this dimension that are not broadcastable | ||
# (i.e. not symbolically known to be broadcastable) | ||
maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the "maybe"?
maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at] | |
non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at] |
pytensor/tensor/extra_ops.py
Outdated
) | ||
else: | ||
bcast_dim = const_nt_shape_var | ||
assert_op = Assert("Could not dynamically broadcast dimensions.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe create this Op once at the module level?
pytensor/tensor/extra_ops.py
Outdated
) | ||
else: | ||
bcast_dim = const_nt_shape_var | ||
assert_op = Assert("Could not dynamically broadcast dimensions.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert_op = Assert("Could not dynamically broadcast dimensions.") | |
assert_op = Assert("Could not broadcast dimensions. If a variable should broadcast use `specify_shape` to inform PyTensor.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made it even more specific:
"Could not broadcast dimensions. Broadcasting is only allowed along "
"axes that have a statically known length 1. Use specify_shape
to "
"inform PyTensor of a known shape."
pytensor/tensor/extra_ops.py
Outdated
continue | ||
|
||
# Add assert that all remaining shapes are equal | ||
use_scalars = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove the use_scalars
block?
One thing we could consider in the future is a |
I think the only thing that's needed is to allow the FusionRewrite to work on 0d tensors (right now I think it requires that ndim > 1). Fusing chains of 0d tensors inside a Composite would have the same effect as using scalars in the graph with a small overhead from Once we have a 0d Elemwise composite it's also trivial to replace it by the scalar case if that's more efficient. Created issue: #349 |
I don't think I understand what you mean with the FusionRewrite. If we used the code that never produces an Elemwise in the first place, then how would Elemwise Fusion matter? |
If you used the code with |
I fixed the suggestions above.
But if we work on tensors anyway, I don't think that would be better than the all(eq(...)) code, would it? The point was more that I would like not to have any tensors in here at all. Since |
I'm gonna mark this as major so that we are forced to manually bump the dependency on pymc and run the test suite (in case it breaks something there) |
Motivation for these changes
As discussed here we want to go back to static broadcasting only, and one this simplifies the rules for the expected shapes of broadcasted arrays quite a bit. This PR changes the broadcast_shape function to take advantage of this, which then simplifies a lot of graphs significantly.
Implementation details
I'm not sure yet how we should code the condition that triggers the assert for arrays that aren't compatible.
The current code contains an implementation that uses Elemwise, and one that uses scalar expressions.
The scalar expressions seem a bit more reasonable to me (not tensors and no allocations), but since most rewrites only work on Elemwise, this can prevent a lot of optimizations right now.
Checklist
Major / Breaking Changes
This can lead to dynamic assertion failures for code that relied on the (never fully implemented) support for dynamic broadcasting.
Example of before and after
Before
After