Replies: 4 comments 11 replies
-
What do you think about the implicit semantics we have now for "truncated" outputs. Right now if the input trace has This requires some overhead with rolling outputs and tracking the pointer. I think it may make sense (specially with taps), but I would like to make those at compilation time (if llvm is not clever enough to do it itself) and not part of the symbolic graph as is now. For instance it doesn't make sense that it's as hard to make an optimized loop as shown in ##174 (comment) |
Beta Was this translation helpful? Give feedback.
-
Thanks @aseyboldt . I’ll have to read this more in detail tomorrow, but from a first quick glance I can say that you don’t need to add another ifelse Op. the one in pytensor already does a lazy evaluation of each of the branches (at least according to its docs). |
Beta Was this translation helpful? Give feedback.
-
BTW the proposal is already 10x more appealing than the current. I suggest we start working on it sooner rather than later. @aseyboldt do you want to do it yourself or are you happy letting someone else tackle it? |
Beta Was this translation helpful? Give feedback.
-
The other thing we should consider is how to transpile to JAX. The IfElse maps well to The Loop can be converted to the jax while loop, but that is not differentiable by default. When we have an actual for loop and not while loop, should we write it as a jax scan, so that autodiff works? Also, why not fallback to a scan for autodiff in pytensor instead of raising as in the L_op example at the top? The biggest issue would be to avoid having a redundant loop and scan with the same inner function (minus the set_subtensor). That reminds me of another option which is for the loop primitive to have two types of outputs: last state and the intermediate states. During compilation we could get rid of the intermediate states if they aren't used anywhere. |
Beta Was this translation helpful? Give feedback.
-
The scan code in pytensor is pretty involved, and over time grew into something that's pretty hard to understand and work with.
In the last design meeting we've been discussing if it would maybe be a good idea to replace it by something completely new, re-designed from ground up.
This is a first attempt at how such a replacement might look like.
First, we add a new loop construct, that's hopefully quite simple and can represent arbitrary loops, not just scan-like loops. The (
Loop
class below). Only some loops easily allow for reverse mode autodiff, and we represent those loops as loops where we know a reverse loop.We then add
scan
as a function that builds a loop, and it's reverse from those building blocks.A (pseudocode-ish) implementation of this idea:
cc @ricardoV94 @Armavica @lucianopaz
Beta Was this translation helpful? Give feedback.
All reactions