Skip to content

Commit

Permalink
fix remat bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Jun 13, 2020
1 parent 005763e commit a918e9b
Showing 1 changed file with 36 additions and 27 deletions.
63 changes: 36 additions & 27 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .. import core
from .. import dtypes
from .. import linear_util as lu
from ..lib import xla_client # flake8: noqa
from ..abstract_arrays import ConcreteArray, raise_to_shaped
from ..ad_util import Zero
from ..util import (unzip2, safe_zip, safe_map, toposort, partial, split_list,
Expand Down Expand Up @@ -646,19 +647,24 @@ def _remat_partial_eval(trace, _, f, tracers, params):
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
f, in_pvals, partial(remat_call_p.bind, **params))

# Convert consts to inputs, since they may contain Tracer instances.
jaxpr = convert_constvars_jaxpr(jaxpr)
const_tracers = map(trace.new_instantiated_const, consts)

# Since we traced with everything marked as unknown, but we need to know which
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.

in_avals = ([raise_to_shaped(t.pval.get_aval()) for t in env_tracers]
+ [raise_to_shaped(pval.get_aval()) for pval in in_pvals])
in_avals = ([raise_to_shaped(t.pval.get_aval()) for t in const_tracers] +
[raise_to_shaped(t.pval.get_aval()) for t in env_tracers] +
[raise_to_shaped(pval.get_aval()) for pval in in_pvals])
out_avals = [raise_to_shaped(abstract_unit if var is unitvar
else get_aval(var.val) if type(var) is Literal
else pval.get_aval())
for var, pval in zip(jaxpr.outvars, eval_out_pvals)]
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
in_unknowns = [not t.is_known() for t in it.chain(env_tracers, tracers)]
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns,
instantiate=False)
typed_jaxpr = core.TypedJaxpr(jaxpr, (), in_avals, out_avals)
in_unknowns = ([False] * len(consts) +
[not t.is_known() for t in it.chain(env_tracers, tracers)])
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
typed_jaxpr, in_unknowns, instantiate=False)
out_knowns = [not b for b in out_unknowns]

# First, we prune the jaxpr to be staged out not to have too many outputs.
Expand All @@ -681,23 +687,21 @@ def _remat_partial_eval(trace, _, f, tracers, params):
jaxpr_known_nores = _dce_jaxpr(jaxpr_known, out_knowns + [False] * num_res, drop_outputs=True)
jaxpr_known_comp = _dce_jaxpr(jaxpr_known_nores, to_compute)
_, in_consts = unzip2(t.pval for t in it.chain(env_tracers, tracers))
reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*in_consts)
reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*consts, *in_consts)
out_known_pvals = map(_reconstruct_pval, out_known_pvals, reconstructed_consts)

# Now that we have out_pvals, the rest is similar to JaxprTrace.process_call.
# Known outputs should keep propagating as constants
assert all(pv.is_known() for pv in out_known_pvals)
known_output_tracers = [trace.new_const(pval.get_known()) for pval in out_known_pvals]
known_output_tracers = [trace.new_const(pval.get_known())
for pval in out_known_pvals]

# Unknown outputs get wrapped in tracers with the appropriate recipe, as in JaxprTrace.process_call
const_tracers = map(trace.new_instantiated_const, consts)
unknown_output_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_unknown_pvals]
lifted_jaxpr = convert_constvars_jaxpr(typed_jaxpr.jaxpr)
new_params = dict(params, call_jaxpr=lifted_jaxpr)
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers, instantiated_tracers)),
unknown_output_tracers,
remat_call_p,
new_params)
unknown_output_tracers, remat_call_p, new_params)
for t in unknown_output_tracers: t.recipe = eqn

return _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns)
Expand Down Expand Up @@ -792,14 +796,15 @@ def full_lower(self):
def _contents(self):
return ()

# TODO(mattjj); re-enable after #3421 and jaxlib
# def __bool__(self):
# self._concretization_error('__bool__')

# def __int__(self):
# self._concretization_error('__int__')

# def _concretization_error(self, name):
# msg = self._progenitor_messages()
# msgs = self._progenitor_messages()
# msg = (f"Abstract tracer value passed to {name} for which a concrete value "
# "is required.\n"
# "This tracer originated from using JAX operations on these lines:\n"
Expand Down Expand Up @@ -925,7 +930,7 @@ def process_primitive(self, primitive, tracers, params):
out_tracers = [JaxprTracer2(self, a) for a in out_avals]
invars = map(self.getvar, tracers)
outvars = map(self.getvar, out_tracers)
eqn = new_jaxpr_eqn(invars, outvars, primitive, params, source_info())
eqn = new_jaxpr_eqn(invars, outvars, primitive, params)
self.frame.eqns.append(eqn)
return out_tracers if primitive.multiple_results else out_tracers.pop()

Expand All @@ -940,8 +945,7 @@ def process_call(self, call_primitive, f, tracers, params):
update_params = call_param_updaters.get(call_primitive)
if update_params:
new_params = update_params(new_params, [True] * len(tracers))
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive, new_params,
source_info())
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive, new_params)
self.frame.eqns.append(eqn)
return out_tracers

Expand All @@ -967,8 +971,7 @@ def process_map(self, map_primitive, f, tracers, params):
update_params = call_param_updaters.get(map_primitive)
if update_params:
new_params = update_params(new_params, [True] * len(tracers))
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive, new_params,
source_info())
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive, new_params)
self.frame.eqns.append(eqn)
return out_tracers

Expand Down Expand Up @@ -1011,13 +1014,19 @@ def trace_to_jaxpr_final(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
return jaxpr, out_avals, consts


class FrameInfo(NamedTuple):
filename: str
lineno: int
# TODO(mattjj): re-enable after #3421 and jaxlib
# class FrameInfo(NamedTuple):
# filename: str
# lineno: int

def source_info():
return [FrameInfo(f.filename, f.lineno) for f in inspect.stack()]
# def source_info():
# try:
# t = xla_client.Traceback.get_traceback()
# except AttributeError:
# return None
# else:
# return [FrameInfo(f.filename, f.lineno) for f in t.frames]

def user_source_info(frame_infos):
base = os.sep.join(__file__.split(os.sep)[:-2])
return next((f for f in frame_infos if base not in f.filename), None)
# def user_source_info(frame_infos):
# base = os.sep.join(__file__.split(os.sep)[:-2])
# return next((f for f in frame_infos if base not in f.filename), None)

0 comments on commit a918e9b

Please sign in to comment.