Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype an explicit embedding API #1677

Closed
wants to merge 1 commit into from
Closed

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Nov 13, 2019

Consider this program:

import jax.numpy as np

@jit
def foo(x):
  M, N = x.shape
  mask =  np.arange(M)[:, None] < np.arange(N)[None, :]
  return lax.select(mask, x, 0)

What gets staged out to XLA? You might be surprised! The jaxpr seen by the tracing machinery is essentially this:

{ lambda b c ;  ; a.
  let d = select b a c
  in [d] }

The creation of mask doesn't get staged out at all: instead, all those operations (iota, reshape, less-than) are executed eagerly (op-by-op), building the large constant in memory before building it into the jaxpr as a constant (b in the jaxpr). That's a shame, because if XLA got to see the whole computation corresponding to foo's body, it might not need to materialize mask at all, fusing its construction into the lax.select operation. Instead, the current behavior and its resulting memory inefficiency has surprised users and prevented them from expressing the XLA computations they want.

(More laziness, like #1668, can help, but has limits unless we make everything lazy, which opens up complicated tradeoffs. And even if we end up with full laziness someday, it should be a magical convenience layer on top of a more explicit API underneath.)

The crux of the issue is data dependence: mask has no data dependence on the argument x (depending on shape information doesn't count as a data dependence). All of JAX's embedding (i.e. tracing) machinery works by data dependence, and while there are other mechanisms possible (not detailed here), relying on data dependence avoids a lot of complexity and other nasty surprises. In particular, it makes compositionality much easier. An unfortunate result is that here we're not able to stage out what we want, at least not without a new API. (The lax.tie_in primitive is a way to force a data dependence. and it's used specifically for this issue. But we're not happy with it because it requires a weird mental model and makes programs awkward.)

One way to think about the issue is in terms of embedding: we've embedded the jaxpr language (essentially XLA HLO) in Python in a way that looks like regular Python+NumPy, but as in this case it's sometimes unclear what jaxpr is being expressed, and sometimes hard or impossible to express exactly the jaxpr you want. We've optimized for convenience but sacrificed some expressiveness and clarity.

One solution is to provide an explicit variant of the embedding with which users can express any jaxpr, at the cost of some convenience. This PR is a prototype in that direction.

The basic proposal is to introduce a jit variant called pure_jit that enables more explicit embeddings.

So far, we've implemented pure_jaxpr, which is like the explicit analogue of make_jaxpr. That is, pure_jaxpr is to make_jaxpr roughly like pure_jit is to jit.

Here are some examples:

import jax.numpy as np
from jax import lax
from jax import pure_jaxpr

def f(pure):
  x = pure.lit(1)
  y = np.broadcast_to(x, (1000,))
  return y
print(pure_jaxpr(f))  # pure_jaxpr is to make_jaxpr as pure_jit is to jit
{ lambda  ;  ; .
  let a = broadcast_in_dim[ shape=(1000,)
                            broadcast_dimensions=() ] 1
  in [a] }
def g(pure, y):
  x = pure.app(lax.iota_p, dtype=np.int32, size=10)
  return x + y
print(pure_jaxpr(g, 3))
{ lambda  ;  ; a.
  let b = iota[ dtype=<class 'numpy.int32'>
                size=10 ]
      c = add b a
  in [c] }

The basic idea is that pure_jit (and here pure_jaxpr) provide an additional argument to the function being called. That additional argument can be used to ensure values, including constants, get staged out (pure.lit), and to ensure that primitive applications are staged out (pure.app).

To be super explicit, a user could write every primitive application with pure.app. But usually these things are only needed in special places, particularly creating constants.

Here's what the original program foo could look like:

from jax import lax

@pure_jit
def foo(pure, x):
  a = pure.app(lax.iota_p, size=x.shape[0], dtype=np.int32)
  b = pure.app(lax.iota_p, size=x.shape[1], dtype=np.int32)
  mask =  a[:, None] < b[None, :]
  return lax.select(mask, x, 0)

We could imagine providing some convenience methods, like have a pure.iota and pure.zeros etc., but the fundamental components are pure.lit and pure.app. (In fact, if we have an identity primitive, we only need pure.app. See also kyapply from Autograd's early history.)

TODO:

  • iterate on the API
  • check the behavior under partial evaluation (i.e. grad): does grad-of-jit cause these things to be executed in an op-by-op way?
  • add pure_jit
  • if needed, generalize our notion of jaxpr linearity (cf this comment)

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
@shoyer
Copy link
Collaborator

shoyer commented Nov 16, 2019

One particular use-case that has come up for me with some regularity is scanning over statically sized ranges. We currently do this with lax.scan over np.arange(...) (e.g., in #1706), but I imagine that XLA could potentially optimize this better if we used an explicit embedding.

Maybe an explicit helper function like lax.scan_range would be in order?

@mattjj
Copy link
Collaborator Author

mattjj commented Nov 16, 2019

@shoyer interesting. What if we just exposed a length argument to scan (so that it didn't require any extensive inputs)?

@shoyer
Copy link
Collaborator

shoyer commented Nov 16, 2019

What if we just exposed a length argument to scan (so that it didn't require any extensive inputs)?

That would do it, though it's not entirely clear to me what the signature of scanned function should be in that case.

I guess the two choices are either:

def scan(f, init, length):
  carry = init
  ys = []
  for i in range(length):
    carry, y = f(carry, i)
    ys.append(y)
  return carry, np.stack(ys)

or

def scan(f, init, length):
  carry = init
  ys = []
  for _ in range(length):
    carry, y = f(carry)
    ys.append(y)
  return carry, np.stack(ys)

The later seems perhaps a little cleaner (what if you don't need the iteration counter?) but now the type signature looks different, i..e., (c -> (c, b)) -> c -> (c, [b]).

@mattjj
Copy link
Collaborator Author

mattjj commented Nov 16, 2019

I was thinking of keeping the signature the same, but exposing the length parameter of the primitive. Then there could be zero or more scanned (i.e. extensive) inputs just like there can already be zero or more scanned outputs; it would be an error if length is provided and it disagrees with the leading axis length of a scanned input. (Currently there can be zero or more scanned outputs, but there must be at least one scanned input, just because that's currently the only way to get the length.)

Concretely:

def scan(f, init, xs, length=None):
  length = length if length is not None else get_axis_length(xs)
  carry = init
  ys = []
  for i in range(length):
    x = get_slice(xs, i)
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys)

where get_axis_length and get_slice handle the case where xs is an empty container (like a None).

(EDIT: I didn't write the logic for raising an error when length is provided and disagrees with get_axis_length(xs), but hopefully you get what I mean!)

@shoyer
Copy link
Collaborator

shoyer commented Nov 16, 2019

I was thinking of keeping the signature the same, but exposing the length parameter of the primitive. Then there could be zero or more scanned (i.e. extensive) inputs just like there can already be zero or more scanned outputs; it would be an error if length is provided and it disagrees with the leading axis length of a scanned input.

I was trying to figure out what get_slice on None would look like, but now that I recall how pytrees work it is obvious: it would be None, since empty pytrees get flattened into an empty list.

So yes, that is obviously how things should work. In that case, let me suggest the following update to the pseudocode / pure-Python version of scan:

def scan(f, init, xs=None, length=None):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys)

mattjj added a commit that referenced this pull request Nov 19, 2019
Before this change, the jaxpr

```
{ lambda  ;  ; a.
  let b = add 2. 3.
      c = mul b a
  in [c] }
```

was not transposable, even though it is linear. The issue was that our
transposition machinery assumed that in effect there were no constant
subexpressions, and that constants appeared only as literals or as
constvars. So this jaxpr would work:

```
{ lambda  ;  ; a.
  let b = mul 5. a
  in [b] }
```

However, as in #1715, using `tie_in` to control the parts of the
computation that get staged out to jaxprs and hence xla (e.g. to avoid
materializing large constants) could create jaxprs with constant
subexpressions (indeed that's the point!).

This commit resolves the issue by adding a constant evaluation pass to
`backward_pass` in ad.py. In the grad-of-jit case, this evaluation is
staged out to XLA.

We thought about having `tie_in` not appear in jaxprs (as in #1647), but
that would remove our ability to stage computations out to the backward
pass (e.g. to avoid materializing a large constant in the backward
pass).

The solution here is a stop-gap that we hope to be cleaned up a bit with #1677,
which aims to provide an explicit embedding API to replace the use of `tie_in`.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
mattjj added a commit that referenced this pull request Nov 19, 2019
Before this change, the jaxpr

```
{ lambda  ;  ; a.
  let b = add 2. 3.
      c = mul b a
  in [c] }
```

was not transposable, even though it is linear. The issue was that our
transposition machinery assumed that in effect there were no constant
subexpressions, and that constants appeared only as literals or as
constvars. So this jaxpr would work:

```
{ lambda  ;  ; a.
  let b = mul 5. a
  in [b] }
```

However, as in #1715, using `tie_in` to control the parts of the
computation that get staged out to jaxprs and hence xla (e.g. to avoid
materializing large constants) could create jaxprs with constant
subexpressions (indeed that's the point!).

This commit resolves the issue by adding a constant evaluation pass to
`backward_pass` in ad.py. In the grad-of-jit case, this evaluation is
staged out to XLA.

We thought about having `tie_in` not appear in jaxprs (as in #1647), but
that would remove our ability to stage computations out to the backward
pass (e.g. to avoid materializing a large constant in the backward
pass).

The solution here is a stop-gap that we hope to be cleaned up a bit with #1677,
which aims to provide an explicit embedding API to replace the use of `tie_in`.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
@mattjj
Copy link
Collaborator Author

mattjj commented Jan 7, 2020

We've had some better ideas here! Closing this as likely not the right direction to take.

@mattjj mattjj closed this Jan 7, 2020
@jakevdp jakevdp deleted the explicit-embedding branch October 6, 2021 19:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants