Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forbid runtime broadcasting in Elemwise #372

Merged

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jul 4, 2023

Related to #100
Related to #149
Related to #371

pytensor/tensor/elemwise.py Outdated Show resolved Hide resolved

out_shape = pytensor.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True)
def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could just use this function: https://github.com/pymc-devs/pytensor/blob/main/pytensor/tensor/extra_ops.py#L1465

The make_node method doesn't seem to properly take into account the broadcastable flag either though, maybe that needs an update as well?

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't want to introduce checks or comparison between shapes, which that function does. This allows it to return a more optimized graph like Theano used to by assuming no invalid shapes were provided

The question then is whether we want to refactor that helper to do the same when arrays_are_shapes=False?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the make_node is correct insofar as it uses static shape and it's not possible to have broadcastable=False and shape=1

That one still requires some thinking and would be tackled in a separate PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't want to introduce checks or comparison between shapes, which that function does. This allows it to return a more optimized graph like Theano used to by assuming no invalid shapes were provided

So we allow undefined behavior in the shapes and in rewrites? I'm not sure I see that much downside with having that check here...

But at least I think we shouldn't have this logic in both places. Maybe the function should have a flag if it should return shape with or without checks?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking we should add a config.assume_shapes_correct flag (default to True) to toggle that behavior in both shape_inference and rewrites that can return simplified cases.

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually that helper works differently in that it expects either shapes or arrays, but here we are combining information from both shapes and arrays so it would require some refactoring. We don't want to simply pass node.inputs since infer_shape wants us to return a graph from ishapes.

I don't know if that is the right place to implement this logic since it is a user facing function. WDYT?

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I reverted to using the helper. Things are a bit weird in shape compilation because it will just use the static type shape of the node if that's available. Because the Elemwise make_node assumes valid shapes, the check introduced by infer_shape is only triggered when all dims are None.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not much we can do about that then I think without a major rewrite of the shape handling...

pytensor/tensor/elemwise_cgen.py Outdated Show resolved Hide resolved
"""
@staticmethod
def check_runtime_broadcast_error(mode):
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd feel better if those tests were a bit more complete, ie inputs with different lengths etc...

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean different runtime shapes (3 vs 5)? I am sure there are old tests for that already.

There are tests for invalid static shapes.

This one test was added when we specifically allowed runtime broadcasting in Aesara. The other thing I considered doing was to just remove it.

I'll confirm other tests for invalid shapes exist and maybe combine with this if they are not too convoluted.

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test for incompatible non-broadcast shapes. Let me know if you meant something else

@ricardoV94 ricardoV94 force-pushed the revert_dynamic_broadcast_elemwise branch from 60b6d6f to a21ae05 Compare July 10, 2023 08:05
@ricardoV94 ricardoV94 force-pushed the revert_dynamic_broadcast_elemwise branch from a21ae05 to b2c2743 Compare July 10, 2023 08:19
@ricardoV94 ricardoV94 force-pushed the revert_dynamic_broadcast_elemwise branch 2 times, most recently from a26e46b to f3ad19a Compare July 10, 2023 12:04
@@ -35,15 +35,20 @@ def compute_itershape(
with builder.if_then(
builder.icmp_unsigned("!=", length, shape[i]), likely=False
):
with builder.if_else(builder.icmp_unsigned("==", length, one)) as (
with builder.if_else(
builder.or_(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird the changes cause a SegmentationFault on the BroadcastTo numba test, but only on python 3.11? I couldn't replicate locally on 3.8 either. https://github.com/pymc-devs/pytensor/actions/runs/5507826611/jobs/10039563156?pr=372

Did I do something obviously wrong @aseyboldt?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see anything wrong, I can try locally with py311 and if I can reproduce I can try to look at it in a debugger (with no debugging symbols, but well...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you can quickly try to reproduce that's already helpful (even if you don't dig down)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No luck so far, for me the tests run just fine...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It reliably segfaults here. I'll remove the numba changes for now and put the new test as an xfail

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it segfault during the test_BroadcastTo test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes... tests/link/numba/test_extra_ops.py::test_BroadcastTo[x0-shape0].

https://github.com/pymc-devs/pytensor/actions/runs/5507826611/jobs/10039563156?pr=372#step:6:281

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I don't see how it could be a problem in those tests. There is nothing else in the compiled graph other than the BroadcastTo

@ricardoV94 ricardoV94 force-pushed the revert_dynamic_broadcast_elemwise branch from eb98809 to 1ab333d Compare July 11, 2023 08:12
@codecov-commenter
Copy link

codecov-commenter commented Jul 11, 2023

Codecov Report

Merging #372 (d044271) into main (5c87d74) will decrease coverage by 0.01%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #372      +/-   ##
==========================================
- Coverage   80.40%   80.40%   -0.01%     
==========================================
  Files         156      156              
  Lines       45401    45390      -11     
  Branches    11106    11103       -3     
==========================================
- Hits        36505    36496       -9     
  Misses       6689     6689              
+ Partials     2207     2205       -2     
Impacted Files Coverage Δ
pytensor/tensor/extra_ops.py 89.00% <ø> (ø)
pytensor/link/jax/dispatch/elemwise.py 81.69% <100.00%> (+1.09%) ⬆️
pytensor/tensor/elemwise.py 88.07% <100.00%> (+0.01%) ⬆️
pytensor/tensor/elemwise_cgen.py 95.34% <100.00%> (-0.40%) ⬇️

... and 2 files with indirect coverage changes

@ricardoV94 ricardoV94 force-pushed the revert_dynamic_broadcast_elemwise branch from 1ab333d to 28b3b46 Compare July 11, 2023 10:00
@ricardoV94 ricardoV94 force-pushed the revert_dynamic_broadcast_elemwise branch from 28b3b46 to d044271 Compare July 11, 2023 10:37
@ricardoV94 ricardoV94 requested a review from aseyboldt July 11, 2023 11:52
Copy link
Member

@aseyboldt aseyboldt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good :-)

@ricardoV94 ricardoV94 merged commit 981be2a into pymc-devs:main Jul 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants