Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

omnistaging #3370

Merged
merged 1 commit into from
Jul 30, 2020
Merged

omnistaging #3370

merged 1 commit into from
Jul 30, 2020

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Jun 8, 2020

Motivation

JAX transformations like jit and pmap stage out computations to XLA. That is, we apply them to functions comprising multiple jax.numpy operations so that rather being executed one at a time from Python the operations are all part of one end-to-end optimized XLA computation.

But exactly which operations get staged out? Here’s an example:

import jax.numpy as jnp
from jax import lax

@jit
def select_tril(x):
  mask = jnp.arange(x.shape[0])[:, None] > jnp.arange(x.shape[1])
  return lax.select(mask, x, jnp.zeros_like(x))  # lax.select is like jnp.where

x = np.arange(12).reshape((3, 4))
select_tril(x)
ENTRY jit_select_tril.8 {
  constant.3 = pred[] constant(false)
  constant.1 = pred[3,4]{1,0} constant({...})
  parameter.2 = s32[3,4]{1,0} parameter(0)
  constant.4 = s32[] constant(0)
  broadcast.5 = s32[3,4]{1,0} broadcast(constant.4), dimensions={}
  select.6 = s32[3,4]{1,0} select(constant.1, parameter.2, broadcast.5)
  ROOT tuple.7 = (s32[3,4]{1,0}) tuple(select.6)
}

The select operation is staged out, but the operations for constructing the constant mask are not. Rather than being staged out, the operations that construct mask are executed op-by-op at Python tracing time, and XLA only sees a compile time constant constant.1 representing the value of mask. That’s unfortunate, because if we had staged out the operations for constructing mask, XLA could have fused them into the select and avoided materializing the result at all. As a result we end up wasting memory with a large constant, wasting time dispatching multiple un-fused op-by-op XLA computations, and potentially even fragmenting memory.

