Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use static-only broadcasting rules to compute shape of broadcasting #345

Merged
merged 2 commits into from
Jun 17, 2023

Conversation

aseyboldt
Copy link
Member

@aseyboldt aseyboldt commented Jun 15, 2023

Motivation for these changes

As discussed here we want to go back to static broadcasting only, and one this simplifies the rules for the expected shapes of broadcasted arrays quite a bit. This PR changes the broadcast_shape function to take advantage of this, which then simplifies a lot of graphs significantly.

Implementation details

I'm not sure yet how we should code the condition that triggers the assert for arrays that aren't compatible.
The current code contains an implementation that uses Elemwise, and one that uses scalar expressions.
The scalar expressions seem a bit more reasonable to me (not tensors and no allocations), but since most rewrites only work on Elemwise, this can prevent a lot of optimizations right now.

Checklist

Major / Breaking Changes

This can lead to dynamic assertion failures for code that relied on the (never fully implemented) support for dynamic broadcasting.

Example of before and after

import pytensor.tensor as pt
import pytensor
import pymc as pm
import numpy as np

trafo = pm.distributions.multivariate.ZeroSumTransform([1])

x = pt.dmatrix("x")
y = trafo.backward(x)
grad = pt.grad(y.sum(), x)

func = pytensor.function([x], grad)

Before

Add [id A] 27
 ├─ Split{2}.0 [id B] 15
 │  ├─ Alloc [id C] 13
 │  │  ├─ [[1.]] [id D]
 │  │  ├─ TensorFromScalar [id E] 11
 │  │  │  └─ Assert{msg=Could not broadcast dimensions} [id F] 9
 │  │  │     ├─ Composite{Abs(maximum(Switch(EQ(i0, 1), (-1), i0), Switch(EQ(i1, 1), (-1), i1)))} [id G] 7
 │  │  │     │  ├─ ScalarFromTensor [id H] 4
 │  │  │     │  │  └─ Shape_i{0} [id I] 1
 │  │  │     │  │     └─ x [id J]
 │  │  │     │  └─ ScalarFromTensor [id H] 4
 │  │  │     │     └─ ···
 │  │  │     └─ Composite{...} [id K] 6
 │  │  │        ├─ ScalarFromTensor [id H] 4
 │  │  │        │  └─ ···
 │  │  │        └─ ScalarFromTensor [id H] 4
 │  │  │           └─ ···
 │  │  └─ Add [id L] 8
 │  │     ├─ Shape_i{1} [id M] 0
 │  │     │  └─ x [id J]
 │  │     └─ 1 [id N]
 │  ├─ 1 [id O]
 │  └─ MakeVector{dtype='int64'} [id P] 2
 │     ├─ Shape_i{1} [id M] 0
 │     │  └─ ···
 │     └─ 1 [id N]
 ├─ Composite{((i0 + i1) / i2)} [id Q] 25
 │  ├─ SpecifyShape [id R] 17
 │  │  ├─ Split{2}.1 [id B] 15
 │  │  │  └─ ···
 │  │  ├─ NoneConst{None} [id S]
 │  │  └─ 1 [id O]
 │  ├─ BroadcastTo [id T] 24
 │  │  ├─ Composite{cast{float64}((-i0))} [id U] 12
 │  │  │  └─ ExpandDims{axes=[0, 1]} [id V] 10
 │  │  │     └─ Add [id L] 8
 │  │  │        └─ ···
 │  │  ├─ TensorFromScalar [id W] 23
 │  │  │  └─ Assert{msg=Could not broadcast dimensions} [id X] 22
 │  │  │     ├─ Composite{Abs(maximum(Switch(EQ(i0, 1), (-1), i0), Switch(EQ(i1, 1), (-1), i1)))} [id Y] 21
 │  │  │     │  ├─ ScalarFromTensor [id Z] 19
 │  │  │     │  │  └─ Subtensor{i} [id BA] 18
 │  │  │     │  │     ├─ SetSubtensor{i} [id BB] 16
 │  │  │     │  │     │  ├─ MakeVector{dtype='int64'} [id BC] 14
 │  │  │     │  │     │  │  ├─ TensorFromScalar [id E] 11
 │  │  │     │  │     │  │  │  └─ ···
 │  │  │     │  │     │  │  └─ Add [id L] 8
 │  │  │     │  │     │  │     └─ ···
 │  │  │     │  │     │  ├─ 1 [id N]
 │  │  │     │  │     │  └─ 1 [id BD]
 │  │  │     │  │     └─ 0 [id BE]
 │  │  │     │  └─ Assert{msg=Could not broadcast dimensions} [id F] 9
 │  │  │     │     └─ ···
 │  │  │     └─ Composite{...} [id BF] 20
 │  │  │        ├─ ScalarFromTensor [id Z] 19
 │  │  │        │  └─ ···
 │  │  │        └─ Assert{msg=Could not broadcast dimensions} [id F] 9
 │  │  │           └─ ···
 │  │  └─ 1 [id N]
 │  └─ Composite{...}.1 [id BG] 5
 │     ├─ [[1]] [id BH]
 │     ├─ ExpandDims{axes=[0, 1]} [id BI] 3
 │     │  └─ Shape_i{1} [id M] 0
 │     │     └─ ···
 │     └─ [[1.]] [id D]
 └─ Composite{((-i0) / i1)} [id BJ] 26
    ├─ SpecifyShape [id R] 17
    │  └─ ···
    └─ Composite{...}.0 [id BG] 5
       └─ ···

Inner graphs:

Composite{Abs(maximum(Switch(EQ(i0, 1), (-1), i0), Switch(EQ(i1, 1), (-1), i1)))} [id G]
 ← Abs [id BK] 'o0'
    └─ maximum [id BL]
       ├─ Switch [id BM]
       │  ├─ EQ [id BN]
       │  │  ├─ i0 [id BO]
       │  │  └─ t8{1} [id BP]
       │  ├─ neg [id BQ] 't1'
       │  │  └─ 1 [id BR]
       │  └─ i0 [id BO]
       └─ Switch [id BS]
          ├─ EQ [id BT]
          │  ├─ i1 [id BU]
          │  └─ t8{1} [id BP]
          ├─ neg [id BQ] 't1'
          │  └─ ···
          └─ i1 [id BU]

Composite{...} [id K]
 ← AND [id BV] 'o0'
    ├─ OR [id BW]
    │  ├─ EQ [id BX]
    │  │  ├─ Switch [id BY] 't10'
    │  │  │  ├─ EQ [id BZ]
    │  │  │  │  ├─ i0 [id CA]
    │  │  │  │  └─ t8{1} [id BP]
    │  │  │  ├─ neg [id CB] 't15'
    │  │  │  │  └─ 1 [id BR]
    │  │  │  └─ i0 [id CA]
    │  │  └─ neg [id CB] 't15'
    │  │     └─ ···
    │  └─ EQ [id CC]
    │     ├─ Switch [id BY] 't10'
    │     │  └─ ···
    │     └─ Composite{Abs(maximum(Switch(EQ(i0, 1), (-1), i0), Switch(EQ(i1, 1), (-1), i1)))} [id CD] 't12'
    │        ├─ i0 [id CA]
    │        └─ i1 [id CE]
    └─ OR [id CF]
       ├─ EQ [id CG]
       │  ├─ Switch [id CH] 't0'
       │  │  ├─ EQ [id CI]
       │  │  │  ├─ i1 [id CE]
       │  │  │  └─ t8{1} [id BP]
       │  │  ├─ neg [id CB] 't15'
       │  │  │  └─ ···
       │  │  └─ i1 [id CE]
       │  └─ neg [id CB] 't15'
       │     └─ ···
       └─ EQ [id CJ]
          ├─ Switch [id CH] 't0'
          │  └─ ···
          └─ Composite{Abs(maximum(Switch(EQ(i0, 1), (-1), i0), Switch(EQ(i1, 1), (-1), i1)))} [id CD] 't12'
             └─ ···

Composite{((i0 + i1) / i2)} [id Q]
 ← true_div [id CK] 'o0'
    ├─ add [id CL]
    │  ├─ i0 [id CM]
    │  └─ i1 [id CN]
    └─ i2 [id CO]

Composite{cast{float64}((-i0))} [id U]
 ← Cast{float64} [id CP] 'o0'
    └─ neg [id CQ]
       └─ i0 [id CR]

Composite{Abs(maximum(Switch(EQ(i0, 1), (-1), i0), Switch(EQ(i1, 1), (-1), i1)))} [id Y]
 ← Abs [id CS] 'o0'
    └─ maximum [id CT]
       ├─ Switch [id CU]
       │  ├─ EQ [id CV]
       │  │  ├─ i0 [id CW]
       │  │  └─ t8{1} [id CX]
       │  ├─ neg [id CY] 't1'
       │  │  └─ 1 [id CZ]
       │  └─ i0 [id CW]
       └─ Switch [id DA]
          ├─ EQ [id DB]
          │  ├─ i1 [id DC]
          │  └─ t8{1} [id CX]
          ├─ neg [id CY] 't1'
          │  └─ ···
          └─ i1 [id DC]

Composite{...} [id BF]
 ← AND [id DD] 'o0'
    ├─ OR [id DE]
    │  ├─ EQ [id DF]
    │  │  ├─ Switch [id DG] 't6'
    │  │  │  ├─ EQ [id DH]
    │  │  │  │  ├─ i0 [id DI]
    │  │  │  │  └─ t8{1} [id CX]
    │  │  │  ├─ neg [id DJ] 't3'
    │  │  │  │  └─ 1 [id CZ]
    │  │  │  └─ i0 [id DI]
    │  │  └─ neg [id DJ] 't3'
    │  │     └─ ···
    │  └─ EQ [id DK]
    │     ├─ Switch [id DG] 't6'
    │     │  └─ ···
    │     └─ Composite{Abs(maximum(Switch(EQ(i0, 1), (-1), i0), Switch(EQ(i1, 1), (-1), i1)))} [id DL] 't16'
    │        ├─ i0 [id DI]
    │        └─ i1 [id DM]
    └─ OR [id DN]
       ├─ EQ [id DO]
       │  ├─ Switch [id DP] 't15'
       │  │  ├─ EQ [id DQ]
       │  │  │  ├─ i1 [id DM]
       │  │  │  └─ t8{1} [id CX]
       │  │  ├─ neg [id DJ] 't3'
       │  │  │  └─ ···
       │  │  └─ i1 [id DM]
       │  └─ neg [id DJ] 't3'
       │     └─ ···
       └─ EQ [id DR]
          ├─ Switch [id DP] 't15'
          │  └─ ···
          └─ Composite{Abs(maximum(Switch(EQ(i0, 1), (-1), i0), Switch(EQ(i1, 1), (-1), i1)))} [id DL] 't16'
             └─ ···

Composite{...} [id BG]
 ← sqrt [id DS] 'o0'
    └─ add [id DT]
       ├─ i0 [id DU]
       └─ i1 [id DV]
 ← add [id DW] 'o1'
    ├─ i2 [id DX]
    ├─ sqrt [id DS] 'o0'
    │  └─ ···
    └─ i1 [id DV]

Composite{((-i0) / i1)} [id BJ]
 ← true_div [id DY] 'o0'
    ├─ neg [id DZ]
    │  └─ i0 [id EA]
    └─ i1 [id EB]

Composite{Abs(maximum(Switch(EQ(i0, 1), (-1), i0), Switch(EQ(i1, 1), (-1), i1)))} [id CD]
 ← Abs [id BK] 'o0'
    └─ ···

Composite{Abs(maximum(Switch(EQ(i0, 1), (-1), i0), Switch(EQ(i1, 1), (-1), i1)))} [id DL]
 ← Abs [id CS] 'o0'
    └─ ···

After

Add [id A] 19
 ├─ Split{2}.0 [id B] 9
 │  ├─ Alloc [id C] 6
 │  │  ├─ [[1.]] [id D]
 │  │  ├─ Shape_i{0} [id E] 1
 │  │  │  └─ x [id F]
 │  │  └─ Add [id G] 5
 │  │     ├─ Shape_i{1} [id H] 0
 │  │     │  └─ x [id F]
 │  │     └─ 1 [id I]
 │  ├─ 1 [id J]
 │  └─ MakeVector{dtype='int64'} [id K] 2
 │     ├─ Shape_i{1} [id H] 0
 │     │  └─ ···
 │     └─ 1 [id I]
 ├─ Composite{((i0 + i1) / i2)} [id L] 17
 │  ├─ SpecifyShape [id M] 12
 │  │  ├─ Split{2}.1 [id B] 9
 │  │  │  └─ ···
 │  │  ├─ NoneConst{None} [id N]
 │  │  └─ 1 [id J]
 │  ├─ BroadcastTo [id O] 16
 │  │  ├─ Composite{cast{float64}((-i0))} [id P] 11
 │  │  │  └─ ExpandDims{axes=[0, 1]} [id Q] 8
 │  │  │     └─ Add [id G] 5
 │  │  │        └─ ···
 │  │  ├─ Assert{msg=Could not dynamically broadcast dimensions.} [id R] 15
 │  │  │  ├─ Subtensor{i} [id S] 13
 │  │  │  │  ├─ SetSubtensor{i} [id T] 10
 │  │  │  │  │  ├─ MakeVector{dtype='int64'} [id U] 7
 │  │  │  │  │  │  ├─ Shape_i{0} [id E] 1
 │  │  │  │  │  │  │  └─ ···
 │  │  │  │  │  │  └─ Add [id G] 5
 │  │  │  │  │  │     └─ ···
 │  │  │  │  │  ├─ 1 [id I]
 │  │  │  │  │  └─ 1 [id V]
 │  │  │  │  └─ 0 [id W]
 │  │  │  └─ Eq [id X] 14
 │  │  │     ├─ Subtensor{i} [id S] 13
 │  │  │     │  └─ ···
 │  │  │     └─ Shape_i{0} [id E] 1
 │  │  │        └─ ···
 │  │  └─ 1 [id I]
 │  └─ Composite{...}.1 [id Y] 4
 │     ├─ [[1]] [id Z]
 │     ├─ ExpandDims{axes=[0, 1]} [id BA] 3
 │     │  └─ Shape_i{1} [id H] 0
 │     │     └─ ···
 │     └─ [[1.]] [id D]
 └─ Composite{((-i0) / i1)} [id BB] 18
    ├─ SpecifyShape [id M] 12
    │  └─ ···
    └─ Composite{...}.0 [id Y] 4
       └─ ···

Inner graphs:

Composite{((i0 + i1) / i2)} [id L]
 ← true_div [id BC] 'o0'
    ├─ add [id BD]
    │  ├─ i0 [id BE]
    │  └─ i1 [id BF]
    └─ i2 [id BG]

Composite{cast{float64}((-i0))} [id P]
 ← Cast{float64} [id BH] 'o0'
    └─ neg [id BI]
       └─ i0 [id BJ]

Composite{...} [id Y]
 ← sqrt [id BK] 'o0'
    └─ add [id BL]
       ├─ i0 [id BM]
       └─ i1 [id BN]
 ← add [id BO] 'o1'
    ├─ i2 [id BP]
    ├─ sqrt [id BK] 'o0'
    │  └─ ···
    └─ i1 [id BN]

Composite{((-i0) / i1)} [id BB]
 ← true_div [id BQ] 'o0'
    ├─ neg [id BR]
    │  └─ i0 [id BS]
    └─ i1 [id BT]

@ricardoV94
Copy link
Member

Possibly related to #330

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 15, 2023

Nothing against the PR itself but note that these complicated graphs are likely to go away if you provide static shape for the inputs, which is the case with 99% of PyMC models.

@codecov-commenter
Copy link

codecov-commenter commented Jun 15, 2023

Codecov Report

Merging #345 (53c26d1) into main (f4536c3) will decrease coverage by 0.01%.
The diff coverage is 95.23%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #345      +/-   ##
==========================================
- Coverage   80.40%   80.39%   -0.01%     
==========================================
  Files         156      156              
  Lines       45416    45398      -18     
  Branches    11114    11106       -8     
==========================================
- Hits        36515    36498      -17     
+ Misses       6695     6694       -1     
  Partials     2206     2206              
Impacted Files Coverage Δ
pytensor/tensor/extra_ops.py 88.98% <95.23%> (-0.29%) ⬇️

... and 4 files with indirect coverage changes

@aseyboldt
Copy link
Member Author

@ricardoV94 True. I was mostly just hoping we could get the changes like in #149 piece by piece, it does seem pretty tricky to do it all in one go...

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, left some nitpick suggestions.

# Get the shapes in this dimension that are not definitively
# broadcastable (i.e. not symbolically known to be broadcastable)
# Get the shapes in this dimension that are not broadcastable
# (i.e. not symbolically known to be broadcastable)
maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the "maybe"?

Suggested change
maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]
non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]

)
else:
bcast_dim = const_nt_shape_var
assert_op = Assert("Could not dynamically broadcast dimensions.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe create this Op once at the module level?

)
else:
bcast_dim = const_nt_shape_var
assert_op = Assert("Could not dynamically broadcast dimensions.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert_op = Assert("Could not dynamically broadcast dimensions.")
assert_op = Assert("Could not broadcast dimensions. If a variable should broadcast use `specify_shape` to inform PyTensor.")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it even more specific:
"Could not broadcast dimensions. Broadcasting is only allowed along "
"axes that have a statically known length 1. Use specify_shape to "
"inform PyTensor of a known shape."

continue

# Add assert that all remaining shapes are equal
use_scalars = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the use_scalars block?

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 16, 2023

One thing we could consider in the future is a DynamicBroadcastOp that explicitly allows for it (and does whatever it takes to get the gradients to work out). I would wait on that because I suspect there is no real demand for that.

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 16, 2023

The scalar expressions seem a bit more reasonable to me (not tensors and no allocations), but since most rewrites only work on Elemwise, this can prevent a lot of optimizations right now.

I think the only thing that's needed is to allow the FusionRewrite to work on 0d tensors (right now I think it requires that ndim > 1).

Fusing chains of 0d tensors inside a Composite would have the same effect as using scalars in the graph with a small overhead from Elemwise (but which also takes care of the otherwise needed ScalarFromTensor and TensorFromScalar)?

Once we have a 0d Elemwise composite it's also trivial to replace it by the scalar case if that's more efficient.

Created issue: #349

@aseyboldt
Copy link
Member Author

I don't think I understand what you mean with the FusionRewrite. If we used the code that never produces an Elemwise in the first place, then how would Elemwise Fusion matter?

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 16, 2023

If you used the code with tensor.and_ (not scalar.and_ nor tensor.all), the FusionRewrite could then merge them all in a single Composite (once we allow it to optimize 0d Elemwises). If PyTensor wants to do other optimizations first it could also because we started with TensorVariables and not Scalars.

@aseyboldt
Copy link
Member Author

I fixed the suggestions above.

If you used the code with tensor.and_ (not scalar.and_ nor tensor.all), the FusionRewrite could then merge them all in a single Composite (once we allow it to optimize 0d Elemwises). If PyTensor wants to do other optimizations first it could also because we started with TensorVariables and not Scalars.

But if we work on tensors anyway, I don't think that would be better than the all(eq(...)) code, would it? The point was more that I would like not to have any tensors in here at all. Since Shape{i} unfortunately returns a tensor this is might be a bit pointless anyway though...

@ricardoV94
Copy link
Member

I'm gonna mark this as major so that we are forced to manually bump the dependency on pymc and run the test suite (in case it breaks something there)

@ricardoV94 ricardoV94 merged commit df4183d into pymc-devs:main Jun 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants