-
Notifications
You must be signed in to change notification settings - Fork 112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add shape_unsafe
tag to rewrites that can hide shape errors
#381
Conversation
shape_unsafe
tag to rewrites that make shape assumptions
shape_unsafe
tag to rewrites that make shape assumptionsshape_unsafe
tag to rewrites that can hide shape errors
shape_unsafe
tag to rewrites that can hide shape errorsshape_unsafe
tag to rewrites that can hide shape errors
@@ -1757,7 +1616,19 @@ def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]: | |||
The arrays to broadcast. | |||
|
|||
""" | |||
return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args) | |||
|
|||
def broadcast_with_others(a, others): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We discussed with @aseyboldt that it may make sense to generalize Second
so that it accepts arbitrary many inputs and returns every variable as output. This would become a flat broadcast_arrays
once Elemwise
d, and make rewrites easier to read. By overriding the __str__
we can also make it much more readable in debug_print
than the current nested Second
1f614bd
to
f1d19f7
Compare
f1d19f7
to
2a3adbe
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #381 +/- ##
==========================================
- Coverage 80.44% 80.41% -0.03%
==========================================
Files 156 156
Lines 45470 45413 -57
Branches 11136 11119 -17
==========================================
- Hits 36578 36520 -58
- Misses 6687 6693 +6
+ Partials 2205 2200 -5
|
Fixing #379 should also help with the "unsafety" concerns |
@@ -1561,141 +1561,6 @@ def broadcast_shape_iter( | |||
return tuple(result_dims) | |||
|
|||
|
|||
class BroadcastTo(COp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BroadcastTo
is imported in pymc a couple of times. Maybe we should leave an empty Op here, that is deprecated and doesn't do anything?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some of the removed rewrites are also directly imported.
This shouldn't be a problem however. I marked this PR as a major release so we will bump the version above the upper-bound pinned by PyMC. When we update the pin on PyMC I'll address the changes. They require some manual review anyway to see if the logic that depended on BroadcastTo was valid per our new rules and can be transferred to Alloc.
This was all on the logprob inference module AFAICT so impact should be pretty contained.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good
Other than the two suggestions above this looks good :-) |
2a3adbe
to
9f8ed94
Compare
BroadcastTo
in favor ofAlloc
Note that from a user standpoint, providing static shapes (via
vector("x", shape=(5,))
orspecify_shapes
) will many times reveal shape errors immediately (this is the case for 99% of PyMC models). In this case users should feel pretty safe about "shape_unsafe" rewrites because they aren't really masking anything that wasn't checked before already.Alloc.make_node
now also raises early when it can see the provided shape is inconsistent. Alloc and Elemwise make up all of the tagged "shape_unsafe" rewrites so far.With this PR, users can also do
mode=get_default_mode().excluding("shape_unsafe")
or addshape_unsafe
to theexcluding
config to skip these rewrites at the cost of less optimizations.Closes #367