Skip to content

Commit

Permalink
when transposing a jaxpr, do a constant eval pass
Browse files Browse the repository at this point in the history
Before this change, the jaxpr

```
{ lambda  ;  ; a.
  let b = add 2. 3.
      c = mul b a
  in [c] }
```

was not transposable, even though it is linear. The issue was that our
transposition machinery assumed that in effect there were no constant
subexpressions, and that constants appeared only as literals or as
constvars. So this jaxpr would work:

```
{ lambda  ;  ; a.
  let b = mul 5. a
  in [b] }
```

However, as in #1715, using `tie_in` to control the parts of the
computation that get staged out to jaxprs and hence xla (e.g. to avoid
materializing large constants) could create jaxprs with constant
subexpressions (indeed that's the point!).

This commit resolves the issue by adding a constant evaluation pass to
`backward_pass` in ad.py. In the grad-of-jit case, this evaluation is
staged out to XLA.

We thought about having `tie_in` not appear in jaxprs (as in #1647), but
that would remove our ability to stage computations out to the backward
pass (e.g. to avoid materializing a large constant in the backward
pass).

The solution here is a stop-gap that we hope to be cleaned up a bit with #1677,
which aims to provide an explicit embedding API to replace the use of `tie_in`.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
  • Loading branch information
mattjj and dougalm committed Nov 19, 2019
1 parent 7f7c21f commit 93132dc
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 18 deletions.
8 changes: 8 additions & 0 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,3 +663,11 @@ def pp_jaxpr(jaxpr):
((pp('let ') >>
vcat(map(pp_eqn, jaxpr.eqns))) +
pp('in {} }}'.format(jaxpr.outvars))).indent(2))


def tie_in(x, y):
return tie_in_p.bind(x, y)

tie_in_p = Primitive('tie_in')
tie_in_p.def_impl(lambda x, y: y)
tie_in_p.def_abstract_eval(lambda x, y: y)
40 changes: 39 additions & 1 deletion jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def unpair_pval(pval):
return (aval_1, const_1), (aval_2, const_2)

def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
if all(ct is zero for ct in cotangents_in):
return [zero] * len(jaxpr.freevars), [zero] * len(jaxpr.invars)

def write_cotangent(v, ct):
# assert v not in primal_env
if ct is not None:
Expand All @@ -162,9 +165,42 @@ def write_primal(v, val):
map(write_primal, jaxpr.freevars, freevar_vals)
map(write_primal, jaxpr.invars, args)

def is_linear(var):
if type(var) is Literal:
return False
else:
return primal_env.get(var, undefined_primal) is undefined_primal

# TODO(mattjj,dougalm): revise with explicit embedding
linear_eqns = []
meat_hook = next(ct for ct in cotangents_in if ct is not zero)
for eqn in jaxpr.eqns:
if eqn.primitive is core.tie_in_p:
if is_linear(eqn.invars[1]):
linear_eqns.append(eqn)
else:
ans = core.tie_in(meat_hook, read_primal(eqn.invars[1]))
write_primal(eqn.outvars[0], ans)
elif any(is_linear(v) for v in eqn.invars):
linear_eqns.append(eqn)
else:
if eqn.bound_subjaxprs:
(_, const_bindings, freevar_bindings), = eqn.bound_subjaxprs
if any(is_linear(v) for v in it.chain(const_bindings, freevar_bindings)):
linear_eqns.append(eqn)
else:
raise NotImplementedError
else:
in_vals = map(read_primal, eqn.invars)
ans = eqn.primitive.bind(*in_vals, **eqn.params)
if eqn.primitive.multiple_results:
map(write_primal, eqn.outvars, ans)
else:
write_primal(eqn.outvars[0], ans)

ct_env = {}
map(write_cotangent, jaxpr.outvars, cotangents_in)
for eqn in jaxpr.eqns[::-1]:
for eqn in linear_eqns[::-1]:
invals = map(read_primal, eqn.invars)
if eqn.primitive.multiple_results:
cts_in = map(read_cotangent, eqn.outvars)
Expand Down Expand Up @@ -514,3 +550,5 @@ def _perm(primal_counts, tangent_counts, lst):
def _interleave(xs, ys):
assert len(xs) == len(ys)
return [e for pair in zip(xs, ys) for l in pair for e in l]

deflinear(core.tie_in_p, lambda t: [zero, t])
2 changes: 2 additions & 0 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,3 +763,5 @@ def _instantiate_device_constant(const, device=None, backend=None, cutoff=1e6):
else:
return xc.Buffer.from_pyval(onp.asarray(const), device,
backend=xb.get_backend(backend))

translations[core.tie_in_p] = lambda c, x, y: y
22 changes: 6 additions & 16 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from .. import linear_util as lu
from .. import dtypes
from ..config import flags
from ..core import Primitive
from ..core import Primitive, tie_in
from ..abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray,
AbstractToken, array_types, make_shaped_array,
raise_to_shaped, abstract_token)
Expand Down Expand Up @@ -1026,9 +1026,6 @@ def sort_key_val(keys, values, dimension=-1):
return sorted_keys, sorted_values


def tie_in(x, y):
return tie_in_p.bind(x, y)

def shaped_identity(x):
return shaped_identity_p.bind(x, shape=x.shape)

Expand Down Expand Up @@ -2583,7 +2580,7 @@ def _select_dtype_rule(pred, on_true, on_false):
return on_true.dtype

def _select_transpose_rule(t, pred, on_true, on_false):
assert pred is not ad.undefined_primal
assert pred is not ad.undefined_primal # TODO
if t is ad_util.zero:
return [None,
ad_util.zero if on_true is ad.undefined_primal else None,
Expand Down Expand Up @@ -3978,22 +3975,15 @@ def _sort_key_val_batch_rule(batched_args, batch_dims, dimension):
batching.primitive_batchers[sort_key_val_p] = _sort_key_val_batch_rule


def _tie_in_transpose_rule(t):
return [ad_util.zero, t]

# TODO move these
def _tie_in_batch_rule(batched_args, batch_dims):
y = tie_in(*batched_args)
_, bdim_y = batch_dims
return y, bdim_y
batching.primitive_batchers[core.tie_in_p] = _tie_in_batch_rule
masking.shape_rules[core.tie_in_p] = lambda shape_exprs: shape_exprs[1]
masking.masking_rules[core.tie_in_p] = lambda vals, logical_shapes: vals[1]

tie_in_p = Primitive('tie_in')
tie_in_p.def_impl(lambda x, y: y)
tie_in_p.def_abstract_eval(lambda x, y: y)
xla.translations[tie_in_p] = lambda c, x, y: y
ad.deflinear(tie_in_p, _tie_in_transpose_rule)
batching.primitive_batchers[tie_in_p] = _tie_in_batch_rule
masking.shape_rules[tie_in_p] = lambda shape_exprs: shape_exprs[1]
masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1]

shaped_identity_p = Primitive('shape_id')
shaped_identity_p.def_impl(lambda x, shape: x)
Expand Down
2 changes: 1 addition & 1 deletion jax/lax/lax_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def _defidentity(prim, argnum=0):
_defbroadcasting(lax.shift_right_arithmetic_p)
_defbroadcasting(lax.shift_right_logical_p)

_defidentity(lax.tie_in_p)
_defidentity(core.tie_in_p)

_defreducer(lax.reduce_sum_p, psum_p)
_defreducer(lax.reduce_max_p, pmax_p)
Expand Down

0 comments on commit 93132dc

Please sign in to comment.