Skip to content

Commit

Permalink
simplify remat partial eval parameterization
Browse files Browse the repository at this point in the history
The main win here is reducing the number of arguments for the function
that parameterizes _remat_partial_eval (so it can be used both with
remat and invertible ad features).

I also included a fix to _remat_partial_eval that is needed in #3370,
though I don't think it's needed on master. It was easier to include the
fix now.

Both these changes made rebasing #3370 easier!
  • Loading branch information
mattjj committed Jun 16, 2020
1 parent 6bcf056 commit 05e0716
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 221 deletions.
43 changes: 23 additions & 20 deletions jax/interpreters/invertible_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from . import ad
from . import partial_eval as pe
from .partial_eval import (PartialVal, partial_eval_jaxpr,
JaxprTracer, ConstVar, convert_constvars_jaxpr, new_eqn_recipe)
JaxprTracer, ConstVar, convert_constvars_jaxpr,
new_eqn_recipe, _partition_knowns)
from ..core import raise_to_shaped, get_aval, Literal, Jaxpr
from ..custom_derivatives import _initial_style_jaxpr, _resolve_kwargs
from ..api_util import flatten_fun_nokwargs
Expand All @@ -43,24 +44,25 @@
invertible_call_p.def_impl(core.call_impl)
invertible_call_p.multiple_results = True

def _invertible_call_make_output_tracers(trace, typed_jaxpr, in_tracers, out_known_pvals, out_unknown_pvals, _, params):
unknown_output_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_unknown_pvals]
lifted_jaxpr = convert_constvars_jaxpr(typed_jaxpr.jaxpr)
# Add dummy arguments representing the outputs to the jaxpr. Those should remain unused in case
# the expression actually ends up being evaluated, but they make it well-formed.
out_known_avals = tuple(raise_to_shaped(get_aval(pval.get_known())) for pval in out_known_pvals)
lifted_jaxpr = _append_invars(lifted_jaxpr, out_known_avals)
new_params = dict(params, call_jaxpr=lifted_jaxpr)
# We also append some dummy outputs that correspond to the known outputs we left in the call_jaxpr
dummy_outputs = [JaxprTracer(trace, pval, core.unit) for pval in out_known_pvals]

output_constants = [JaxprTracer(trace, pval, ConstVar(pval.get_known())) for pval in out_known_pvals]
eqn = new_eqn_recipe(tuple(it.chain(in_tracers, output_constants)),
dummy_outputs + unknown_output_tracers,
invertible_call_p,
new_params)
for t in unknown_output_tracers: t.recipe = eqn
return unknown_output_tracers
def _invertible_call_make_output_tracers(trace, in_tracers, out_tracers, params):
uks = [not t.pval.is_known() for t in out_tracers]
out_tracers_known, out_tracers_unknown = _partition_knowns(out_tracers, uks)

# Add dummy arguments representing the outputs to the jaxpr. Those should
# remain unused if the expression is evaluated, but they make it well-formed.
out_known_avals = [raise_to_shaped(t.pval.get_aval()) for t in out_tracers_known]
out_consts = [trace.instantiate_const(t) for t in out_tracers_known]
new_jaxpr = _append_invars(params['call_jaxpr'], tuple(out_known_avals))
new_in_tracers = (*in_tracers, *out_consts)

# Append dummy outputs that correspond to known outputs left in the call_jaxpr
dummy_outputs = [JaxprTracer(trace, t.pval, core.unit) for t in out_tracers_known]
new_out_tracers = (*dummy_outputs, *out_tracers_unknown)

eqn = new_eqn_recipe(new_in_tracers, new_out_tracers, invertible_call_p,
dict(params, call_jaxpr=new_jaxpr))
for t in out_tracers_unknown: t.recipe = eqn
return new_out_tracers

pe.call_partial_eval_rules[invertible_call_p] = partial(
pe._remat_partial_eval, _invertible_call_make_output_tracers)
Expand All @@ -69,7 +71,8 @@ def _invertible_call_make_output_tracers(trace, typed_jaxpr, in_tracers, out_kno
@cache()
def _append_invars(jaxpr, avals):
newvar = core.gensym([jaxpr])
return core.Jaxpr(jaxpr.constvars, jaxpr.invars + map(newvar, avals), jaxpr.outvars, jaxpr.eqns)
return core.Jaxpr(jaxpr.constvars, jaxpr.invars + map(newvar, avals),
jaxpr.outvars, jaxpr.eqns)


def _invertible_call_transpose(params, call_jaxpr, args, ct, _):
Expand Down
110 changes: 57 additions & 53 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,12 @@ def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, Abst
remat_call_p.def_impl(core.call_impl)
remat_call_p.multiple_results = True

def _remat_partial_eval(wrap_unknown_pvals, trace, _, f, tracers, params):
# We reuse the _remat_partial_eval function both for remat_call and for
# invertible_call, both of which in a sense stage out operations to
# rematerialize values. The two usages differ only in details of what jaxpr eqn
# and output tracers are formed. As a result we parameterize _remat_partial_eval
# by a `process_out` function.
def _remat_partial_eval(process_out, trace, _, f, tracers, params):
concrete = params['concrete']

# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
Expand All @@ -699,24 +704,27 @@ def _remat_partial_eval(wrap_unknown_pvals, trace, _, f, tracers, params):

# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
in_pvals = [t.pval for t in instantiated_tracers]
with core.initial_style_staging():
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
f, in_pvals, partial(remat_call_p.bind, **params))
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
f, in_pvals, partial(remat_call_p.bind, **params))

# Convert consts to inputs, since they may contain Tracer instances.
jaxpr = convert_constvars_jaxpr(jaxpr)
const_tracers = map(trace.new_instantiated_const, consts)

# Since we traced with everything marked as unknown, but we need to know which
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.

in_avals = ([raise_to_shaped(t.pval.get_aval()) for t in env_tracers]
+ [raise_to_shaped(pval.get_aval()) for pval in in_pvals])
in_avals = ([raise_to_shaped(t.pval.get_aval()) for t in const_tracers] +
[raise_to_shaped(t.pval.get_aval()) for t in env_tracers] +
[raise_to_shaped(pval.get_aval()) for pval in in_pvals])
out_avals = [raise_to_shaped(abstract_unit if var is unitvar
else get_aval(var.val) if type(var) is Literal
else pval.get_aval())
for var, pval in zip(jaxpr.outvars, eval_out_pvals)]
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
in_unknowns = [not t.is_known() for t in it.chain(env_tracers, tracers)]
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns,
instantiate=False,
trace_type=trace.master.trace_type)
typed_jaxpr = core.TypedJaxpr(jaxpr, (), in_avals, out_avals)
in_unknowns = ([False] * len(consts) +
[not t.is_known() for t in it.chain(env_tracers, tracers)])
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
typed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.master.trace_type)
out_knowns = [not b for b in out_unknowns]
out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns)

Expand All @@ -728,56 +736,52 @@ def _remat_partial_eval(wrap_unknown_pvals, trace, _, f, tracers, params):
# values. For the use case of inverse-mode ad in op-by-op ("eager mode")
# evaluation, all the primal outputs should be concrete (thus not recomputed).
to_compute = [type(pval[0]) is not ConcreteArray
for uk, pval in zip(out_unknowns, eval_out_pvals)
if not uk]
for uk, pval in zip(out_unknowns, eval_out_pvals) if not uk]
num_outputs = len(jaxpr_unknown.out_avals)
num_res = len(jaxpr_known.out_avals) - num_outputs
jaxpr_known_nores = _dce_jaxpr(jaxpr_known, out_knowns + [False] * num_res, drop_outputs=True)
jaxpr_known_comp = _dce_jaxpr(jaxpr_known_nores, to_compute)
_, in_consts = unzip2(t.pval for t in it.chain(env_tracers, tracers))
reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*in_consts)
reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*consts, *in_consts)
out_known_pvals = map(_reconstruct_pval, out_known_pvals, reconstructed_consts)

