diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index f48985815397..5fe3a6231763 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -26,6 +26,7 @@ from .. import core from .. import dtypes from .. import linear_util as lu +from ..lib import xla_client # flake8: noqa from ..abstract_arrays import ConcreteArray, raise_to_shaped from ..ad_util import Zero from ..util import (unzip2, safe_zip, safe_map, toposort, partial, split_list, @@ -646,19 +647,24 @@ def _remat_partial_eval(trace, _, f, tracers, params): jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval( f, in_pvals, partial(remat_call_p.bind, **params)) + # Convert consts to inputs, since they may contain Tracer instances. + jaxpr = convert_constvars_jaxpr(jaxpr) + const_tracers = map(trace.new_instantiated_const, consts) + # Since we traced with everything marked as unknown, but we need to know which # outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns. - - in_avals = ([raise_to_shaped(t.pval.get_aval()) for t in env_tracers] - + [raise_to_shaped(pval.get_aval()) for pval in in_pvals]) + in_avals = ([raise_to_shaped(t.pval.get_aval()) for t in const_tracers] + + [raise_to_shaped(t.pval.get_aval()) for t in env_tracers] + + [raise_to_shaped(pval.get_aval()) for pval in in_pvals]) out_avals = [raise_to_shaped(abstract_unit if var is unitvar else get_aval(var.val) if type(var) is Literal else pval.get_aval()) for var, pval in zip(jaxpr.outvars, eval_out_pvals)] - typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals) - in_unknowns = [not t.is_known() for t in it.chain(env_tracers, tracers)] - jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns, - instantiate=False) + typed_jaxpr = core.TypedJaxpr(jaxpr, (), in_avals, out_avals) + in_unknowns = ([False] * len(consts) + + [not t.is_known() for t in it.chain(env_tracers, tracers)]) + jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr( + typed_jaxpr, in_unknowns, instantiate=False) out_knowns = [not b for b in out_unknowns] # First, we prune the jaxpr to be staged out not to have too many outputs. @@ -681,23 +687,21 @@ def _remat_partial_eval(trace, _, f, tracers, params): jaxpr_known_nores = _dce_jaxpr(jaxpr_known, out_knowns + [False] * num_res, drop_outputs=True) jaxpr_known_comp = _dce_jaxpr(jaxpr_known_nores, to_compute) _, in_consts = unzip2(t.pval for t in it.chain(env_tracers, tracers)) - reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*in_consts) + reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*consts, *in_consts) out_known_pvals = map(_reconstruct_pval, out_known_pvals, reconstructed_consts) # Now that we have out_pvals, the rest is similar to JaxprTrace.process_call. # Known outputs should keep propagating as constants assert all(pv.is_known() for pv in out_known_pvals) - known_output_tracers = [trace.new_const(pval.get_known()) for pval in out_known_pvals] + known_output_tracers = [trace.new_const(pval.get_known()) + for pval in out_known_pvals] # Unknown outputs get wrapped in tracers with the appropriate recipe, as in JaxprTrace.process_call - const_tracers = map(trace.new_instantiated_const, consts) unknown_output_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_unknown_pvals] lifted_jaxpr = convert_constvars_jaxpr(typed_jaxpr.jaxpr) new_params = dict(params, call_jaxpr=lifted_jaxpr) eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers, instantiated_tracers)), - unknown_output_tracers, - remat_call_p, - new_params) + unknown_output_tracers, remat_call_p, new_params) for t in unknown_output_tracers: t.recipe = eqn return _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns) @@ -792,6 +796,7 @@ def full_lower(self): def _contents(self): return () + # TODO(mattjj); re-enable after #3421 and jaxlib # def __bool__(self): # self._concretization_error('__bool__') @@ -799,7 +804,7 @@ def _contents(self): # self._concretization_error('__int__') # def _concretization_error(self, name): - # msg = self._progenitor_messages() + # msgs = self._progenitor_messages() # msg = (f"Abstract tracer value passed to {name} for which a concrete value " # "is required.\n" # "This tracer originated from using JAX operations on these lines:\n" @@ -925,7 +930,7 @@ def process_primitive(self, primitive, tracers, params): out_tracers = [JaxprTracer2(self, a) for a in out_avals] invars = map(self.getvar, tracers) outvars = map(self.getvar, out_tracers) - eqn = new_jaxpr_eqn(invars, outvars, primitive, params, source_info()) + eqn = new_jaxpr_eqn(invars, outvars, primitive, params) self.frame.eqns.append(eqn) return out_tracers if primitive.multiple_results else out_tracers.pop() @@ -940,8 +945,7 @@ def process_call(self, call_primitive, f, tracers, params): update_params = call_param_updaters.get(call_primitive) if update_params: new_params = update_params(new_params, [True] * len(tracers)) - eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive, new_params, - source_info()) + eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive, new_params) self.frame.eqns.append(eqn) return out_tracers @@ -967,8 +971,7 @@ def process_map(self, map_primitive, f, tracers, params): update_params = call_param_updaters.get(map_primitive) if update_params: new_params = update_params(new_params, [True] * len(tracers)) - eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive, new_params, - source_info()) + eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive, new_params) self.frame.eqns.append(eqn) return out_tracers @@ -1011,13 +1014,19 @@ def trace_to_jaxpr_final(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]): return jaxpr, out_avals, consts -class FrameInfo(NamedTuple): - filename: str - lineno: int +# TODO(mattjj): re-enable after #3421 and jaxlib +# class FrameInfo(NamedTuple): +# filename: str +# lineno: int -def source_info(): - return [FrameInfo(f.filename, f.lineno) for f in inspect.stack()] +# def source_info(): +# try: +# t = xla_client.Traceback.get_traceback() +# except AttributeError: +# return None +# else: +# return [FrameInfo(f.filename, f.lineno) for f in t.frames] -def user_source_info(frame_infos): - base = os.sep.join(__file__.split(os.sep)[:-2]) - return next((f for f in frame_infos if base not in f.filename), None) +# def user_source_info(frame_infos): +# base = os.sep.join(__file__.split(os.sep)[:-2]) +# return next((f for f in frame_infos if base not in f.filename), None)