-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prototype an explicit embedding API #1677
Conversation
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
c7769c7
to
fe5e404
Compare
One particular use-case that has come up for me with some regularity is scanning over statically sized ranges. We currently do this with Maybe an explicit helper function like |
@shoyer interesting. What if we just exposed a |
That would do it, though it's not entirely clear to me what the signature of scanned function should be in that case. I guess the two choices are either: def scan(f, init, length):
carry = init
ys = []
for i in range(length):
carry, y = f(carry, i)
ys.append(y)
return carry, np.stack(ys) or def scan(f, init, length):
carry = init
ys = []
for _ in range(length):
carry, y = f(carry)
ys.append(y)
return carry, np.stack(ys) The later seems perhaps a little cleaner (what if you don't need the iteration counter?) but now the type signature looks different, i..e., |
I was thinking of keeping the signature the same, but exposing the Concretely: def scan(f, init, xs, length=None):
length = length if length is not None else get_axis_length(xs)
carry = init
ys = []
for i in range(length):
x = get_slice(xs, i)
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys) where (EDIT: I didn't write the logic for raising an error when |
I was trying to figure out what So yes, that is obviously how things should work. In that case, let me suggest the following update to the pseudocode / pure-Python version of def scan(f, init, xs=None, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys) |
Before this change, the jaxpr ``` { lambda ; ; a. let b = add 2. 3. c = mul b a in [c] } ``` was not transposable, even though it is linear. The issue was that our transposition machinery assumed that in effect there were no constant subexpressions, and that constants appeared only as literals or as constvars. So this jaxpr would work: ``` { lambda ; ; a. let b = mul 5. a in [b] } ``` However, as in #1715, using `tie_in` to control the parts of the computation that get staged out to jaxprs and hence xla (e.g. to avoid materializing large constants) could create jaxprs with constant subexpressions (indeed that's the point!). This commit resolves the issue by adding a constant evaluation pass to `backward_pass` in ad.py. In the grad-of-jit case, this evaluation is staged out to XLA. We thought about having `tie_in` not appear in jaxprs (as in #1647), but that would remove our ability to stage computations out to the backward pass (e.g. to avoid materializing a large constant in the backward pass). The solution here is a stop-gap that we hope to be cleaned up a bit with #1677, which aims to provide an explicit embedding API to replace the use of `tie_in`. Co-authored-by: Dougal Maclaurin <dougalm@google.com>
Before this change, the jaxpr ``` { lambda ; ; a. let b = add 2. 3. c = mul b a in [c] } ``` was not transposable, even though it is linear. The issue was that our transposition machinery assumed that in effect there were no constant subexpressions, and that constants appeared only as literals or as constvars. So this jaxpr would work: ``` { lambda ; ; a. let b = mul 5. a in [b] } ``` However, as in #1715, using `tie_in` to control the parts of the computation that get staged out to jaxprs and hence xla (e.g. to avoid materializing large constants) could create jaxprs with constant subexpressions (indeed that's the point!). This commit resolves the issue by adding a constant evaluation pass to `backward_pass` in ad.py. In the grad-of-jit case, this evaluation is staged out to XLA. We thought about having `tie_in` not appear in jaxprs (as in #1647), but that would remove our ability to stage computations out to the backward pass (e.g. to avoid materializing a large constant in the backward pass). The solution here is a stop-gap that we hope to be cleaned up a bit with #1677, which aims to provide an explicit embedding API to replace the use of `tie_in`. Co-authored-by: Dougal Maclaurin <dougalm@google.com>
We've had some better ideas here! Closing this as likely not the right direction to take. |
Consider this program:
What gets staged out to XLA? You might be surprised! The jaxpr seen by the tracing machinery is essentially this:
The creation of
mask
doesn't get staged out at all: instead, all those operations (iota, reshape, less-than) are executed eagerly (op-by-op), building the large constant in memory before building it into the jaxpr as a constant (b
in the jaxpr). That's a shame, because if XLA got to see the whole computation corresponding tofoo
's body, it might not need to materializemask
at all, fusing its construction into thelax.select
operation. Instead, the current behavior and its resulting memory inefficiency has surprised users and prevented them from expressing the XLA computations they want.(More laziness, like #1668, can help, but has limits unless we make everything lazy, which opens up complicated tradeoffs. And even if we end up with full laziness someday, it should be a magical convenience layer on top of a more explicit API underneath.)
The crux of the issue is data dependence:
mask
has no data dependence on the argumentx
(depending on shape information doesn't count as a data dependence). All of JAX's embedding (i.e. tracing) machinery works by data dependence, and while there are other mechanisms possible (not detailed here), relying on data dependence avoids a lot of complexity and other nasty surprises. In particular, it makes compositionality much easier. An unfortunate result is that here we're not able to stage out what we want, at least not without a new API. (Thelax.tie_in
primitive is a way to force a data dependence. and it's used specifically for this issue. But we're not happy with it because it requires a weird mental model and makes programs awkward.)One way to think about the issue is in terms of embedding: we've embedded the jaxpr language (essentially XLA HLO) in Python in a way that looks like regular Python+NumPy, but as in this case it's sometimes unclear what jaxpr is being expressed, and sometimes hard or impossible to express exactly the jaxpr you want. We've optimized for convenience but sacrificed some expressiveness and clarity.
One solution is to provide an explicit variant of the embedding with which users can express any jaxpr, at the cost of some convenience. This PR is a prototype in that direction.
The basic proposal is to introduce a
jit
variant calledpure_jit
that enables more explicit embeddings.So far, we've implemented
pure_jaxpr
, which is like the explicit analogue ofmake_jaxpr
. That is,pure_jaxpr
is tomake_jaxpr
roughly likepure_jit
is tojit
.Here are some examples:
The basic idea is that
pure_jit
(and herepure_jaxpr
) provide an additional argument to the function being called. That additional argument can be used to ensure values, including constants, get staged out (pure.lit
), and to ensure that primitive applications are staged out (pure.app
).To be super explicit, a user could write every primitive application with
pure.app
. But usually these things are only needed in special places, particularly creating constants.Here's what the original program
foo
could look like:We could imagine providing some convenience methods, like have a
pure.iota
andpure.zeros
etc., but the fundamental components arepure.lit
andpure.app
. (In fact, if we have an identity primitive, we only needpure.app
. See alsokyapply
from Autograd's early history.)TODO:
pure_jit