Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy sublanguage #1668

Merged
merged 1 commit into from
Jan 8, 2020
Merged

Lazy sublanguage #1668

merged 1 commit into from
Jan 8, 2020

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Nov 12, 2019

TLDR

Before this commit, this computation would avoid materializing the iota (np.arange) array at trace time:

@jit
def f(x):
  m, n = x.shape
  return x + np.arange(n)

But this one would materialize the iota array at trace time and stage it into the computation as a potentially large array constant:

@jit
def f(x):
  m, n = x.shape
  return x + np.arange(m)[:, None]

The difference is that previously operations like broadcasts, transposes, and singleton-dimension-adding reshapes (as above) would force otherwise lazy values to be materialized, while after this commit broadcasts, transposes, and those reshapes are all lazy operations that only update metadata and reuse the same device buffer as their input rather than compiling and executing XLA computations and producing new buffers.

Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).

Finally, this PR replaces the ad-hoc "lazy device constant" system.

In an earlier version of this PR, I had also included the feature that lazy expressions would be fused into eager mode op-by-op compiled computations. After thinking it through with @hawkinsp we decided to split that functionality out into a follow-up PR, if we decide to include it at all.

TODO

  • add a way to make identities and triangular matrices
  • systematic tests
  • make jit force its arguments (see comment below)
  • microbenchmarks
  • write implementation notes below

Follow-up work for subsequent PRs

Design idea

The basic idea is to introduce a kind of lazy sublanguage. Each DeviceArray carries an expression in this lazy sublanguage, and when we apply certain operations we'll produce a result that refers to the same underlying buffer as the input but has an updated expression (rather than compiling and executing an XLA computation to produce a new result buffer). Only when we apply functions to the array that don't exist in the lazy sublanguage will we compile and execute an XLA computation (taking into account the any lazy expressions on the inputs).

We want a proper sub-language rather than, say, modeling the full jaxpr language, because that way we sidestep a tough tradeoff: if the lazy sublanguage can only express "cheap" operations, meaning operations we're happy to stage separately into multiple downstream consumers (and hence could be evaluated multiple times), then we don't need to worry about work sharing. If instead we had expensive computations, which we might not want to evaluate more than once, then we would need a more complex system that attempted to share work between evaluations of lazy expressions. (It gets complex because when we evaluate a lazy subexpression we'd need to decide which of its intermediates to materialize, and then update all equivalent lazy subexpressions in the system with the materialized values we computed.) Unfortunately "cheap" is hard to define precisely without understanding XLA better (on all backends), but the kinds of operations outlined above are ones that we expect can be fused into consumers, and so are "cheap" because the values of these expressions may never be materialized at all.

Because we're attaching expressions to DeviceArrays, this design operates "underneath" all the tracing logic: it sits under the impl rules. In other words, as the toy model below makes clear, it's something one can do in any numerical library, and all of JAX's tracing and transformation machinery sits on top unmodified.

Lazy sublanguage

Other than being able to express the broadcasts, reshapes, and transposes we want, some design criteria for the language are:

  1. cheap expressions (e.g. expressions XLA fuses into consumers, avoiding direct evaluation)
  2. fixed-size canonical representation (i.e. it doesn't grow as we apply more broadcast/reshape/transpose operations) so to get more cache hits and for fast expression manipulation

Here's the abstract syntax in terms of AST constructors:

data LazyExpr = LazyExpr Input Shape Dims
data Input = ArrayVar
           | Iota Dtype Size
           | Eye Dtype Shape Offset
           | Tri Dtype Shape Offset
           | Delta Shape Axes

type Shape = [Int]
type Dims = [Maybe Int]
type Size = Int
type Offset = Int
type Axes = [Int]

There are two components to a LazyExpr: an input and a reindexing specification. The input represents a base array to which the reindexing specification is applied.

An input can represent an array constructor (Iota, Eye, etc) or it can be an ArrayVar which encodes that the base array is some exogenous array value. (These LazyExprs are attached to DeviceArrays, so when the input part of the expression is ArrayVar that basically means the associated device buffer is the input, while if the input is an array constructor than the associated device_buffer field of the DeviceArray should be set to the sentinel value xla.device_constant.)

The reindexing specification encodes the shape of the final result and a list of dimensions, which are integers or Nones. The integer entries take on values 0, 1, ..., N-1 where N is the rank of the input array, and encode where the axes of the input array are to be mapped in the final output. When an entry is None that indicates that the corresponding axis of the result is a broadcasted one.

The corresponding AST constructors in Python look like

LazyExpr = namedtuple('LazyExpr', ['input', 'shape', 'dims'])
LazyArrayVar = namedtuple('ArrayVar', [])
LazyIota = namedtuple('Iota', ['dtype', 'size'])
LazyEye = namedtuple('Eye', ['dtype', 'shape', 'offset'])
LazyTri = namedtuple('Tri', ['dtype', 'shape', 'offset'])
LazyDelta = namedtuple('Delta', ['dtype', 'shape', 'axes'])

Here are some examples of lazy expressions and the arrays they represent:

LazyExpr(input=Iota(dtype=dtype('float32'), size=3), shape=(3, 4), dims=(0, None))
DeviceArray([[0., 0., 0., 0.],
             [1., 1., 1., 1.],
             [2., 2., 2., 2.]], dtype=float32)

LazyExpr(input=Iota(dtype=dtype('float32'), size=3), shape=(4, 3), dims=(None, 0))
DeviceArray([[0., 1., 2.],
             [0., 1., 2.],
             [0., 1., 2.],
             [0., 1., 2.]], dtype=float32)

See xla.py for a numpy-based interpreter of this language.

Toy model

Here's a model showing the main idea (except for the changes to op-by-op and jit/pmap logic in xla.py).

from collections import namedtuple
from functools import partial
import itertools as it

import numpy as onp
from jax.util import unzip2

LazyExpr = namedtuple('LazyExpr', ['input', 'shape', 'dims'])
ArrayVar = namedtuple('ArrayVar', [])
Iota = namedtuple('Iota', ['dtype', 'size'])


class Array(object):
  __slots__ = ['buf', 'lexpr']
  def __init__(self, buf, lexpr):
    self.buf = buf
    self.lexpr = lexpr

def materialize(arr):
  input_, shape, dims = arr.lexpr

  t = type(input_)
  if t is ArrayVar:
    x = arr.buf
  elif t is Iota:
    assert arr.buf is None
    x = onp.arange(input_.size, dtype=input_.dtype)
  else:
    assert False

  bcast_dims, perm = unzip2((i, d) for i, d in enumerate(dims) if d is not None)
  transposed = onp.transpose(x, perm)
  out = _broadcast_in_dim(transposed, shape, bcast_dims)

  return out

# NumPy implementation of lax.broadcast_in_dim (from lax_reference.py)
def _broadcast_in_dim(operand, shape, bcast_dims):
  inshape = tuple(1 if i not in bcast_dims else d for i, d in enumerate(shape))
  return onp.broadcast_to(onp.reshape(operand, inshape), shape)


# making an array (with a trivial identity reindex expression)
def make_array(buf):
  lazy_expr = LazyExpr(ArrayVar(), buf.shape, tuple(range(buf.ndim)))
  return Array(buf, lazy_expr)

# making an iota
def make_iota(dtype, size):
  lazy_expr = LazyExpr(Iota(dtype, size), (size,), (0,))
  return Array(None, lazy_expr)

# now for operations! first, a lazy broadcasting operation, modeling XLA's.
def lazy_broadcast(arr, shape, bcast_dims):
  lazy_expr = arr.lexpr
  new_dims = [None] * len(shape)
  for i, d in enumerate(bcast_dims):
    new_dims[d] = lazy_expr.dims[i]
  new_lazy_expr = LazyExpr(lazy_expr.input, shape, new_dims)
  return Array(arr.buf, new_lazy_expr)

# second, a transpose (HLO transpose uses "where it comes from" encoding)
def lazy_transpose(arr, perm):
  lazy_expr = arr.lexpr
  new_shape = [lazy_expr.shape[i] for i in perm]
  new_dims = [lazy_expr.dims[i] for i in perm]
  new_lazy_expr = LazyExpr(lazy_expr.input, new_shape, new_dims)
  return Array(arr.buf, new_lazy_expr)

Micro-benchmarks

This PR adds some work to the op-by-op dispatch path, so we want to check that there isn't a significant performance regression. In the absence of proper performance regression tests, I just did some quick checks by hand.

On master (CPU):

In [1]: from jax import lax
In [2]: timeit -r 100 -n 1000 lax.add(1, 1).block_until_ready()
92.9 µs ± 8.3 µs per loop (mean ± std. dev. of 100 runs, 1000 loops each)
In [3]: x = lax.add(1, 1)
In [4]: timeit -r 100 -n 1000 lax.add(x, x).block_until_ready()
46.4 µs ± 5.11 µs per loop (mean ± std. dev. of 100 runs, 1000 loops each)

On branch (CPU):

In [1]: from jax import lax
In [2]: timeit -r 100 -n 1000 lax.add(1, 1).block_until_ready()
94.2 µs ± 9.03 µs per loop (mean ± std. dev. of 100 runs, 1000 loops each)
In [3]: x = lax.add(1, 1)
In [4]: timeit -r 100 -n 1000 lax.add(x, x).block_until_ready()
47.1 µs ± 5.36 µs per loop (mean ± std. dev. of 100 runs, 1000 loops each)

I did something similar with jit(lax.add) and there wasn't any movement there either.

Implementation notes

  • The lazy language itself (namely the AST constructor namedtuples, a NumPy-based evaluator eval_lexpr, and an XLA-based interpreter stage_lexpr) is added in lazy.py.
  • The impl rules of relevant primitives are updated in lax.py, so that e.g. a broadcast on a DeviceArray only updates metadata (as in the toy model above).
  • The DeviceArray class in xla.py is updated to keep a _lazy_expr attribute.
  • I made a couple explicit sentinels, like device_constant, rather than using None in some places.
  • DeviceConstant, FIlledConstant, IotaConstant, EyeConstant are all removed.

Fixes #1909, and incidentally fixes #1431 because I happened to be updating that code.

@mattjj
Copy link
Collaborator Author

mattjj commented Dec 4, 2019

Running the mnist_vae.py example, I noticed something unfortunate: we were compiling the update loop (run_epoch) twice!

First call's signature:

WARNING:absl:Compiling run_epoch for args (ArgSpec(aval=ShapedArray(uint32[2]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(2,), dims=(0,)), xla_shape=u32[2]{0}), ArgSpec(aval=ShapedArray(float32[784,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784, 512), dims=(0, 1)), xla_shape=f32[784,512]{1,0}), ArgSpec(aval=ShapedArray(float32[784,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784, 512), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(None,)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(0, 1)), xla_shape=f32[512,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(None,)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(0, 1)), xla_shape=f32[512,10]{1,0}), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(0,)), xla_shape=f32[10]{0}), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(None,)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(0, 1)), xla_shape=f32[512,10]{1,0}), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(0,)), xla_shape=f32[10]{0}), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(None,)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[10,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10, 512), dims=(0, 1)), xla_shape=f32[10,512]{1,0}), ArgSpec(aval=ShapedArray(float32[10,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10, 512), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(None,)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(0, 1)), xla_shape=f32[512,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(None,)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[512,784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 784), dims=(0, 1)), xla_shape=f32[512,784]{1,0}), ArgSpec(aval=ShapedArray(float32[512,784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 784), dims=(None, None)), xla_shape=f32[]), ArgSpec(aval=ShapedArray(float32[784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784,), dims=(0,)), xla_shape=f32[784]{0}), ArgSpec(aval=ShapedArray(float32[784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784,), dims=(None,)), xla_shape=f32[])).

Second and subsequent calls' signatures:

WARNING:absl:Compiling run_epoch for args (ArgSpec(aval=ShapedArray(uint32[2]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(2,), dims=(0,)), xla_shape=u32[2]{0}), ArgSpec(aval=ShapedArray(float32[784,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784, 512), dims=(0, 1)), xla_shape=f32[784,512]{1,0}), ArgSpec(aval=ShapedArray(float32[784,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784, 512), dims=(0, 1)), xla_shape=f32[784,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(0, 1)), xla_shape=f32[512,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(0, 1)), xla_shape=f32[512,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(0, 1)), xla_shape=f32[512,10]{1,0}), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(0, 1)), xla_shape=f32[512,10]{1,0}), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(0,)), xla_shape=f32[10]{0}), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(0,)), xla_shape=f32[10]{0}), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(0, 1)), xla_shape=f32[512,10]{1,0}), ArgSpec(aval=ShapedArray(float32[512,10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 10), dims=(0, 1)), xla_shape=f32[512,10]{1,0}), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(0,)), xla_shape=f32[10]{0}), ArgSpec(aval=ShapedArray(float32[10]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10,), dims=(0,)), xla_shape=f32[10]{0}), ArgSpec(aval=ShapedArray(float32[10,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10, 512), dims=(0, 1)), xla_shape=f32[10,512]{1,0}), ArgSpec(aval=ShapedArray(float32[10,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(10, 512), dims=(0, 1)), xla_shape=f32[10,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(0, 1)), xla_shape=f32[512,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 512), dims=(0, 1)), xla_shape=f32[512,512]{1,0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512,), dims=(0,)), xla_shape=f32[512]{0}), ArgSpec(aval=ShapedArray(float32[512,784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 784), dims=(0, 1)), xla_shape=f32[512,784]{1,0}), ArgSpec(aval=ShapedArray(float32[512,784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(512, 784), dims=(0, 1)), xla_shape=f32[512,784]{1,0}), ArgSpec(aval=ShapedArray(float32[784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784,), dims=(0,)), xla_shape=f32[784]{0}), ArgSpec(aval=ShapedArray(float32[784]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784,), dims=(0,)), xla_shape=f32[784]{0})).

The difference is in the broadcasts, like this pair of signature entries:

ArgSpec(aval=ShapedArray(float32[784,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784, 512), dims=(None, None)), xla_shape=f32[]),
ArgSpec(aval=ShapedArray(float32[784,512]), lazy_expr=LazyExpr(input=ArrayVar(), shape=(784, 512), dims=(0, 1)), xla_shape=f32[784,512]{1,0}),

Because of the np.zeros used to initialize optimizers in optimizers.py, on the first call to run_epoch we're fusing the formation of those lazy zeros into the run_epoch computation, meaning we compile for a signature with scalar arguments and broadcasts on the inputs. Then on the second call to run_epoch we need to recompile for dense array inputs. (If we add an xla.force call to the zeros creation, like xla.force(np.zeros_like(x0)) to the init function of momentum in optimizers.py, we don't have to recompile, as on master.)

Two possible solutions:

  1. make jit force its arguments, as on master (while still keeping op-by-op application and jit closure nonstrict)
  2. change optimizers.py, and similar user code, to call xla.force in cases like this

Discussing with @hawkinsp, we think Option 1 sounds like a better heuristic. (As an extension, if we want it, we could add a nonstrict_argnums to jit to control this behavior, maybe useful for library writers.) Forcing users to think about laziness and change their code, as we'd need to change optimizers.py, is costly, and we don't yet have real use cases where that this strict-jit policy would be problematic.

@hawkinsp articulated these principles:

  • if the computation is expensive to recompile (as is plausible with jit) then we don't want recompiles (because of different lazy expressions on the inputs), whereas for op-by-op broadcasts, singleton-adding reshapes, and transposes we'd like to avoid those compiles (and dispatches and materialization) entirely
  • it's less surprising if we don't recompile by default for jit functions like this one

If we want to be fancy, we could have a heuristic like: force arguments for "big" or slow-to-compile computations, and be lazy otherwise. We have all that information in _xla_callable in xla.py. But we'll start simple!

mattjj added a commit that referenced this pull request Dec 4, 2019
This change is to avoid recompiles. See comment:
#1668 (comment)
Thanks @hawkinsp for help with this.

Also, make force(x) update x's device_buffer reference.
mattjj added a commit that referenced this pull request Dec 4, 2019
This change is to avoid recompiles. See comment:
#1668 (comment)
Thanks @hawkinsp for help with this.

Also, make force(x) update x's device_buffer reference.
mattjj added a commit that referenced this pull request Dec 4, 2019
This change is to avoid recompiles. See comment:
#1668 (comment)
Thanks @hawkinsp for help with this.

Also, make force(x) update x's device_buffer reference.
mattjj added a commit that referenced this pull request Dec 4, 2019
This change is to avoid recompiles. See comment:
#1668 (comment)
Thanks @hawkinsp for help with this.

Also, make force(x) update x's device_buffer reference.
mattjj added a commit that referenced this pull request Dec 31, 2019
Before this commit, evaluating x[:, None] * x[None, :] for a vector x in
op-by-op (eager) mode would compile and execute 3 XLA computations and
materialize a total of 3 result buffers. After this commit, it compiles
and executes 1 XLA computation and materializes only one result buffer.

Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).

Finally, this commit replaces the ad-hoc "lazy device constant" system.

See #1668 for more.
mattjj added a commit that referenced this pull request Jan 1, 2020
Before this commit, evaluating x[:, None] * x[None, :] for a vector x in
op-by-op (eager) mode would compile and execute 3 XLA computations and
materialize a total of 3 result buffers. After this commit, it compiles
and executes 1 XLA computation and materializes only one result buffer.

Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).

Finally, this commit replaces the ad-hoc "lazy device constant" system.

See #1668 for more.
mattjj added a commit that referenced this pull request Jan 1, 2020
Before this commit, evaluating x[:, None] * x[None, :] for a vector x in
op-by-op (eager) mode would compile and execute 3 XLA computations and
materialize a total of 3 result buffers. After this commit, it compiles
and executes 1 XLA computation and materializes only one result buffer.

Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).

Finally, this commit replaces the ad-hoc "lazy device constant" system.

See #1668 for more.
@mattjj mattjj mentioned this pull request Jan 7, 2020
Before this commit, this computation would avoid materializing the iota
array at trace time:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(n)

But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(m)[:, None]

The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.

Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).

This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.

Incidentally fixes #1431

See #1668 for more.
@mattjj mattjj merged commit ad9b6d4 into master Jan 8, 2020
@mattjj mattjj deleted the lazy-sublanguage branch January 8, 2020 04:49
@gnecula
Copy link
Collaborator

gnecula commented Jan 8, 2020

It would be really nice to pull much of the PR descriptions into the documentation, otherwise that info may be hard to find.

@mattjj
Copy link
Collaborator Author

mattjj commented Jan 8, 2020

You're right. The "Lazy Sublanguage" section appears as a comment in the source code but I think you probably meant the jax.readthedocs.io documentation. I'll add it to my todos (documentation on this as well as on the device placement policy stuff we discussed yesterday).

@gnecula gnecula added the useful read PR or issue that contains useful design discussion label Feb 27, 2020
hawkinsp added a commit to hawkinsp/jax that referenced this pull request Mar 4, 2021
This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.)

After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity).

This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
hawkinsp added a commit to hawkinsp/jax that referenced this pull request Mar 4, 2021
Updated version of jax-ml#4536.

This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.)

After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity).

This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
hawkinsp added a commit to hawkinsp/jax that referenced this pull request Mar 4, 2021
Updated version of jax-ml#4536.

This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.)

After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity).

This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Mar 17, 2021
Updated version of jax-ml#4536.

This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.)

After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity).

This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes useful read PR or issue that contains useful design discussion
Projects
None yet
4 participants