(The broadcast that corresponds to the construction of the zeros array for jnp.zeros_like(x) is staged out because JAX is lazy about very simple expressions from #1668.)

The reason the creation of mask is not staged out is that jit operates based on data dependence. That is, jit stages out only those operations in a function that have a data dependence on an augment. Control flow primitives and pmap behave similarly. In the case of select_tril, the operations to construct the constant mask do not have a data dependence on the argument x, so they are not staged out; only the lax.select call has a data dependence.

The omnistaging change is about enabling jit, pmap, and control flow primitives to stage out more computation to XLA. As the name implies, it’s about staging out as much as possible! More precisely, with omnistaging all jax.numpy calls in the dynamic context of a jit-transformed function are staged out to XLA.

After omnistaging, the computation XLA sees for select_tril is

ENTRY jit_select_tril.16 {
  constant.4 = pred[] constant(false)
  iota.1 = s32[3]{0} iota(), iota_dimension=0
  broadcast.5 = s32[3,1]{1,0} broadcast(iota.1), dimensions={0}
  reshape.7 = s32[3]{0} reshape(broadcast.5)
  broadcast.8 = s32[3,4]{1,0} broadcast(reshape.7), dimensions={0}
  iota.2 = s32[4]{0} iota(), iota_dimension=0
  broadcast.6 = s32[1,4]{1,0} broadcast(iota.2), dimensions={1}
  reshape.9 = s32[4]{0} reshape(broadcast.6)
  broadcast.10 = s32[3,4]{1,0} broadcast(reshape.9), dimensions={1}
  compare.11 = pred[3,4]{1,0} compare(broadcast.8, broadcast.10), direction=GT
  parameter.3 = s32[3,4]{1,0} parameter(0)
  constant.12 = s32[] constant(0)
  broadcast.13 = s32[3,4]{1,0} broadcast(constant.12), dimensions={}
  select.14 = s32[3,4]{1,0} select(compare.11, parameter.3, broadcast.13)
  ROOT tuple.15 = (s32[3,4]{1,0}) tuple(select.14)
}

In addition to improving JAX’s memory performance, omnistaging enables a host of other improvements and simplifications throughout JAX. For example, it allows lax.cond to accept thunks, so that lax.cond(x > 0, lambda: f(y), lambda: g(z)) will only evaluate one of f(y) and g(z). Another example is that it lets us remove the lazy sublanguage of #1668.

What's the catch? Common issues and fixes

TODO

Enable with a flag!

This PR sets up parallel implementations of the old core and the new omnistaging one. The old one is still used by default; omnistaging is enabled by setting the JAX_OMNISTAGING environment variable, by setting the boolean flag jax_omnistaging, or by using from jax.config import config; config.enable_omnistaging().

Main implementation idea

The semantics of transformations in JAX are based on a stack of interpreters. When you call a transformed function, we think of pushing its transformations onto a global stack. Each stack frame represents a layer of interpretation. When we bind a primitive, we interpret it using the transformation on the top of the stack (calling the corresponding Trace's process_primitive). If that transformation rule itself binds primitives, those binds are interpreted using the second stack frame. And so on down the stack. If the transformation rule for the interpreter on the bottom of the stack binds primitives, those binds get interpreted with an implicit evaluation interpreter. That implicit evaluation interpreter is special because we know that it will bind no primitives and that there are no more transformations that can come. It's the exit.

Those are the semantics, but there's a kind of an optimization: transformations in the stack are only applied if the input arguments to a primitive bind are boxed in a Tracer of the corresponding transformation. (If the inputs are boxed with different Tracers, the one corresponding to the highest level in the stack wins; that's the purpose of keeping the stack around, and having levels!) That lets us skip irrelevant transformations for primitive binds whose inputs haven't been "infected" with the transformation. Because we've always treated jit the same way, it also results in jit's data-dependence behavior, where the only operations that are staged out are those boxed in the appropriate Tracers because they have a data dependence on an argument to the jit-transformed function. As we've seen, this data-dependent-staging behavior for jit can be undesirable.

The solution idea here is to make jit (and pmap) act differently, by making the implicit bottom of the stack less implicit. That is, we instantiate the special exit-at-the-bottom-of-the-stack interpreter frame explicitly. We make the eval interpreter an explicit interpreter, and when the trace stack is initialized we put an eval interpreter at the base. The point of making it explicit is that now we can swap it out: when we trace a jitted function, we swap out the base of the interpretation stack. That way, when the interpretation of a primitive bind bottoms out while we're tracing a jitted function, instead of evaluating the primitive application we'll stage it out to a jaxpr. Because we always hit the bottom of the stack, regardless of whether arguments to the primitive bind are boxed in a particular Tracer, we end up staging out all primitive binds. Hence "omnistaging"!

(There's one more twist, but it's not really fundamental. The dynamic attribute on the trace stack is one additional mechanism to handle control flow primitives. For those, we want to insert a "special exit-the-system staging frame" at an arbitrary point in the trace stack, not just at the bottom. So they trace functions to jaxprs by pushing a dynamic trace to the top of the stack rather than swapping out the base. That in effect lets us point at a function and say "regardless of whatever transformations are already going on in the system, and regardless of what transformations this function will itself push onto the stack when I run it, give me a jaxpr for this function on these avals!")

For comparison, the current implementation sets up jit (and pmap) traces in a special downward-growing stack, which in some ways also is meant always to sit under the regular transformation stack. But it's a whole stack, rather than just a single base frame, because it has to sort out data dependence issues (e.g. in the case of multi-input primitives) and decide for any given primitive bind which jit traces it should be staged into. Since we're not relying on data dependence anymore, we don't need a whole stack!

@mattjj mattjj force-pushed the omnistaging3 branch 2 times, most recently from df8b1ce to b707fd2 Compare June 10, 2020 20:01
@mattjj mattjj force-pushed the omnistaging3 branch 3 times, most recently from 86543bd to a918e9b Compare June 13, 2020 20:26
mattjj added a commit that referenced this pull request Jun 16, 2020
The main win here is reducing the number of arguments for the function
that parameterizes _remat_partial_eval (so it can be used both with
remat and invertible ad features).

I also included a fix to _remat_partial_eval that is needed in #3370,
though I don't think it's needed on master. It was easier to include the
fix now.

Both these changes made rebasing #3370 easier!
@mattjj mattjj force-pushed the omnistaging3 branch 2 times, most recently from 4333071 to b98f29c Compare June 16, 2020 19:24
@mattjj mattjj force-pushed the omnistaging3 branch 3 times, most recently from 496eecf to 63170cd Compare June 18, 2020 00:24
@mattjj mattjj force-pushed the omnistaging3 branch 2 times, most recently from e761d9a to 1f68f0f Compare June 24, 2020 04:30
copybara-service bot pushed a commit to google-deepmind/dm-haiku that referenced this pull request Jun 24, 2020
Omnistaging (jax-ml/jax#3370) is going to introduce a new, dynamic scoping
mechanism for JAX tracers. Currently we have been lucky with `scan` and
`dynamic_unroll`, since the function we have been scanning has typically only
been impure during `init` and the only side effect has been creating params with
an rng key captured from outside. While technically this is a violation of the
contract for scan (`f` was not a pure function), it has been working fine
(because `f` only has side effects that `scan` cannot observe). However..
omnistaging sees all and will trigger an error here (thankfully this has been
caught by our existing tests!!).

The fix here is to explicitly thread Haiku state in and out of the `scan` and
use this `hk.scan` to implement `dynamic_unroll`. I have tried to make this
more efficient by only threading `state` (e.g. batch norm averages) and `rng` in
and out of the `scan`. This is only safe to do since during `init` we unroll the
first step of the `scan`.

PiperOrigin-RevId: 318067893
Change-Id: I9844d46cfc282e08544ca0eae819940627552d28
@mattjj mattjj force-pushed the omnistaging3 branch 3 times, most recently from 5e8507a to df97557 Compare June 27, 2020 18:02
@mattjj mattjj force-pushed the omnistaging3 branch 2 times, most recently from 88d8055 to 2409c6e Compare July 2, 2020 02:20
hawkinsp added a commit to hawkinsp/jax that referenced this pull request Mar 4, 2021
This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.)

After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity).

This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
hawkinsp added a commit to hawkinsp/jax that referenced this pull request Mar 4, 2021
Updated version of jax-ml#4536.

This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.)

After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity).

This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
hawkinsp added a commit to hawkinsp/jax that referenced this pull request Mar 4, 2021
Updated version of jax-ml#4536.

This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.)

After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity).

This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Mar 17, 2021
Updated version of jax-ml#4536.

This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.)

After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity).

This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants