-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lazy sublanguage #1668
Lazy sublanguage #1668
Conversation
8c0f13c
to
4b33e38
Compare
21b9c87
to
ca030ad
Compare
4283431
to
9232549
Compare
be9cc31
to
07c3461
Compare
Running the mnist_vae.py example, I noticed something unfortunate: we were compiling the update loop ( First call's signature:
Second and subsequent calls' signatures:
The difference is in the broadcasts, like this pair of signature entries:
Because of the Two possible solutions:
Discussing with @hawkinsp, we think Option 1 sounds like a better heuristic. (As an extension, if we want it, we could add a @hawkinsp articulated these principles:
If we want to be fancy, we could have a heuristic like: force arguments for "big" or slow-to-compile computations, and be lazy otherwise. We have all that information in |
This change is to avoid recompiles. See comment: #1668 (comment) Thanks @hawkinsp for help with this. Also, make force(x) update x's device_buffer reference.
07c3461
to
924f6d6
Compare
This change is to avoid recompiles. See comment: #1668 (comment) Thanks @hawkinsp for help with this. Also, make force(x) update x's device_buffer reference.
924f6d6
to
73c3b3e
Compare
This change is to avoid recompiles. See comment: #1668 (comment) Thanks @hawkinsp for help with this. Also, make force(x) update x's device_buffer reference.
73c3b3e
to
c1777fa
Compare
This change is to avoid recompiles. See comment: #1668 (comment) Thanks @hawkinsp for help with this. Also, make force(x) update x's device_buffer reference.
cfcf4f1
to
737fb85
Compare
Before this commit, evaluating x[:, None] * x[None, :] for a vector x in op-by-op (eager) mode would compile and execute 3 XLA computations and materialize a total of 3 result buffers. After this commit, it compiles and executes 1 XLA computation and materializes only one result buffer. Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full). Finally, this commit replaces the ad-hoc "lazy device constant" system. See #1668 for more.
737fb85
to
c9b4bc9
Compare
Before this commit, evaluating x[:, None] * x[None, :] for a vector x in op-by-op (eager) mode would compile and execute 3 XLA computations and materialize a total of 3 result buffers. After this commit, it compiles and executes 1 XLA computation and materializes only one result buffer. Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full). Finally, this commit replaces the ad-hoc "lazy device constant" system. See #1668 for more.
c9b4bc9
to
25b53be
Compare
Before this commit, evaluating x[:, None] * x[None, :] for a vector x in op-by-op (eager) mode would compile and execute 3 XLA computations and materialize a total of 3 result buffers. After this commit, it compiles and executes 1 XLA computation and materializes only one result buffer. Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full). Finally, this commit replaces the ad-hoc "lazy device constant" system. See #1668 for more.
25b53be
to
3d47a67
Compare
Before this commit, this computation would avoid materializing the iota array at trace time: @jit def f(x): m, n = x.shape return x + np.arange(n) But this one would materialize the iota array at trace time and stage it into the computation as a potentially large array constant: @jit def f(x): m, n = x.shape return x + np.arange(m)[:, None] The difference is that previously operations like broadcasts, transposes, and reshapes that add singleton dimensions (as above) would force otherwise lazy values to be materialized, while after this commit broadcasts, transposes, and reshapes are all lazy operations that only update metadata on their input rather than compiling and executing XLA computations and producing new buffers. Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full). This commit replaces the ad-hoc "lazy device constant" system, which was used to get the simpler behavior in the first example above. Incidentally fixes #1431 See #1668 for more.
3d47a67
to
cd54c67
Compare
It would be really nice to pull much of the PR descriptions into the documentation, otherwise that info may be hard to find. |
You're right. The "Lazy Sublanguage" section appears as a comment in the source code but I think you probably meant the jax.readthedocs.io documentation. I'll add it to my todos (documentation on this as well as on the device placement policy stuff we discussed yesterday). |
This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.) After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity). This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
Updated version of jax-ml#4536. This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.) After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity). This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
Updated version of jax-ml#4536. This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.) After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity). This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
Updated version of jax-ml#4536. This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.) After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity). This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
TLDR
Before this commit, this computation would avoid materializing the iota (
np.arange
) array at trace time:But this one would materialize the iota array at trace time and stage it into the computation as a potentially large array constant:
The difference is that previously operations like broadcasts, transposes, and singleton-dimension-adding reshapes (as above) would force otherwise lazy values to be materialized, while after this commit broadcasts, transposes, and those reshapes are all lazy operations that only update metadata and reuse the same device buffer as their input rather than compiling and executing XLA computations and producing new buffers.
Also,
np.eye
andnp.tri
become lazy (in addition tonp.zeros
,np.ones
,np.full
).Finally, this PR replaces the ad-hoc "lazy device constant" system.
In an earlier version of this PR, I had also included the feature that lazy expressions would be fused into eager mode op-by-op compiled computations. After thinking it through with @hawkinsp we decided to split that functionality out into a follow-up PR, if we decide to include it at all.
TODO
jit
force its arguments (see comment below)Follow-up work for subsequent PRs
lax.tie_in
in as many places as possiblepmap
dispatchDesign idea
The basic idea is to introduce a kind of lazy sublanguage. Each DeviceArray carries an expression in this lazy sublanguage, and when we apply certain operations we'll produce a result that refers to the same underlying buffer as the input but has an updated expression (rather than compiling and executing an XLA computation to produce a new result buffer). Only when we apply functions to the array that don't exist in the lazy sublanguage will we compile and execute an XLA computation (taking into account the any lazy expressions on the inputs).
We want a proper sub-language rather than, say, modeling the full jaxpr language, because that way we sidestep a tough tradeoff: if the lazy sublanguage can only express "cheap" operations, meaning operations we're happy to stage separately into multiple downstream consumers (and hence could be evaluated multiple times), then we don't need to worry about work sharing. If instead we had expensive computations, which we might not want to evaluate more than once, then we would need a more complex system that attempted to share work between evaluations of lazy expressions. (It gets complex because when we evaluate a lazy subexpression we'd need to decide which of its intermediates to materialize, and then update all equivalent lazy subexpressions in the system with the materialized values we computed.) Unfortunately "cheap" is hard to define precisely without understanding XLA better (on all backends), but the kinds of operations outlined above are ones that we expect can be fused into consumers, and so are "cheap" because the values of these expressions may never be materialized at all.
Because we're attaching expressions to DeviceArrays, this design operates "underneath" all the tracing logic: it sits under the impl rules. In other words, as the toy model below makes clear, it's something one can do in any numerical library, and all of JAX's tracing and transformation machinery sits on top unmodified.
Lazy sublanguage
Other than being able to express the broadcasts, reshapes, and transposes we want, some design criteria for the language are:
Here's the abstract syntax in terms of AST constructors:
There are two components to a LazyExpr: an input and a reindexing specification. The input represents a base array to which the reindexing specification is applied.
An input can represent an array constructor (
Iota
,Eye
, etc) or it can be anArrayVar
which encodes that the base array is some exogenous array value. (These LazyExprs are attached to DeviceArrays, so when the input part of the expression isArrayVar
that basically means the associated device buffer is the input, while if the input is an array constructor than the associateddevice_buffer
field of the DeviceArray should be set to the sentinel valuexla.device_constant
.)The reindexing specification encodes the shape of the final result and a list of dimensions, which are integers or Nones. The integer entries take on values 0, 1, ..., N-1 where N is the rank of the input array, and encode where the axes of the input array are to be mapped in the final output. When an entry is None that indicates that the corresponding axis of the result is a broadcasted one.
The corresponding AST constructors in Python look like
Here are some examples of lazy expressions and the arrays they represent:
See xla.py for a numpy-based interpreter of this language.
Toy model
Here's a model showing the main idea (except for the changes to op-by-op and jit/pmap logic in xla.py).
Micro-benchmarks
This PR adds some work to the op-by-op dispatch path, so we want to check that there isn't a significant performance regression. In the absence of proper performance regression tests, I just did some quick checks by hand.
On master (CPU):
On branch (CPU):
I did something similar with
jit(lax.add)
and there wasn't any movement there either.Implementation notes
eval_lexpr
, and an XLA-based interpreterstage_lexpr
) is added in lazy.py._lazy_expr
attribute.device_constant
, rather than usingNone
in some places.Fixes #1909, and incidentally fixes #1431 because I happened to be updating that code.