From 05e0716f4004295e720a81a7f46a1aacb14f71a4 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Mon, 15 Jun 2020 18:42:53 -0700 Subject: [PATCH 1/4] simplify remat partial eval parameterization The main win here is reducing the number of arguments for the function that parameterizes _remat_partial_eval (so it can be used both with remat and invertible ad features). I also included a fix to _remat_partial_eval that is needed in #3370, though I don't think it's needed on master. It was easier to include the fix now. Both these changes made rebasing #3370 easier! --- jax/interpreters/invertible_ad.py | 43 +++-- jax/interpreters/partial_eval.py | 110 +++++------ tests/api_test.py | 299 +++++++++++++++--------------- 3 files changed, 231 insertions(+), 221 deletions(-) diff --git a/jax/interpreters/invertible_ad.py b/jax/interpreters/invertible_ad.py index 990f70a95b8e..8a606810581b 100644 --- a/jax/interpreters/invertible_ad.py +++ b/jax/interpreters/invertible_ad.py @@ -22,7 +22,8 @@ from . import ad from . import partial_eval as pe from .partial_eval import (PartialVal, partial_eval_jaxpr, - JaxprTracer, ConstVar, convert_constvars_jaxpr, new_eqn_recipe) + JaxprTracer, ConstVar, convert_constvars_jaxpr, + new_eqn_recipe, _partition_knowns) from ..core import raise_to_shaped, get_aval, Literal, Jaxpr from ..custom_derivatives import _initial_style_jaxpr, _resolve_kwargs from ..api_util import flatten_fun_nokwargs @@ -43,24 +44,25 @@ invertible_call_p.def_impl(core.call_impl) invertible_call_p.multiple_results = True -def _invertible_call_make_output_tracers(trace, typed_jaxpr, in_tracers, out_known_pvals, out_unknown_pvals, _, params): - unknown_output_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_unknown_pvals] - lifted_jaxpr = convert_constvars_jaxpr(typed_jaxpr.jaxpr) - # Add dummy arguments representing the outputs to the jaxpr. Those should remain unused in case - # the expression actually ends up being evaluated, but they make it well-formed. - out_known_avals = tuple(raise_to_shaped(get_aval(pval.get_known())) for pval in out_known_pvals) - lifted_jaxpr = _append_invars(lifted_jaxpr, out_known_avals) - new_params = dict(params, call_jaxpr=lifted_jaxpr) - # We also append some dummy outputs that correspond to the known outputs we left in the call_jaxpr - dummy_outputs = [JaxprTracer(trace, pval, core.unit) for pval in out_known_pvals] - - output_constants = [JaxprTracer(trace, pval, ConstVar(pval.get_known())) for pval in out_known_pvals] - eqn = new_eqn_recipe(tuple(it.chain(in_tracers, output_constants)), - dummy_outputs + unknown_output_tracers, - invertible_call_p, - new_params) - for t in unknown_output_tracers: t.recipe = eqn - return unknown_output_tracers +def _invertible_call_make_output_tracers(trace, in_tracers, out_tracers, params): + uks = [not t.pval.is_known() for t in out_tracers] + out_tracers_known, out_tracers_unknown = _partition_knowns(out_tracers, uks) + + # Add dummy arguments representing the outputs to the jaxpr. Those should + # remain unused if the expression is evaluated, but they make it well-formed. + out_known_avals = [raise_to_shaped(t.pval.get_aval()) for t in out_tracers_known] + out_consts = [trace.instantiate_const(t) for t in out_tracers_known] + new_jaxpr = _append_invars(params['call_jaxpr'], tuple(out_known_avals)) + new_in_tracers = (*in_tracers, *out_consts) + + # Append dummy outputs that correspond to known outputs left in the call_jaxpr + dummy_outputs = [JaxprTracer(trace, t.pval, core.unit) for t in out_tracers_known] + new_out_tracers = (*dummy_outputs, *out_tracers_unknown) + + eqn = new_eqn_recipe(new_in_tracers, new_out_tracers, invertible_call_p, + dict(params, call_jaxpr=new_jaxpr)) + for t in out_tracers_unknown: t.recipe = eqn + return new_out_tracers pe.call_partial_eval_rules[invertible_call_p] = partial( pe._remat_partial_eval, _invertible_call_make_output_tracers) @@ -69,7 +71,8 @@ def _invertible_call_make_output_tracers(trace, typed_jaxpr, in_tracers, out_kno @cache() def _append_invars(jaxpr, avals): newvar = core.gensym([jaxpr]) - return core.Jaxpr(jaxpr.constvars, jaxpr.invars + map(newvar, avals), jaxpr.outvars, jaxpr.eqns) + return core.Jaxpr(jaxpr.constvars, jaxpr.invars + map(newvar, avals), + jaxpr.outvars, jaxpr.eqns) def _invertible_call_transpose(params, call_jaxpr, args, ct, _): diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 0673e6a3db79..df664835869e 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -680,7 +680,12 @@ def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, Abst remat_call_p.def_impl(core.call_impl) remat_call_p.multiple_results = True -def _remat_partial_eval(wrap_unknown_pvals, trace, _, f, tracers, params): +# We reuse the _remat_partial_eval function both for remat_call and for +# invertible_call, both of which in a sense stage out operations to +# rematerialize values. The two usages differ only in details of what jaxpr eqn +# and output tracers are formed. As a result we parameterize _remat_partial_eval +# by a `process_out` function. +def _remat_partial_eval(process_out, trace, _, f, tracers, params): concrete = params['concrete'] # Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of @@ -699,24 +704,27 @@ def _remat_partial_eval(wrap_unknown_pvals, trace, _, f, tracers, params): # Using the instantiated tracers, run call_bind like JaxprTrace.process_call. in_pvals = [t.pval for t in instantiated_tracers] - with core.initial_style_staging(): - jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval( - f, in_pvals, partial(remat_call_p.bind, **params)) + jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval( + f, in_pvals, partial(remat_call_p.bind, **params)) + + # Convert consts to inputs, since they may contain Tracer instances. + jaxpr = convert_constvars_jaxpr(jaxpr) + const_tracers = map(trace.new_instantiated_const, consts) # Since we traced with everything marked as unknown, but we need to know which # outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns. - - in_avals = ([raise_to_shaped(t.pval.get_aval()) for t in env_tracers] - + [raise_to_shaped(pval.get_aval()) for pval in in_pvals]) + in_avals = ([raise_to_shaped(t.pval.get_aval()) for t in const_tracers] + + [raise_to_shaped(t.pval.get_aval()) for t in env_tracers] + + [raise_to_shaped(pval.get_aval()) for pval in in_pvals]) out_avals = [raise_to_shaped(abstract_unit if var is unitvar else get_aval(var.val) if type(var) is Literal else pval.get_aval()) for var, pval in zip(jaxpr.outvars, eval_out_pvals)] - typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals) - in_unknowns = [not t.is_known() for t in it.chain(env_tracers, tracers)] - jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns, - instantiate=False, - trace_type=trace.master.trace_type) + typed_jaxpr = core.TypedJaxpr(jaxpr, (), in_avals, out_avals) + in_unknowns = ([False] * len(consts) + + [not t.is_known() for t in it.chain(env_tracers, tracers)]) + jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr( + typed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.master.trace_type) out_knowns = [not b for b in out_unknowns] out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns) @@ -728,56 +736,52 @@ def _remat_partial_eval(wrap_unknown_pvals, trace, _, f, tracers, params): # values. For the use case of inverse-mode ad in op-by-op ("eager mode") # evaluation, all the primal outputs should be concrete (thus not recomputed). to_compute = [type(pval[0]) is not ConcreteArray - for uk, pval in zip(out_unknowns, eval_out_pvals) - if not uk] + for uk, pval in zip(out_unknowns, eval_out_pvals) if not uk] num_outputs = len(jaxpr_unknown.out_avals) num_res = len(jaxpr_known.out_avals) - num_outputs jaxpr_known_nores = _dce_jaxpr(jaxpr_known, out_knowns + [False] * num_res, drop_outputs=True) jaxpr_known_comp = _dce_jaxpr(jaxpr_known_nores, to_compute) _, in_consts = unzip2(t.pval for t in it.chain(env_tracers, tracers)) - reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*in_consts) + reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*consts, *in_consts) out_known_pvals = map(_reconstruct_pval, out_known_pvals, reconstructed_consts) - # Now that we have out_pvals, the rest is similar to JaxprTrace.process_call. # Known outputs should keep propagating as constants assert all(pv.is_known() for pv in out_known_pvals) - known_output_tracers = [trace.new_const(pval.get_known()) for pval in out_known_pvals] - - # Unknown outputs get wrapped in tracers with the appropriate recipe, as in JaxprTrace.process_call - const_tracers = map(trace.new_instantiated_const, consts) - unknown_output_tracers = wrap_unknown_pvals( - trace, - typed_jaxpr, - tuple(it.chain(const_tracers, env_tracers, instantiated_tracers)), - out_known_pvals, - out_unknown_pvals, - out_unknowns, - params) - - return _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns) - -def _remat_make_output_tracers(trace, typed_jaxpr, input_tracers, _, out_unknown_pvals, out_unknowns, params): - unknown_output_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_unknown_pvals] - typed_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns, drop_outputs=True) - lifted_jaxpr = convert_constvars_jaxpr(typed_jaxpr.jaxpr) - new_params = dict(params, call_jaxpr=lifted_jaxpr) - eqn = new_eqn_recipe(input_tracers, - unknown_output_tracers, - remat_call_p, - new_params) - for t in unknown_output_tracers: t.recipe = eqn - return unknown_output_tracers - -call_partial_eval_rules[remat_call_p] = partial(_remat_partial_eval, _remat_make_output_tracers) - -def _partition_knowns(l, unknowns): - return ([e for e, unknown in zip(l, unknowns) if not unknown], - [e for e, unknown in zip(l, unknowns) if unknown]) - -def _zip_knowns(kl, ul, unknowns): - ul_it = iter(ul) - kl_it = iter(kl) - return [next(ul_it) if unknown else next(kl_it) for unknown in unknowns] + known_output_tracers = [trace.new_const(pval.get_known()) + for pval in out_known_pvals] + # Unknown outputs get wrapped in tracers with the appropriate recipe + unknown_output_tracers = [JaxprTracer(trace, out_pval, None) + for out_pval in out_unknown_pvals] + out_tracers = _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns) + + in_tracers = (*const_tracers, *env_tracers, *instantiated_tracers) + new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(typed_jaxpr.jaxpr)) + return process_out(trace, in_tracers, out_tracers, new_params) + +def _remat_make_output_tracers(_, in_tracers, out_tracers, params): + # dce jaxpr outputs + jaxpr = params['call_jaxpr'] + out_unknowns = [not t.pval.is_known() for t in out_tracers] + typed_jaxpr = core.TypedJaxpr(jaxpr, (), [v.aval for v in jaxpr.invars], + [v.aval for v in jaxpr.outvars]) + new_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns, drop_outputs=True).jaxpr + new_params = dict(params, call_jaxpr=new_jaxpr) + + # set up eqn for unknown outputs + unknown_out_tracers = [t for t in out_tracers if not t.pval.is_known()] + eqn = new_eqn_recipe(in_tracers, unknown_out_tracers, remat_call_p, new_params) + for t in unknown_out_tracers: t.recipe = eqn + return out_tracers +call_partial_eval_rules[remat_call_p] = partial( + _remat_partial_eval, _remat_make_output_tracers) + +def _partition_knowns(pvals, unknowns: Sequence[bool]): + return ([e for e, unknown in zip(pvals, unknowns) if not unknown], + [e for e, unknown in zip(pvals, unknowns) if unknown]) + +def _zip_knowns(known_list, unknown_list, which_unknown: Sequence[bool]): + known_iter, unknown_iter = iter(known_list), iter(unknown_list) + return [next(unknown_iter) if uk else next(known_iter) for uk in which_unknown] def _dce_jaxpr(typed_jaxpr: TypedJaxpr, outputs: Sequence[bool], drop_outputs=False) -> TypedJaxpr: diff --git a/tests/api_test.py b/tests/api_test.py index d890f8926de9..dfbade7b01fc 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1266,6 +1266,156 @@ def test_grad_of_jit_compilation_caching(self): self.assertAllClose(ans1, np.cos(2.), check_dtypes=False) self.assertAllClose(ans2, np.cos(3.), check_dtypes=False) + def test_trivial_computations(self): + x = jnp.array([1, 2, 3]) + y = api.jit(lambda x: x)(x) + self.assertIs(x, y) + + z1, z2 = api.jit(lambda x: (x, x))(x) + self.assertIs(z1, z2) + + x1, x2 = jnp.array([1, 2]), jnp.array([2, 3]) + z1, z2, z3 = api.jit(lambda x, y: (y, 1, x))(x1, x2) + self.assertIs(z1, x2) + self.assertIs(z3, x1) + self.assertEqual(z2, 1) + + def test_nested_jit_hoisting(self): + @api.jit + def f(x, y): + z = 2 * x + return y + z, 3 + + @api.jit + def g(x): + return f(2, x) + + jaxpr_subcomp = xla.jaxpr_subcomp + + jaxprs = [] + def jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs): + jaxprs.append(jaxpr) + return jaxpr_subcomp(c, jaxpr, *args, **kwargs) + + try: + xla.jaxpr_subcomp = jaxpr_subcomp_and_collect + ans = g(3) + finally: + xla.jaxpr_subcomp = jaxpr_subcomp + + self.assertEqual(ans, (7, 3)) + self.assertLen(jaxprs, 2) + outer_jaxpr, inner_jaxpr = jaxprs + + self.assertLen(outer_jaxpr.eqns, 1) + self.assertEqual(outer_jaxpr.eqns[0].primitive.name, 'xla_call') + subjaxpr_1 = outer_jaxpr.eqns[0].params["call_jaxpr"] + self.assertEqual(str(subjaxpr_1), str(inner_jaxpr)) + self.assertLen(inner_jaxpr.eqns, 2) + self.assertEqual(inner_jaxpr.eqns[0].primitive.name, 'mul') + self.assertEqual(inner_jaxpr.eqns[1].primitive.name, 'add') + + def test_primitive_compilation_cache(self): + with jtu.count_primitive_compiles() as count: + lax.add(1, 2) + lax.add(2, 3) + self.assertEqual(count[0], 1) + + def test_arange_jit(self): + # see https://github.com/google/jax/issues/553 + def fun(x): + r = jnp.arange(x.shape[0])[x] + return r + + jit(fun)(jnp.array([0, 1, 2], dtype=jnp.int32)) # doesn't crash + + def helper_save_tracer(self, x): + self._saved_tracer = x + return x + + def test_escaped_tracers_diffent_top_level_traces(self): + api.jit(self.helper_save_tracer)(0.) + with self.assertRaisesRegex( + core.UnexpectedTracerError, + re.compile( + "Encountered an unexpected tracer.*Different traces at same level", + re.DOTALL)): + api.jit(lambda x: self._saved_tracer)(0.) + + def test_escaped_tracers_cant_lift_sublevels(self): + api.jit(self.helper_save_tracer)(0.) + with self.assertRaisesRegex( + core.UnexpectedTracerError, + re.compile( + "Encountered an unexpected tracer.*Can't lift sublevels 1 to 0", + re.DOTALL)): + api.jit(lambda x: x)(self._saved_tracer) + + def test_escaped_tracers_tracer_from_higher_level(self): + api.grad(self.helper_save_tracer)(0.) + with self.assertRaisesRegex( + core.UnexpectedTracerError, + re.compile( + "Encountered an unexpected tracer.*Tracer from a higher level", + re.DOTALL)): + api.grad(lambda x: x)(self._saved_tracer) + + def test_escaped_tracers_incompatible_sublevel(self): + def func1(x): + api.jit(self.helper_save_tracer)(0.) + # Use the tracer + return x + self._saved_tracer + with self.assertRaisesRegex( + core.UnexpectedTracerError, + re.compile("Encountered an unexpected tracer.*Incompatible sublevel", + re.DOTALL)): + api.jit(func1)(2.) + + def test_escaped_tracers_cant_lift(self): + def func1(x): + api.grad(self.helper_save_tracer)(0.) + return x + self._saved_tracer + with self.assertRaisesRegex( + core.UnexpectedTracerError, + re.compile("Encountered an unexpected tracer.*Can't lift", + re.DOTALL)): + api.grad(func1)(2.) + + def test_escaped_tracers_not_among_input_tracers(self): + def func1(x): + api.grad(self.helper_save_tracer)(x) + # Use the tracer + return x + self._saved_tracer + + with self.assertRaisesRegex( + core.UnexpectedTracerError, + re.compile( + "Encountered an unexpected tracer.*Tracer not among input tracers", + re.DOTALL)): + api.jit(func1)(2.) + + def test_pmap_static_kwarg_error_message(self): + # https://github.com/google/jax/issues/3007 + def f(a, b): + return a + b + + g = jax.pmap(f, static_broadcasted_argnums=(1,)) + + msg = (r"pmapped function has static_broadcasted_argnums=\(1,\) but was " + r"called with only 1 positional argument. All static broadcasted " + r"arguments must be passed positionally.") + with self.assertRaisesRegex(ValueError, msg): + g(jnp.ones((1, 1)), b=1) + + def test_vmap_unmapped_last(self): + @partial(jax.vmap, out_axes=jax.interpreters.batching.last) + def f(x): + return np.zeros((2,)) + f(np.zeros((5,))) + + +class RematTest(jtu.JaxTestCase): + def test_remat_basic(self): @api.remat def g(x): @@ -1622,154 +1772,6 @@ def call(f, *args): vjp(v) - def test_trivial_computations(self): - x = jnp.array([1, 2, 3]) - y = api.jit(lambda x: x)(x) - self.assertIs(x, y) - - z1, z2 = api.jit(lambda x: (x, x))(x) - self.assertIs(z1, z2) - - x1, x2 = jnp.array([1, 2]), jnp.array([2, 3]) - z1, z2, z3 = api.jit(lambda x, y: (y, 1, x))(x1, x2) - self.assertIs(z1, x2) - self.assertIs(z3, x1) - self.assertEqual(z2, 1) - - def test_nested_jit_hoisting(self): - @api.jit - def f(x, y): - z = 2 * x - return y + z, 3 - - @api.jit - def g(x): - return f(2, x) - - jaxpr_subcomp = xla.jaxpr_subcomp - - jaxprs = [] - def jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs): - jaxprs.append(jaxpr) - return jaxpr_subcomp(c, jaxpr, *args, **kwargs) - - try: - xla.jaxpr_subcomp = jaxpr_subcomp_and_collect - ans = g(3) - finally: - xla.jaxpr_subcomp = jaxpr_subcomp - - self.assertEqual(ans, (7, 3)) - self.assertLen(jaxprs, 2) - outer_jaxpr, inner_jaxpr = jaxprs - - self.assertLen(outer_jaxpr.eqns, 1) - self.assertEqual(outer_jaxpr.eqns[0].primitive.name, 'xla_call') - subjaxpr_1 = outer_jaxpr.eqns[0].params["call_jaxpr"] - self.assertEqual(str(subjaxpr_1), str(inner_jaxpr)) - self.assertLen(inner_jaxpr.eqns, 2) - self.assertEqual(inner_jaxpr.eqns[0].primitive.name, 'mul') - self.assertEqual(inner_jaxpr.eqns[1].primitive.name, 'add') - - def test_primitive_compilation_cache(self): - with jtu.count_primitive_compiles() as count: - lax.add(1, 2) - lax.add(2, 3) - self.assertEqual(count[0], 1) - - def test_arange_jit(self): - # see https://github.com/google/jax/issues/553 - def fun(x): - r = jnp.arange(x.shape[0])[x] - return r - - jit(fun)(jnp.array([0, 1, 2], dtype=jnp.int32)) # doesn't crash - - def helper_save_tracer(self, x): - self._saved_tracer = x - return x - - def test_escaped_tracers_diffent_top_level_traces(self): - api.jit(self.helper_save_tracer)(0.) - with self.assertRaisesRegex( - core.UnexpectedTracerError, - re.compile( - "Encountered an unexpected tracer.*Different traces at same level", - re.DOTALL)): - api.jit(lambda x: self._saved_tracer)(0.) - - def test_escaped_tracers_cant_lift_sublevels(self): - api.jit(self.helper_save_tracer)(0.) - with self.assertRaisesRegex( - core.UnexpectedTracerError, - re.compile( - "Encountered an unexpected tracer.*Can't lift sublevels 1 to 0", - re.DOTALL)): - api.jit(lambda x: x)(self._saved_tracer) - - def test_escaped_tracers_tracer_from_higher_level(self): - api.grad(self.helper_save_tracer)(0.) - with self.assertRaisesRegex( - core.UnexpectedTracerError, - re.compile( - "Encountered an unexpected tracer.*Tracer from a higher level", - re.DOTALL)): - api.grad(lambda x: x)(self._saved_tracer) - - def test_escaped_tracers_incompatible_sublevel(self): - def func1(x): - api.jit(self.helper_save_tracer)(0.) - # Use the tracer - return x + self._saved_tracer - with self.assertRaisesRegex( - core.UnexpectedTracerError, - re.compile("Encountered an unexpected tracer.*Incompatible sublevel", - re.DOTALL)): - api.jit(func1)(2.) - - def test_escaped_tracers_cant_lift(self): - def func1(x): - api.grad(self.helper_save_tracer)(0.) - return x + self._saved_tracer - with self.assertRaisesRegex( - core.UnexpectedTracerError, - re.compile("Encountered an unexpected tracer.*Can't lift", - re.DOTALL)): - api.grad(func1)(2.) - - def test_escaped_tracers_not_among_input_tracers(self): - def func1(x): - api.grad(self.helper_save_tracer)(x) - # Use the tracer - return x + self._saved_tracer - - with self.assertRaisesRegex( - core.UnexpectedTracerError, - re.compile( - "Encountered an unexpected tracer.*Tracer not among input tracers", - re.DOTALL)): - api.jit(func1)(2.) - - def test_pmap_static_kwarg_error_message(self): - # https://github.com/google/jax/issues/3007 - def f(a, b): - return a + b - - g = jax.pmap(f, static_broadcasted_argnums=(1,)) - - msg = (r"pmapped function has static_broadcasted_argnums=\(1,\) but was " - r"called with only 1 positional argument. All static broadcasted " - r"arguments must be passed positionally.") - with self.assertRaisesRegex(ValueError, msg): - g(jnp.ones((1, 1)), b=1) - - def test_vmap_unmapped_last(self): - @partial(jax.vmap, out_axes=jax.interpreters.batching.last) - def f(x): - return np.zeros((2,)) - f(np.zeros((5,))) - - class JaxprTest(jtu.JaxTestCase): def test_scalar_literals(self): @@ -2833,6 +2835,7 @@ def clip_gradient(x): jax.grad(clip_gradient)(1.) # doesn't crash + class InvertibleADTest(jtu.JaxTestCase): def test_invertible_basic(self): From 8a901ba064bdf7a20edc765fc0ed5d2aec6ac489 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Mon, 15 Jun 2020 19:36:45 -0700 Subject: [PATCH 2/4] deflake --- jax/interpreters/invertible_ad.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/jax/interpreters/invertible_ad.py b/jax/interpreters/invertible_ad.py index 8a606810581b..93f83bbb4491 100644 --- a/jax/interpreters/invertible_ad.py +++ b/jax/interpreters/invertible_ad.py @@ -13,7 +13,6 @@ # limitations under the License. from functools import partial -import itertools as it from typing import Dict, Any, Callable import jax @@ -21,8 +20,7 @@ from jax import linear_util as lu from . import ad from . import partial_eval as pe -from .partial_eval import (PartialVal, partial_eval_jaxpr, - JaxprTracer, ConstVar, convert_constvars_jaxpr, +from .partial_eval import (PartialVal, partial_eval_jaxpr, JaxprTracer, new_eqn_recipe, _partition_knowns) from ..core import raise_to_shaped, get_aval, Literal, Jaxpr from ..custom_derivatives import _initial_style_jaxpr, _resolve_kwargs From 005958e13ec0d30cd45192673e761e886622db6e Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 16 Jun 2020 11:46:37 -0700 Subject: [PATCH 3/4] added reviewer suggestion --- jax/interpreters/invertible_ad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/interpreters/invertible_ad.py b/jax/interpreters/invertible_ad.py index 93f83bbb4491..1a8f2e8225c3 100644 --- a/jax/interpreters/invertible_ad.py +++ b/jax/interpreters/invertible_ad.py @@ -54,7 +54,7 @@ def _invertible_call_make_output_tracers(trace, in_tracers, out_tracers, params) new_in_tracers = (*in_tracers, *out_consts) # Append dummy outputs that correspond to known outputs left in the call_jaxpr - dummy_outputs = [JaxprTracer(trace, t.pval, core.unit) for t in out_tracers_known] + dummy_outputs = [trace.new_const(t.pval.get_known()) for t in out_tracers_known] new_out_tracers = (*dummy_outputs, *out_tracers_unknown) eqn = new_eqn_recipe(new_in_tracers, new_out_tracers, invertible_call_p, From dfdf05fe50ac733f25e5727df0aa13798b713a4a Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 16 Jun 2020 11:56:42 -0700 Subject: [PATCH 4/4] deflake --- jax/interpreters/invertible_ad.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/interpreters/invertible_ad.py b/jax/interpreters/invertible_ad.py index 1a8f2e8225c3..e8d8293d8b82 100644 --- a/jax/interpreters/invertible_ad.py +++ b/jax/interpreters/invertible_ad.py @@ -20,8 +20,8 @@ from jax import linear_util as lu from . import ad from . import partial_eval as pe -from .partial_eval import (PartialVal, partial_eval_jaxpr, JaxprTracer, - new_eqn_recipe, _partition_knowns) +from .partial_eval import (PartialVal, partial_eval_jaxpr, new_eqn_recipe, + _partition_knowns) from ..core import raise_to_shaped, get_aval, Literal, Jaxpr from ..custom_derivatives import _initial_style_jaxpr, _resolve_kwargs from ..api_util import flatten_fun_nokwargs