-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add OpFromGraph
wrapper around alloc_diag
#915
Add OpFromGraph
wrapper around alloc_diag
#915
Conversation
…other rewrites that need it have fired.
02fc20c
to
afe2a65
Compare
I guess my little |
We probably want to stop wrapping the gradient in
|
Is there a way to just forward all the props there for now? Or do you want me to look into the better solution |
I think we should use the BaseClass. Don't want to call AllocDiag to the gradient as well do we? |
OpFromGraph wasn't really designed originally to be subclassed so we are doing something new these days, it's not surprising we need to rethink some of the aspects |
It also has some sketchy forwarding of "kwargs" already, but those are the compilation kwargs, which is also an odd design (we don't need to tackle that now) |
This reverts commit f6f27ec.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some specialized comments, that are only relevant after the big picture question:
Why do we need to inline it specifically for JAX?
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #915 +/- ##
==========================================
+ Coverage 81.04% 81.49% +0.45%
==========================================
Files 170 176 +6
Lines 46962 46925 -37
Branches 11507 11428 -79
==========================================
+ Hits 38059 38242 +183
+ Misses 6694 6500 -194
+ Partials 2209 2183 -26
|
We don't need to only inline it for JAX. I did it only for JAX because we must do it for JAX. Nothing stops us from doing it for all backends, and for all |
Why do we need to do it for JAX at all? |
Otherwise we need a jax dispatch for |
JAX should have a dispatch for OpFromGraph already, does it not? |
We do not!!! That's silly, should be as simple as pytensor/pytensor/link/numba/dispatch/basic.py Lines 429 to 452 in e2f9cb8
|
Ok I'll implement that instead |
Numba segfaulting in the same way I observed in the |
Which test is segfaulting? |
Looks like case 3 of |
Can you check how the optimized graph looks like after and before this PR (you have to do this before the linking part starts and segfaults)? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tiny suggestions, looks pretty good.
Follow up on the numba segfault: it appears to happen because AdvancedSetSubtensor [id A] 3
├─ Alloc [id B] 2
│ ├─ 0.0 [id C]
│ ├─ Shape_i{0} [id D] 0
│ │ └─ <Vector(float64, shape=(?,))> [id E]
│ └─ Shape_i{0} [id D] 0
│ └─ ···
├─ <Vector(float64, shape=(?,))> [id E]
├─ ARange{dtype='int64'} [id F] 1
│ ├─ 0 [id G]
│ ├─ Shape_i{0} [id D] 0
│ │ └─ ···
│ └─ 1 [id H]
└─ ARange{dtype='int64'} [id F] 1
└─ ··· And here's the graph I get when
So the OFG isn't written away. Given what I was seeing in the Or maybe not? When I make the AllocDiag{axis1=0, axis2=1} [id A] 0
└─ <Vector(float64, shape=(?,))> [id B]
Inner graphs:
AllocDiag{axis1=0, axis2=1} [id A]
← AdvancedSetSubtensor [id C]
├─ Alloc [id D]
│ ├─ 0.0 [id E]
│ ├─ Shape_i{0} [id F]
│ │ └─ *0-<Vector(float64, shape=(?,))> [id G]
│ └─ Shape_i{0} [id F]
│ └─ ···
├─ *0-<Vector(float64, shape=(?,))> [id G]
├─ ARange{dtype='int64'} [id H]
│ ├─ 0 [id I]
│ ├─ Shape_i{0} [id F]
│ │ └─ ···
│ └─ 1 [id J]
└─ ARange{dtype='int64'} [id H]
└─ ··· |
That maximum and scalar_maximum are a bit suspect. Why is one a ScalarOp and not the other? Can you dprint with types and destroy/view maps? |
without AllocDiag{axis1=0, axis2=1} [id A] <Matrix(float64, shape=(?, ?))> 0
├─ <Vector(float64, shape=(?,))> [id B] <Vector(float64, shape=(?,))>
└─ 0 [id C] <Scalar(int8, shape=())>
Inner graphs:
AllocDiag{axis1=0, axis2=1} [id A]
← AdvancedSetSubtensor [id D] <Matrix(float64, shape=(?, ?))> d={0: [0]}
├─ Alloc [id E] <Matrix(float64, shape=(?, ?))>
│ ├─ 0.0 [id F] <Scalar(float64, shape=())>
│ ├─ Composite{(i1 + abs(i0))} [id G] <Scalar(int64, shape=())> d={0: [1]}
│ │ ├─ *1-<Scalar(int8, shape=())> [id H] <Scalar(int8, shape=())>
│ │ └─ Shape_i{0} [id I] <Scalar(int64, shape=())>
│ │ └─ *0-<Vector(float64, shape=(?,))> [id J] <Vector(float64, shape=(?,))>
│ └─ Composite{(i1 + abs(i0))} [id G] <Scalar(int64, shape=())> d={0: [1]}
│ └─ ···
├─ *0-<Vector(float64, shape=(?,))> [id J] <Vector(float64, shape=(?,))>
├─ Add [id K] <Vector(int64, shape=(?,))>
│ ├─ ARange{dtype='int64'} [id L] <Vector(int64, shape=(?,))>
│ │ ├─ 0 [id M] <Scalar(int8, shape=())>
│ │ ├─ Shape_i{0} [id I] <Scalar(int64, shape=())>
│ │ │ └─ ···
│ │ └─ 1 [id N] <Scalar(int8, shape=())>
│ └─ Composite{maximum(0, (-i0))} [id O] <Vector(int8, shape=(1,))>
│ └─ ExpandDims{axis=0} [id P] <Vector(int8, shape=(1,))> v={0: [0]}
│ └─ *1-<Scalar(int8, shape=())> [id H] <Scalar(int8, shape=())>
└─ Add [id Q] <Vector(int64, shape=(?,))> d={0: [0]}
├─ ARange{dtype='int64'} [id L] <Vector(int64, shape=(?,))>
│ └─ ···
└─ Scalarmaximum [id R] <Vector(int8, shape=(1,))> d={0: [1]}
├─ [0] [id S] <Vector(int8, shape=(1,))>
└─ ExpandDims{axis=0} [id P] <Vector(int8, shape=(1,))> v={0: [0]}
└─ ···
Composite{(i1 + abs(i0))} [id G] d={0: [1]}
← add [id T] <int64> 'o0'
├─ i1 [id U] <int64>
└─ Abs [id V] <int8>
└─ i0 [id W] <int8>
Composite{maximum(0, (-i0))} [id O]
← maximum [id X] <int8> 'o0'
├─ 0 [id Y] <int8>
└─ neg [id Z] <int8>
└─ i0 [id BA] <int8> with AdvancedSetSubtensor [id A] <Matrix(float64, shape=(?, ?))> d={0: [0]} 3
├─ Alloc [id B] <Matrix(float64, shape=(?, ?))> 2
│ ├─ 0.0 [id C] <Scalar(float64, shape=())>
│ ├─ Shape_i{0} [id D] <Scalar(int64, shape=())> 0
│ │ └─ <Vector(float64, shape=(?,))> [id E] <Vector(float64, shape=(?,))>
│ └─ Shape_i{0} [id D] <Scalar(int64, shape=())> 0
│ └─ ···
├─ <Vector(float64, shape=(?,))> [id E] <Vector(float64, shape=(?,))>
├─ ARange{dtype='int64'} [id F] <Vector(int64, shape=(?,))> 1
│ ├─ 0 [id G] <Scalar(int8, shape=())>
│ ├─ Shape_i{0} [id D] <Scalar(int64, shape=())> 0
│ │ └─ ···
│ └─ 1 [id H] <Scalar(int8, shape=())>
└─ ARange{dtype='int64'} [id F] <Vector(int64, shape=(?,))> 1
└─ ··· |
That ScalarMaximum shape=(1,) seems to be trying to write on what's ultimate a shape=() array. Could that be a problem? Wonder if it's a problem in the C backend. Would need something like |
Or can you exclude inplace when optimizing the inner OFG to see if it's related to the problem? |
I'm curious why the two |
I'm not sure they are, maybe it's one of those Elemwise with a weird custom name. The negated one gets fused with the negation |
You can also check what happens with a switch instead of maximum |
No problem in the C backend using Doing Switch also segfaults, as does clip. |
Honestly my inclination now is to just switch Edit: Or just change the test to use |
Are you excluding inplace "inside" the OFG? That's what I meant for testing inplace The mode example I gave was to see if the same problem would happen in the C backend. Can't just test with FAST_COMPILE because that skips C We can make it static no need to fix in this PR. It will come back to us sooner or later. |
* Add `OpFromGraph` wrapper around `alloc_diag` * Remove depreciated `AllocDiag` `Op`, rename `AllocDiag2 -> AllocDiag` * Set `inline = False` * Add rewrite to inline all `OpFromGraph` `Op`s * Add `is_zero_offset` helper to `Eye` * Add `is_left_expand_dims` and `is_right_expand_dims` attributes to `DimShuffle` * Seed `test_local_lift_through_linalg` test
Description
The relatively complex graph generated by
pt.diag(x)
wherex.ndim == 1
is causing problems for rewrites like #860 and #782. We were using a pattern rewriter, but this is proving to be quite brittle. It would be much easier if there were a singleOp
we could just look for. This PR introduces such an Op,AllocDiag2
, as an OpFromGraph wrapper around the existingalloc_diag
function.Related Issue
Checklist
Type of change