Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpFromGraph wrapper around alloc_diag #915

Merged
merged 33 commits into from
Jul 18, 2024

Conversation

jessegrabowski
Copy link
Member

Description

The relatively complex graph generated by pt.diag(x) where x.ndim == 1 is causing problems for rewrites like #860 and #782. We were using a pattern rewriter, but this is proving to be quite brittle. It would be much easier if there were a single Op we could just look for. This PR introduces such an Op, AllocDiag2, as an OpFromGraph wrapper around the existing alloc_diag function.

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@jessegrabowski
Copy link
Member Author

I guess my little __props__ thing doesn't work the way I want it to, because these arguments don't get propagated into make_node then the OpFromGraph needs to be rebuilt (I guess that's why the scan tests are failing). I think scan rewrites happen first, so I can't move up the rewrite that eliminates the OpFromGraph. I think the approach here is fundamentally flawed, so need some guidance on how to proceed.

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 10, 2024

We probably want to stop wrapping the gradient in type(self) Since we are using subclass to target specialized forms, we don't want to mistakenly target the gradients. Perhaps the gradient should always use the base OpFromGraph for wrapping? And maybe this wrapping the gradient shouldn't be the default anyway.

>       lop_op = type(self)(
            inputs=inner_inputs + connected_inner_outputs + connected_output_grads,
            outputs=connected_input_grads,
            inline=self.is_inline,
            name=(None if self.name is None else f"{self.name}_LOp"),
            # TODO: We can be eager here and exclude unused inputs in the OFG
            on_unused_input="ignore",
        )
E       TypeError: AllocDiag2.__init__() missing 3 required keyword-only arguments: 'offset', 'axis1', and 'axis2'

@jessegrabowski
Copy link
Member Author

Is there a way to just forward all the props there for now? Or do you want me to look into the better solution

@ricardoV94
Copy link
Member

Is there a way to just forward all the props there for now? Or do you want me to look into the better solution

I think we should use the BaseClass. Don't want to call AllocDiag to the gradient as well do we?

@ricardoV94
Copy link
Member

OpFromGraph wasn't really designed originally to be subclassed so we are doing something new these days, it's not surprising we need to rethink some of the aspects

@ricardoV94
Copy link
Member

It also has some sketchy forwarding of "kwargs" already, but those are the compilation kwargs, which is also an odd design (we don't need to tackle that now)

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some specialized comments, that are only relevant after the big picture question:

Why do we need to inline it specifically for JAX?

pytensor/link/jax/dispatch/tensor_basic.py Outdated Show resolved Hide resolved
pytensor/link/jax/dispatch/tensor_basic.py Outdated Show resolved Hide resolved
pytensor/link/jax/dispatch/tensor_basic.py Outdated Show resolved Hide resolved
Copy link

codecov bot commented Jul 12, 2024

Codecov Report

Attention: Patch coverage is 97.22222% with 2 lines in your changes missing coverage. Please review.

Project coverage is 81.49%. Comparing base (05d376f) to head (56a3ffe).
Report is 152 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/rewriting/linalg.py 87.50% 1 Missing and 1 partial ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #915      +/-   ##
==========================================
+ Coverage   81.04%   81.49%   +0.45%     
==========================================
  Files         170      176       +6     
  Lines       46962    46925      -37     
  Branches    11507    11428      -79     
==========================================
+ Hits        38059    38242     +183     
+ Misses       6694     6500     -194     
+ Partials     2209     2183      -26     
Files with missing lines Coverage Δ
pytensor/compile/builders.py 88.66% <100.00%> (+0.24%) ⬆️
pytensor/link/jax/dispatch/basic.py 80.76% <100.00%> (+4.20%) ⬆️
pytensor/tensor/basic.py 91.38% <100.00%> (+2.97%) ⬆️
pytensor/tensor/elemwise.py 88.38% <100.00%> (-0.04%) ⬇️
pytensor/tensor/rewriting/ofg.py 100.00% <100.00%> (ø)
pytensor/tensor/rewriting/linalg.py 90.78% <87.50%> (+1.70%) ⬆️

... and 79 files with indirect coverage changes

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jul 12, 2024

We don't need to only inline it for JAX. I did it only for JAX because we must do it for JAX. Nothing stops us from doing it for all backends, and for all OpFromGraphs. Is there a benefit to doing so?

@ricardoV94
Copy link
Member

Why do we need to do it for JAX at all?

@jessegrabowski
Copy link
Member Author

Otherwise we need a jax dispatch for AllocDiag specifically

@ricardoV94
Copy link
Member

JAX should have a dispatch for OpFromGraph already, does it not?

@ricardoV94
Copy link
Member

We do not!!! That's silly, should be as simple as

@numba_funcify.register(OpFromGraph)
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
_ = kwargs.pop("storage_map", None)
# Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
# The C-code defers it to the make_thunk phase
NUMBA.optimizer(op.fgraph)
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
if len(op.fgraph.outputs) == 1:
@numba_njit
def opfromgraph(*inputs):
return fgraph_fn(*inputs)[0]
else:
@numba_njit
def opfromgraph(*inputs):
return fgraph_fn(*inputs)
return opfromgraph

@jessegrabowski
Copy link
Member Author

Ok I'll implement that instead

@jessegrabowski
Copy link
Member Author

Numba segfaulting in the same way I observed in the pad rewrite. I think it's related to the symbolic offset, because this test was passing before. Maybe there's an illegal indexer in the generated code?

pytensor/compile/mode.py Outdated Show resolved Hide resolved
pytensor/tensor/rewriting/linalg.py Show resolved Hide resolved
pytensor/tensor/rewriting/linalg.py Outdated Show resolved Hide resolved
pytensor/tensor/rewriting/linalg.py Outdated Show resolved Hide resolved
pytensor/tensor/rewriting/linalg.py Outdated Show resolved Hide resolved
pytensor/tensor/rewriting/ofg.py Outdated Show resolved Hide resolved
tests/tensor/rewriting/test_linalg.py Outdated Show resolved Hide resolved
tests/tensor/rewriting/test_ofg.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

Numba segfaulting in the same way I observed in the pad rewrite. I think it's related to the symbolic offset, because this test was passing before. Maybe there's an illegal indexer in the generated code?

Which test is segfaulting?

@jessegrabowski
Copy link
Member Author

Looks like case 3 of test_tensor_basic.py::test_ExtractDiag

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 17, 2024

Looks like case 3 of test_tensor_basic.py::test_ExtractDiag

Can you check how the optimized graph looks like after and before this PR (you have to do this before the linking part starts and segfaults)?

pytensor/tensor/rewriting/linalg.py Outdated Show resolved Hide resolved
pytensor/tensor/rewriting/ofg.py Outdated Show resolved Hide resolved
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tiny suggestions, looks pretty good.

pytensor/tensor/elemwise.py Outdated Show resolved Hide resolved
pytensor/tensor/elemwise.py Show resolved Hide resolved
@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jul 18, 2024

Follow up on the numba segfault: it appears to happen because fast_run optimizations are omitted from the numba_mode used by compare_numba_and_py. Here's the graph I get when mode="NUMBA":

AdvancedSetSubtensor [id A] 3
 ├─ Alloc [id B] 2
 │  ├─ 0.0 [id C]
 │  ├─ Shape_i{0} [id D] 0
 │  │  └─ <Vector(float64, shape=(?,))> [id E]
 │  └─ Shape_i{0} [id D] 0
 │     └─ ···
 ├─ <Vector(float64, shape=(?,))> [id E]
 ├─ ARange{dtype='int64'} [id F] 1
 │  ├─ 0 [id G]
 │  ├─ Shape_i{0} [id D] 0
 │  │  └─ ···
 │  └─ 1 [id H]
 └─ ARange{dtype='int64'} [id F] 1
    └─ ···

And here's the graph I get when mode = numba_mode:

AllocDiag{axis1=0, axis2=1} [id A]
 ← AdvancedSetSubtensor [id D]
    ├─ Alloc [id E]
    │  ├─ 0.0 [id F]
    │  ├─ Composite{(i1 + abs(i0))} [id G]
    │  │  ├─ *1-<Scalar(int8, shape=())> [id H]
    │  │  └─ Shape_i{0} [id I]
    │  │     └─ *0-<Vector(float64, shape=(?,))> [id J]
    │  └─ Composite{(i1 + abs(i0))} [id G]
    │     └─ ···
    ├─ *0-<Vector(float64, shape=(?,))> [id J]
    ├─ Add [id K]
    │  ├─ ARange{dtype='int64'} [id L]
    │  │  ├─ 0 [id M]
    │  │  ├─ Shape_i{0} [id I]
    │  │  │  └─ ···
    │  │  └─ 1 [id N]
    │  └─ Composite{maximum(0, (-i0))} [id O]
    │     └─ ExpandDims{axis=0} [id P]
    │        └─ *1-<Scalar(int8, shape=())> [id H]
    └─ Add [id Q]
       ├─ ARange{dtype='int64'} [id L]
       │  └─ ···
       └─ Scalarmaximum [id R]
          ├─ [0] [id S]
          └─ ExpandDims{axis=0} [id P]
             └─ ···

Composite{(i1 + abs(i0))} [id G]
 ← add [id T] 'o0'
    ├─ i1 [id U]
    └─ Abs [id V]
       └─ i0 [id W]

Composite{maximum(0, (-i0))} [id O]
 ← maximum [id X] 'o0'
    ├─ 0 [id Y]
    └─ neg [id Z]
       └─ i0 [id BA]

So the OFG isn't written away. Given what I was seeing in the pad PR, my guess is this is something to do with how the inner graphs are being handled when an OFG is jitted.

Or maybe not? When I make the offset arg non-symbolic again, I get this graph (which does not segfault):

AllocDiag{axis1=0, axis2=1} [id A] 0
 └─ <Vector(float64, shape=(?,))> [id B]

Inner graphs:

AllocDiag{axis1=0, axis2=1} [id A]
 ← AdvancedSetSubtensor [id C]
    ├─ Alloc [id D]
    │  ├─ 0.0 [id E]
    │  ├─ Shape_i{0} [id F]
    │  │  └─ *0-<Vector(float64, shape=(?,))> [id G]
    │  └─ Shape_i{0} [id F]
    │     └─ ···
    ├─ *0-<Vector(float64, shape=(?,))> [id G]
    ├─ ARange{dtype='int64'} [id H]
    │  ├─ 0 [id I]
    │  ├─ Shape_i{0} [id F]
    │  │  └─ ···
    │  └─ 1 [id J]
    └─ ARange{dtype='int64'} [id H]
       └─ ···

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 18, 2024

That maximum and scalar_maximum are a bit suspect. Why is one a ScalarOp and not the other?

Can you dprint with types and destroy/view maps?

@jessegrabowski
Copy link
Member Author

without FAST_RUN :

AllocDiag{axis1=0, axis2=1} [id A] <Matrix(float64, shape=(?, ?))> 0
 ├─ <Vector(float64, shape=(?,))> [id B] <Vector(float64, shape=(?,))>
 └─ 0 [id C] <Scalar(int8, shape=())>

Inner graphs:

AllocDiag{axis1=0, axis2=1} [id A]
 ← AdvancedSetSubtensor [id D] <Matrix(float64, shape=(?, ?))> d={0: [0]}
    ├─ Alloc [id E] <Matrix(float64, shape=(?, ?))>
    │  ├─ 0.0 [id F] <Scalar(float64, shape=())>
    │  ├─ Composite{(i1 + abs(i0))} [id G] <Scalar(int64, shape=())> d={0: [1]}
    │  │  ├─ *1-<Scalar(int8, shape=())> [id H] <Scalar(int8, shape=())>
    │  │  └─ Shape_i{0} [id I] <Scalar(int64, shape=())>
    │  │     └─ *0-<Vector(float64, shape=(?,))> [id J] <Vector(float64, shape=(?,))>
    │  └─ Composite{(i1 + abs(i0))} [id G] <Scalar(int64, shape=())> d={0: [1]}
    │     └─ ···
    ├─ *0-<Vector(float64, shape=(?,))> [id J] <Vector(float64, shape=(?,))>
    ├─ Add [id K] <Vector(int64, shape=(?,))>
    │  ├─ ARange{dtype='int64'} [id L] <Vector(int64, shape=(?,))>
    │  │  ├─ 0 [id M] <Scalar(int8, shape=())>
    │  │  ├─ Shape_i{0} [id I] <Scalar(int64, shape=())>
    │  │  │  └─ ···
    │  │  └─ 1 [id N] <Scalar(int8, shape=())>
    │  └─ Composite{maximum(0, (-i0))} [id O] <Vector(int8, shape=(1,))>
    │     └─ ExpandDims{axis=0} [id P] <Vector(int8, shape=(1,))> v={0: [0]}
    │        └─ *1-<Scalar(int8, shape=())> [id H] <Scalar(int8, shape=())>
    └─ Add [id Q] <Vector(int64, shape=(?,))> d={0: [0]}
       ├─ ARange{dtype='int64'} [id L] <Vector(int64, shape=(?,))>
       │  └─ ···
       └─ Scalarmaximum [id R] <Vector(int8, shape=(1,))> d={0: [1]}
          ├─ [0] [id S] <Vector(int8, shape=(1,))>
          └─ ExpandDims{axis=0} [id P] <Vector(int8, shape=(1,))> v={0: [0]}
             └─ ···

Composite{(i1 + abs(i0))} [id G] d={0: [1]}
 ← add [id T] <int64> 'o0'
    ├─ i1 [id U] <int64>
    └─ Abs [id V] <int8>
       └─ i0 [id W] <int8>

Composite{maximum(0, (-i0))} [id O]
 ← maximum [id X] <int8> 'o0'
    ├─ 0 [id Y] <int8>
    └─ neg [id Z] <int8>
       └─ i0 [id BA] <int8>

with FAST_RUN:

AdvancedSetSubtensor [id A] <Matrix(float64, shape=(?, ?))> d={0: [0]} 3
 ├─ Alloc [id B] <Matrix(float64, shape=(?, ?))> 2
 │  ├─ 0.0 [id C] <Scalar(float64, shape=())>
 │  ├─ Shape_i{0} [id D] <Scalar(int64, shape=())> 0
 │  │  └─ <Vector(float64, shape=(?,))> [id E] <Vector(float64, shape=(?,))>
 │  └─ Shape_i{0} [id D] <Scalar(int64, shape=())> 0
 │     └─ ···
 ├─ <Vector(float64, shape=(?,))> [id E] <Vector(float64, shape=(?,))>
 ├─ ARange{dtype='int64'} [id F] <Vector(int64, shape=(?,))> 1
 │  ├─ 0 [id G] <Scalar(int8, shape=())>
 │  ├─ Shape_i{0} [id D] <Scalar(int64, shape=())> 0
 │  │  └─ ···
 │  └─ 1 [id H] <Scalar(int8, shape=())>
 └─ ARange{dtype='int64'} [id F] <Vector(int64, shape=(?,))> 1
    └─ ···

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 18, 2024

That ScalarMaximum shape=(1,) seems to be trying to write on what's ultimate a shape=() array. Could that be a problem?

Wonder if it's a problem in the C backend. Would need something like Mode(linker=get_mode(”FAST_RUN").linker, optimizer=get_mode(”FAST_COMPILE"). optimizer).

@ricardoV94
Copy link
Member

Or can you exclude inplace when optimizing the inner OFG to see if it's related to the problem?

@jessegrabowski
Copy link
Member Author

I'm curious why the two maximum(0, offset) and maximum(0, -offset) are being treated so differently

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 18, 2024

I'm curious why the two maximum(0, offset) and maximum(0, -offset) are being treated so differently

I'm not sure they are, maybe it's one of those Elemwise with a weird custom name. The negated one gets fused with the negation

@ricardoV94
Copy link
Member

You can also check what happens with a switch instead of maximum

@jessegrabowski
Copy link
Member Author

No problem in the C backend using Mode(linker=get_mode("FAST_RUN").linker, optimizer=get_mode("FAST_COMPILE").optimizer).

Doing numba_mode.excluding('inplace') still segfaults.

Switch also segfaults, as does clip.

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jul 18, 2024

Honestly my inclination now is to just switch offset back to a constant, and add a TODO to make it symbolic. I want to get this merged to unblock @tanish1729 's GSOC work.

Edit: Or just change the test to use mode='NUMBA", which works fine. It's only when you use numba_mode = Mode( NumbaLinker(), opts.including("numba", "local_useless_unbatched_blockwise") ) that there's a problem. The default numba mode also includes "fast_run"

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 18, 2024

Are you excluding inplace "inside" the OFG? That's what I meant for testing inplace

The mode example I gave was to see if the same problem would happen in the C backend. Can't just test with FAST_COMPILE because that skips C

We can make it static no need to fix in this PR. It will come back to us sooner or later.

@jessegrabowski jessegrabowski merged commit cac9feb into pymc-devs:main Jul 18, 2024
58 of 59 checks passed
Ch0ronomato pushed a commit to Ch0ronomato/pytensor that referenced this pull request Aug 15, 2024
* Add `OpFromGraph` wrapper around `alloc_diag`

* Remove depreciated `AllocDiag` `Op`, rename `AllocDiag2 -> AllocDiag`

* Set `inline = False`

* Add rewrite to inline all `OpFromGraph` `Op`s

* Add `is_zero_offset` helper to `Eye`

* Add `is_left_expand_dims` and `is_right_expand_dims` attributes to `DimShuffle`

* Seed `test_local_lift_through_linalg` test
@jessegrabowski jessegrabowski deleted the alloc-diag-ofg branch October 8, 2024 12:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants