Skip to content

Commit

Permalink
Merge pull request #3459 from google/simplify-remat-partial-eval
Browse files Browse the repository at this point in the history
simplify remat partial eval parameterization
  • Loading branch information
mattjj authored Jun 16, 2020
2 parents 140c9ea + dfdf05f commit 711c93d
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 223 deletions.
45 changes: 23 additions & 22 deletions jax/interpreters/invertible_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
# limitations under the License.

from functools import partial
import itertools as it
from typing import Dict, Any, Callable

import jax
from jax import core
from jax import linear_util as lu
from . import ad
from . import partial_eval as pe
from .partial_eval import (PartialVal, partial_eval_jaxpr,
JaxprTracer, ConstVar, convert_constvars_jaxpr, new_eqn_recipe)
from .partial_eval import (PartialVal, partial_eval_jaxpr, new_eqn_recipe,
_partition_knowns)
from ..core import raise_to_shaped, get_aval, Literal, Jaxpr
from ..custom_derivatives import _initial_style_jaxpr, _resolve_kwargs
from ..api_util import flatten_fun_nokwargs
Expand All @@ -43,24 +42,25 @@
invertible_call_p.def_impl(core.call_impl)
invertible_call_p.multiple_results = True

def _invertible_call_make_output_tracers(trace, typed_jaxpr, in_tracers, out_known_pvals, out_unknown_pvals, _, params):
unknown_output_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_unknown_pvals]
lifted_jaxpr = convert_constvars_jaxpr(typed_jaxpr.jaxpr)
# Add dummy arguments representing the outputs to the jaxpr. Those should remain unused in case
# the expression actually ends up being evaluated, but they make it well-formed.
out_known_avals = tuple(raise_to_shaped(get_aval(pval.get_known())) for pval in out_known_pvals)
lifted_jaxpr = _append_invars(lifted_jaxpr, out_known_avals)
new_params = dict(params, call_jaxpr=lifted_jaxpr)
# We also append some dummy outputs that correspond to the known outputs we left in the call_jaxpr
dummy_outputs = [JaxprTracer(trace, pval, core.unit) for pval in out_known_pvals]

output_constants = [JaxprTracer(trace, pval, ConstVar(pval.get_known())) for pval in out_known_pvals]
eqn = new_eqn_recipe(tuple(it.chain(in_tracers, output_constants)),
dummy_outputs + unknown_output_tracers,
invertible_call_p,
new_params)
for t in unknown_output_tracers: t.recipe = eqn
return unknown_output_tracers
def _invertible_call_make_output_tracers(trace, in_tracers, out_tracers, params):
uks = [not t.pval.is_known() for t in out_tracers]
out_tracers_known, out_tracers_unknown = _partition_knowns(out_tracers, uks)

# Add dummy arguments representing the outputs to the jaxpr. Those should
# remain unused if the expression is evaluated, but they make it well-formed.
out_known_avals = [raise_to_shaped(t.pval.get_aval()) for t in out_tracers_known]
out_consts = [trace.instantiate_const(t) for t in out_tracers_known]
new_jaxpr = _append_invars(params['call_jaxpr'], tuple(out_known_avals))
new_in_tracers = (*in_tracers, *out_consts)

# Append dummy outputs that correspond to known outputs left in the call_jaxpr
dummy_outputs = [trace.new_const(t.pval.get_known()) for t in out_tracers_known]
new_out_tracers = (*dummy_outputs, *out_tracers_unknown)

eqn = new_eqn_recipe(new_in_tracers, new_out_tracers, invertible_call_p,
dict(params, call_jaxpr=new_jaxpr))
for t in out_tracers_unknown: t.recipe = eqn
return new_out_tracers

pe.call_partial_eval_rules[invertible_call_p] = partial(
pe._remat_partial_eval, _invertible_call_make_output_tracers)
Expand All @@ -69,7 +69,8 @@ def _invertible_call_make_output_tracers(trace, typed_jaxpr, in_tracers, out_kno
@cache()
def _append_invars(jaxpr, avals):
newvar = core.gensym([jaxpr])
return core.Jaxpr(jaxpr.constvars, jaxpr.invars + map(newvar, avals), jaxpr.outvars, jaxpr.eqns)
return core.Jaxpr(jaxpr.constvars, jaxpr.invars + map(newvar, avals),
jaxpr.outvars, jaxpr.eqns)


def _invertible_call_transpose(params, call_jaxpr, args, ct, _):
Expand Down
110 changes: 57 additions & 53 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,12 @@ def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, Abst
remat_call_p.def_impl(core.call_impl)
remat_call_p.multiple_results = True

def _remat_partial_eval(wrap_unknown_pvals, trace, _, f, tracers, params):
# We reuse the _remat_partial_eval function both for remat_call and for
# invertible_call, both of which in a sense stage out operations to
# rematerialize values. The two usages differ only in details of what jaxpr eqn
# and output tracers are formed. As a result we parameterize _remat_partial_eval
# by a `process_out` function.
def _remat_partial_eval(process_out, trace, _, f, tracers, params):
concrete = params['concrete']

# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
Expand All @@ -699,24 +704,27 @@ def _remat_partial_eval(wrap_unknown_pvals, trace, _, f, tracers, params):

# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
in_pvals = [t.pval for t in instantiated_tracers]
with core.initial_style_staging():
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
f, in_pvals, partial(remat_call_p.bind, **params))
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
f, in_pvals, partial(remat_call_p.bind, **params))

# Convert consts to inputs, since they may contain Tracer instances.
jaxpr = convert_constvars_jaxpr(jaxpr)
const_tracers = map(trace.new_instantiated_const, consts)

# Since we traced with everything marked as unknown, but we need to know which
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.

in_avals = ([raise_to_shaped(t.pval.get_aval()) for t in env_tracers]
+ [raise_to_shaped(pval.get_aval()) for pval in in_pvals])
in_avals = ([raise_to_shaped(t.pval.get_aval()) for t in const_tracers] +
[raise_to_shaped(t.pval.get_aval()) for t in env_tracers] +
[raise_to_shaped(pval.get_aval()) for pval in in_pvals])
out_avals = [raise_to_shaped(abstract_unit if var is unitvar
else get_aval(var.val) if type(var) is Literal
else pval.get_aval())
for var, pval in zip(jaxpr.outvars, eval_out_pvals)]
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
in_unknowns = [not t.is_known() for t in it.chain(env_tracers, tracers)]
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns,
instantiate=False,
trace_type=trace.master.trace_type)
typed_jaxpr = core.TypedJaxpr(jaxpr, (), in_avals, out_avals)
in_unknowns = ([False] * len(consts) +
[not t.is_known() for t in it.chain(env_tracers, tracers)])
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
typed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.master.trace_type)
out_knowns = [not b for b in out_unknowns]
out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns)

Expand All @@ -728,56 +736,52 @@ def _remat_partial_eval(wrap_unknown_pvals, trace, _, f, tracers, params):
# values. For the use case of inverse-mode ad in op-by-op ("eager mode")
# evaluation, all the primal outputs should be concrete (thus not recomputed).
to_compute = [type(pval[0]) is not ConcreteArray
for uk, pval in zip(out_unknowns, eval_out_pvals)
if not uk]
for uk, pval in zip(out_unknowns, eval_out_pvals) if not uk]
num_outputs = len(jaxpr_unknown.out_avals)
num_res = len(jaxpr_known.out_avals) - num_outputs
jaxpr_known_nores = _dce_jaxpr(jaxpr_known, out_knowns + [False] * num_res, drop_outputs=True)
jaxpr_known_comp = _dce_jaxpr(jaxpr_known_nores, to_compute)
_, in_consts = unzip2(t.pval for t in it.chain(env_tracers, tracers))
reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*in_consts)
reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*consts, *in_consts)
out_known_pvals = map(_reconstruct_pval, out_known_pvals, reconstructed_consts)

# Now that we have out_pvals, the rest is similar to JaxprTrace.process_call.
# Known outputs should keep propagating as constants
assert all(pv.is_known() for pv in out_known_pvals)
known_output_tracers = [trace.new_const(pval.get_known()) for pval in out_known_pvals]

# Unknown outputs get wrapped in tracers with the appropriate recipe, as in JaxprTrace.process_call
const_tracers = map(trace.new_instantiated_const, consts)
unknown_output_tracers = wrap_unknown_pvals(
trace,
typed_jaxpr,
tuple(it.chain(const_tracers, env_tracers, instantiated_tracers)),
out_known_pvals,
out_unknown_pvals,
out_unknowns,
params)

return _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns)

def _remat_make_output_tracers(trace, typed_jaxpr, input_tracers, _, out_unknown_pvals, out_unknowns, params):
unknown_output_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_unknown_pvals]
typed_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns, drop_outputs=True)
lifted_jaxpr = convert_constvars_jaxpr(typed_jaxpr.jaxpr)
new_params = dict(params, call_jaxpr=lifted_jaxpr)
eqn = new_eqn_recipe(input_tracers,
unknown_output_tracers,
remat_call_p,
new_params)
for t in unknown_output_tracers: t.recipe = eqn
return unknown_output_tracers

call_partial_eval_rules[remat_call_p] = partial(_remat_partial_eval, _remat_make_output_tracers)

def _partition_knowns(l, unknowns):
return ([e for e, unknown in zip(l, unknowns) if not unknown],
[e for e, unknown in zip(l, unknowns) if unknown])

def _zip_knowns(kl, ul, unknowns):
ul_it = iter(ul)
kl_it = iter(kl)
return [next(ul_it) if unknown else next(kl_it) for unknown in unknowns]
known_output_tracers = [trace.new_const(pval.get_known())
for pval in out_known_pvals]
# Unknown outputs get wrapped in tracers with the appropriate recipe
unknown_output_tracers = [JaxprTracer(trace, out_pval, None)
for out_pval in out_unknown_pvals]
out_tracers = _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns)

in_tracers = (*const_tracers, *env_tracers, *instantiated_tracers)
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(typed_jaxpr.jaxpr))
return process_out(trace, in_tracers, out_tracers, new_params)

def _remat_make_output_tracers(_, in_tracers, out_tracers, params):
# dce jaxpr outputs
jaxpr = params['call_jaxpr']
out_unknowns = [not t.pval.is_known() for t in out_tracers]
typed_jaxpr = core.TypedJaxpr(jaxpr, (), [v.aval for v in jaxpr.invars],
[v.aval for v in jaxpr.outvars])
new_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns, drop_outputs=True).jaxpr
new_params = dict(params, call_jaxpr=new_jaxpr)

# set up eqn for unknown outputs
unknown_out_tracers = [t for t in out_tracers if not t.pval.is_known()]
eqn = new_eqn_recipe(in_tracers, unknown_out_tracers, remat_call_p, new_params)
for t in unknown_out_tracers: t.recipe = eqn
return out_tracers
call_partial_eval_rules[remat_call_p] = partial(
_remat_partial_eval, _remat_make_output_tracers)

def _partition_knowns(pvals, unknowns: Sequence[bool]):
return ([e for e, unknown in zip(pvals, unknowns) if not unknown],
[e for e, unknown in zip(pvals, unknowns) if unknown])

def _zip_knowns(known_list, unknown_list, which_unknown: Sequence[bool]):
known_iter, unknown_iter = iter(known_list), iter(unknown_list)
return [next(unknown_iter) if uk else next(known_iter) for uk in which_unknown]


def _dce_jaxpr(typed_jaxpr: TypedJaxpr, outputs: Sequence[bool], drop_outputs=False) -> TypedJaxpr:
Expand Down
Loading

0 comments on commit 711c93d

Please sign in to comment.