-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
omnistaging #3370
Merged
Merged
omnistaging #3370
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This was referenced Jun 10, 2020
mattjj
force-pushed
the
omnistaging3
branch
2 times, most recently
from
June 10, 2020 20:01
df8b1ce
to
b707fd2
Compare
mattjj
force-pushed
the
omnistaging3
branch
3 times, most recently
from
June 13, 2020 20:26
86543bd
to
a918e9b
Compare
mattjj
added a commit
that referenced
this pull request
Jun 16, 2020
The main win here is reducing the number of arguments for the function that parameterizes _remat_partial_eval (so it can be used both with remat and invertible ad features). I also included a fix to _remat_partial_eval that is needed in #3370, though I don't think it's needed on master. It was easier to include the fix now. Both these changes made rebasing #3370 easier!
mattjj
force-pushed
the
omnistaging3
branch
2 times, most recently
from
June 16, 2020 19:24
4333071
to
b98f29c
Compare
mattjj
force-pushed
the
omnistaging3
branch
3 times, most recently
from
June 18, 2020 00:24
496eecf
to
63170cd
Compare
mattjj
force-pushed
the
omnistaging3
branch
2 times, most recently
from
June 24, 2020 04:30
e761d9a
to
1f68f0f
Compare
copybara-service bot
pushed a commit
to google-deepmind/dm-haiku
that referenced
this pull request
Jun 24, 2020
Omnistaging (jax-ml/jax#3370) is going to introduce a new, dynamic scoping mechanism for JAX tracers. Currently we have been lucky with `scan` and `dynamic_unroll`, since the function we have been scanning has typically only been impure during `init` and the only side effect has been creating params with an rng key captured from outside. While technically this is a violation of the contract for scan (`f` was not a pure function), it has been working fine (because `f` only has side effects that `scan` cannot observe). However.. omnistaging sees all and will trigger an error here (thankfully this has been caught by our existing tests!!). The fix here is to explicitly thread Haiku state in and out of the `scan` and use this `hk.scan` to implement `dynamic_unroll`. I have tried to make this more efficient by only threading `state` (e.g. batch norm averages) and `rng` in and out of the `scan`. This is only safe to do since during `init` we unroll the first step of the `scan`. PiperOrigin-RevId: 318067893 Change-Id: I9844d46cfc282e08544ca0eae819940627552d28
mattjj
force-pushed
the
omnistaging3
branch
3 times, most recently
from
June 27, 2020 18:02
5e8507a
to
df97557
Compare
mattjj
force-pushed
the
omnistaging3
branch
2 times, most recently
from
July 2, 2020 02:20
88d8055
to
2409c6e
Compare
This was referenced Sep 11, 2020
This was referenced Sep 22, 2020
This was referenced Oct 11, 2020
This was referenced Dec 14, 2020
hawkinsp
added a commit
to hawkinsp/jax
that referenced
this pull request
Mar 4, 2021
This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.) After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity). This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
hawkinsp
added a commit
to hawkinsp/jax
that referenced
this pull request
Mar 4, 2021
Updated version of jax-ml#4536. This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.) After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity). This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
hawkinsp
added a commit
to hawkinsp/jax
that referenced
this pull request
Mar 4, 2021
Updated version of jax-ml#4536. This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.) After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity). This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
NeilGirdhar
pushed a commit
to NeilGirdhar/jax
that referenced
this pull request
Mar 17, 2021
Updated version of jax-ml#4536. This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.) After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity). This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
This was referenced Apr 23, 2021
5 tasks
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
JAX transformations like
jit
andpmap
stage out computations to XLA. That is, we apply them to functions comprising multiplejax.numpy
operations so that rather being executed one at a time from Python the operations are all part of one end-to-end optimized XLA computation.But exactly which operations get staged out? Here’s an example:
The
select
operation is staged out, but the operations for constructing the constantmask
are not. Rather than being staged out, the operations that constructmask
are executed op-by-op at Python tracing time, and XLA only sees a compile time constantconstant.1
representing the value ofmask
. That’s unfortunate, because if we had staged out the operations for constructingmask
, XLA could have fused them into theselect
and avoided materializing the result at all. As a result we end up wasting memory with a large constant, wasting time dispatching multiple un-fused op-by-op XLA computations, and potentially even fragmenting memory.(The
broadcast
that corresponds to the construction of the zeros array forjnp.zeros_like(x)
is staged out because JAX is lazy about very simple expressions from #1668.)The reason the creation of
mask
is not staged out is thatjit
operates based on data dependence. That is,jit
stages out only those operations in a function that have a data dependence on an augment. Control flow primitives andpmap
behave similarly. In the case ofselect_tril
, the operations to construct the constantmask
do not have a data dependence on the argumentx
, so they are not staged out; only thelax.select
call has a data dependence.The omnistaging change is about enabling
jit
,pmap
, and control flow primitives to stage out more computation to XLA. As the name implies, it’s about staging out as much as possible! More precisely, with omnistaging alljax.numpy
calls in the dynamic context of ajit
-transformed function are staged out to XLA.After omnistaging, the computation XLA sees for
select_tril
isIn addition to improving JAX’s memory performance, omnistaging enables a host of other improvements and simplifications throughout JAX. For example, it allows
lax.cond
to accept thunks, so thatlax.cond(x > 0, lambda: f(y), lambda: g(z))
will only evaluate one off(y)
andg(z)
. Another example is that it lets us remove the lazy sublanguage of #1668.What's the catch? Common issues and fixes
TODO
Enable with a flag!
This PR sets up parallel implementations of the old core and the new omnistaging one. The old one is still used by default; omnistaging is enabled by setting the
JAX_OMNISTAGING
environment variable, by setting the boolean flagjax_omnistaging
, or by usingfrom jax.config import config; config.enable_omnistaging()
.Main implementation idea
The semantics of transformations in JAX are based on a stack of interpreters. When you call a transformed function, we think of pushing its transformations onto a global stack. Each stack frame represents a layer of interpretation. When we bind a primitive, we interpret it using the transformation on the top of the stack (calling the corresponding
Trace
'sprocess_primitive
). If that transformation rule itself binds primitives, those binds are interpreted using the second stack frame. And so on down the stack. If the transformation rule for the interpreter on the bottom of the stack binds primitives, those binds get interpreted with an implicit evaluation interpreter. That implicit evaluation interpreter is special because we know that it will bind no primitives and that there are no more transformations that can come. It's the exit.Those are the semantics, but there's a kind of an optimization: transformations in the stack are only applied if the input arguments to a primitive bind are boxed in a
Tracer
of the corresponding transformation. (If the inputs are boxed with different Tracers, the one corresponding to the highest level in the stack wins; that's the purpose of keeping the stack around, and havinglevel
s!) That lets us skip irrelevant transformations for primitive binds whose inputs haven't been "infected" with the transformation. Because we've always treatedjit
the same way, it also results injit
's data-dependence behavior, where the only operations that are staged out are those boxed in the appropriateTracer
s because they have a data dependence on an argument to thejit
-transformed function. As we've seen, this data-dependent-staging behavior forjit
can be undesirable.The solution idea here is to make
jit
(andpmap
) act differently, by making the implicit bottom of the stack less implicit. That is, we instantiate the special exit-at-the-bottom-of-the-stack interpreter frame explicitly. We make the eval interpreter an explicit interpreter, and when the trace stack is initialized we put an eval interpreter at the base. The point of making it explicit is that now we can swap it out: when we trace ajit
ted function, we swap out the base of the interpretation stack. That way, when the interpretation of a primitive bind bottoms out while we're tracing ajit
ted function, instead of evaluating the primitive application we'll stage it out to a jaxpr. Because we always hit the bottom of the stack, regardless of whether arguments to the primitive bind are boxed in a particular Tracer, we end up staging out all primitive binds. Hence "omnistaging"!(There's one more twist, but it's not really fundamental. The
dynamic
attribute on the trace stack is one additional mechanism to handle control flow primitives. For those, we want to insert a "special exit-the-system staging frame" at an arbitrary point in the trace stack, not just at the bottom. So they trace functions to jaxprs by pushing adynamic
trace to the top of the stack rather than swapping out the base. That in effect lets us point at a function and say "regardless of whatever transformations are already going on in the system, and regardless of what transformations this function will itself push onto the stack when I run it, give me a jaxpr for this function on these avals!")For comparison, the current implementation sets up
jit
(andpmap
) traces in a special downward-growing stack, which in some ways also is meant always to sit under the regular transformation stack. But it's a whole stack, rather than just a single base frame, because it has to sort out data dependence issues (e.g. in the case of multi-input primitives) and decide for any given primitive bind whichjit
traces it should be staged into. Since we're not relying on data dependence anymore, we don't need a whole stack!