# Now that we have out_pvals, the rest is similar to JaxprTrace.process_call.
# Known outputs should keep propagating as constants
assert all(pv.is_known() for pv in out_known_pvals)
known_output_tracers = [trace.new_const(pval.get_known()) for pval in out_known_pvals]

# Unknown outputs get wrapped in tracers with the appropriate recipe, as in JaxprTrace.process_call
const_tracers = map(trace.new_instantiated_const, consts)
unknown_output_tracers = wrap_unknown_pvals(
trace,
typed_jaxpr,
tuple(it.chain(const_tracers, env_tracers, instantiated_tracers)),
out_known_pvals,
out_unknown_pvals,
out_unknowns,
params)

return _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns)

def _remat_make_output_tracers(trace, typed_jaxpr, input_tracers, _, out_unknown_pvals, out_unknowns, params):
unknown_output_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_unknown_pvals]
typed_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns, drop_outputs=True)
lifted_jaxpr = convert_constvars_jaxpr(typed_jaxpr.jaxpr)
new_params = dict(params, call_jaxpr=lifted_jaxpr)
eqn = new_eqn_recipe(input_tracers,
unknown_output_tracers,
remat_call_p,
new_params)
for t in unknown_output_tracers: t.recipe = eqn
return unknown_output_tracers

call_partial_eval_rules[remat_call_p] = partial(_remat_partial_eval, _remat_make_output_tracers)

def _partition_knowns(l, unknowns):
return ([e for e, unknown in zip(l, unknowns) if not unknown],
[e for e, unknown in zip(l, unknowns) if unknown])

def _zip_knowns(kl, ul, unknowns):
ul_it = iter(ul)
kl_it = iter(kl)
return [next(ul_it) if unknown else next(kl_it) for unknown in unknowns]
known_output_tracers = [trace.new_const(pval.get_known())
for pval in out_known_pvals]
# Unknown outputs get wrapped in tracers with the appropriate recipe
unknown_output_tracers = [JaxprTracer(trace, out_pval, None)
for out_pval in out_unknown_pvals]
out_tracers = _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns)

in_tracers = (*const_tracers, *env_tracers, *instantiated_tracers)
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(typed_jaxpr.jaxpr))
return process_out(trace, in_tracers, out_tracers, new_params)

def _remat_make_output_tracers(_, in_tracers, out_tracers, params):
# dce jaxpr outputs
jaxpr = params['call_jaxpr']
out_unknowns = [not t.pval.is_known() for t in out_tracers]
typed_jaxpr = core.TypedJaxpr(jaxpr, (), [v.aval for v in jaxpr.invars],
[v.aval for v in jaxpr.outvars])
new_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns, drop_outputs=True).jaxpr
new_params = dict(params, call_jaxpr=new_jaxpr)

# set up eqn for unknown outputs
unknown_out_tracers = [t for t in out_tracers if not t.pval.is_known()]
eqn = new_eqn_recipe(in_tracers, unknown_out_tracers, remat_call_p, new_params)
for t in unknown_out_tracers: t.recipe = eqn
return out_tracers
call_partial_eval_rules[remat_call_p] = partial(
_remat_partial_eval, _remat_make_output_tracers)

def _partition_knowns(pvals, unknowns: Sequence[bool]):
return ([e for e, unknown in zip(pvals, unknowns) if not unknown],
[e for e, unknown in zip(pvals, unknowns) if unknown])

def _zip_knowns(known_list, unknown_list, which_unknown: Sequence[bool]):
known_iter, unknown_iter = iter(known_list), iter(unknown_list)
return [next(unknown_iter) if uk else next(known_iter) for uk in which_unknown]


def _dce_jaxpr(typed_jaxpr: TypedJaxpr, outputs: Sequence[bool], drop_outputs=False) -> TypedJaxpr:
Expand Down
Loading

0 comments on commit 05e0716

Please sign in to comment.