-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement new Loop and Scan operators #191
base: main
Are you sure you want to change the base?
Conversation
assert input_state.type == output_state.type | ||
|
||
|
||
class Loop(Op): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Add mixin HasInnerGraph
so that we can see the inner graph in debug_print
76a9b4c
to
f2a2c03
Compare
Wouldn't a fill loop look something like this?
(and very much need inplace rewrites for good performance...)
Good question...
I think one rewrite that get's easier with the if-else-do-while approach would be loop invariant code motion. Let's say we have a loop like
we could move
Well, I guess we really need those :-) |
Why can't we move it even if it's empty? Sum works fine. Are you worried about Ops that we know will fail with empty inputs? About the filling Ops, yeah I don't see it as a problem anymore. Just felt awkward to create the dummy input when translating from scan to loop. I am okay with it now |
That would change the behavior. If we move it out and don't prevent it from being executed, things could fail for instance if there's an assert somewhere, or some other error happens during it's evaluation. Also, it could be potentially very costly (let's say "solve an ode"). (somehow I accidentally edited your comment instead of writing a new one, no clue how, but fixed now) |
In my last commit, sequences are demoted from special citizens to just another constant input in the I have reverted converting the constant inputs to dummies before calling the user function, which allows the example in the jacobian documentation to work, including the one that didn't work before (because both are now equivalent under the hood :)) https://pytensor.readthedocs.io/en/latest/tutorial/gradients.html#computing-the-jacobian I reverted too much, and I still need to pass dummy inputs as the state variables, since it doesn't make sense for the user function to introspect the graph beyond the initial state (since it's only valid for the initial state) |
7bcd42c
to
6c953b3
Compare
return last_states[1:], traces[1:] | ||
|
||
|
||
def map( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about subclassing Scan into
Map(Scan)
Reduce(Scan)
Filter(Scan)
It will be easier to dispatch into optimized implementations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do that later, not convinced we need that yet
if init_state is None: | ||
# next_state may reference idx. We replace that by the initial value, | ||
# so that the shape of the dummy init state does not depend on it. | ||
[next_state] = clone_replace( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not graph_replace or using memo for FunctionGraph(memo={symbolic_idx: idx}) (here)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is that better?
Added a simple JAX dispatcher, works in the few examples I tried |
# explicitly triggers the optimization of the inner graphs of Scan? | ||
update_fg = op.update_fg.clone() | ||
rewriter = get_mode("JAX").optimizer | ||
rewriter(update_fg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gives an annoying Supervisor Feature missing warning... gotta clean that up
|
||
print(max_iters) | ||
states, traces = jax.lax.scan( | ||
scan_fn, init=list(states), xs=None, length=max_iters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Todo: Check we are not missing performance by not having explicit sequences.
Todo: When there are multiple sequences PyTensor defines n_steps as the shortest sequence. JAX should be able to handle this, but if not we could consider not allowing sequences/n_steps with different lengths in the Pytensor scan.
Then we could pass a single shape as n_steps after asserting they are the same?
I just found out about TypedLists in PyTensor. That should allow us to trace any type of Variables, including RandomTypes 🤯 Pushed a couple of commits that rely on this. |
5f15c5e
to
32b4fb4
Compare
Co-authored-by: Adrian Seyboldt <adrian.seyboldt@gmail.com>
Co-authored-by: Adrian Seyboldt <adrian.seyboldt@gmail.com>
This was not possible prior to use of TypedListType for non TensorVariable sequences, as it would otherwise not be possible to represent indexing of last sequence state, which is needed e.g., for shared random generator updates.
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #191 +/- ##
==========================================
+ Coverage 80.03% 80.09% +0.06%
==========================================
Files 170 173 +3
Lines 45086 45435 +349
Branches 9603 9694 +91
==========================================
+ Hits 36085 36392 +307
- Misses 6789 6818 +29
- Partials 2212 2225 +13
|
This Discourse thread is a great reminder of several Scan design issues that are fixed here: https://discourse.pymc.io/t/hitting-a-weird-error-to-do-with-rngs-in-scan-in-a-custom-function-inside-a-potential/13151/15 Namely:
|
Related to #189
This PR implements a new low level
Loop
Op
which can be easily transpiled toNumba
(the Python perform method takes 9 lines, yay to not having to supportC
in the future).It also implements a new higher level
Scan
Op
which returns as outputs the last states + intermediate states of a looping operation. ThisOp
cannot be directly evaluated, and must be rewritten as aLoop
Op
in Python/Numba backends. For theJAX
backend it's probably fine to transpile directly from this representation into alax.scan
as the signatures are pretty much identical. That was not done in this PR.The reason for the two types of outputs, is that they are useful in different contexts. Final states are sometimes all one needs, whereas intermediate states are generally needed for back propagation (not implemented yet). This allows us to choose which one (or both) of the outputs we want during compilation, without having to do complicated graph analysis.
The existing
save_mem_new_scan
is used to convert a general scan into aloop
that only returns the last computed state. It's... pretty complicated (although it also covers cases where more than 1 but less than all steps being requested, but OTOH it can't handle while loops #178):pytensor/pytensor/scan/rewriting.py
Line 1119 in 8ad3317
Taking that as a reference I would say the new conversion rewrite from Scan to Loop is much much simpler. Most of it is boilerplate code for defining the right trace inputs and new FunctionGraph
Both
Ops
expect aFunctionGraph
as input. This should probably be created by a user-facing helper that accepts a callable like scan does now.That was not done yet, as I first wanted to discuss the general design.DoneDesign issues
1. The current implementation of Loop assumes there are as many states as outputs of the inner function. This does not make sense for mapping or "filling" operations such as filling a tensor with random values. In one of the tests I had to create a dummyx
input to accommodate this restriction. Should we useNoneConst
to represent outputs that don't feed into the next state? I think there is something similar being done with the oldScan
where theoutputs_info
must explicitly beNone
in these cases.Scan and Loop can now take random types as inputs (scan can't return it as a sequence). This makes random seeding much more explicit compared to the old Scan, which was based on default updates of shared variables. However it highlights the awkwardness of the random API when we want to access the next random state. Should we perhaps add a
return_rng_update
to__call__
, so that it doesn't hide the next rng state output?Do we want to be able to represent empty Loop / Sequences? If so, how should we go about that?
IfElse
is one option, but perhaps it would be nice to represent it in the sameLoop
Op
?What do we want to do in terms of inplacing optimizations?
TODO
If people are on board with the approach
mode
,truncate_gradient
,reverse
and so on)trace[-1]
by the first set of outputs (final state). That way we can keep the old API, while retaining the benefit of doing while Scans without tracing when it's not needed.