diff --git a/jax/_src/config.py b/jax/_src/config.py index b99b75945895..1ab7f5f91043 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -883,6 +883,11 @@ def _update_disable_jit_thread_local(val): ) ) +config.define_bool_state( + name='jax_experimental_subjaxpr_lowering_cache', + default=False, + help='Enable using a cache for lowering subjaxprs.') + @contextlib.contextmanager def explicit_device_put_scope() -> Iterator[None]: """Indicates that the current context is an explicit device_put*() call.""" diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 8dd31a3eb0d2..942593570f65 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -1818,7 +1818,11 @@ def process_call(self, call_primitive, f, explicit_tracers, params): in_tracers = [*implicit_tracers, *explicit_tracers] # TODO(mattjj): check in_tracers are consistent with f.in_type annotation with core.new_sublevel(): - jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main) + if config.jax_check_tracer_leaks or not config.jax_experimental_subjaxpr_lowering_cache: + jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main) + else: + jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2_memoized( + f, self.main).val if jaxpr.effects: raise NotImplementedError('Effects not supported for call primitives.') if params.get('inline', False): @@ -2117,6 +2121,18 @@ def trace_to_subjaxpr_dynamic2( return jaxpr, out_type, consts +@lu.cache +def trace_to_subjaxpr_dynamic2_memoized(fun: lu.WrappedFun, + main: core.MainTrace): + return WrapperForWeakRef(trace_to_subjaxpr_dynamic2(fun, main)) + + +class WrapperForWeakRef: + val: Any + + def __init__(self, val): + self.val = val + @contextlib.contextmanager def extend_jaxpr_stack(main, frame): main.jaxpr_stack = main.jaxpr_stack + (frame,) diff --git a/tests/api_test.py b/tests/api_test.py index 872f9c30dc01..6583da55898e 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -3779,6 +3779,54 @@ def test_jit_negative_static_argnums(self): g(1, 2) # doesn't crash +@jtu.with_config(jax_experimental_subjaxpr_lowering_cache=True) +class SubcallTraceCacheTest(jtu.JaxTestCase): + + def test_subcall_trace_caching(self): + should_be_tracing_f = False + + @api.jit + def f(x): + self.assertTrue(should_be_tracing_f) + return x**2 + + @api.jit + def g(x): + nonlocal should_be_tracing_f + self.assertTrue(should_be_tracing_g) + should_be_tracing_f = True + y = f(x) + should_be_tracing_f = False + z = f(x + 1) + return y + z + + should_be_tracing_g = True + out = g(2) + self.assertEqual(out, 13) + + should_be_tracing_g = False + out = g(3) + self.assertEqual(out, 25) + + def test_subcall_jaxpr_id(self): + + @api.jit + def f(x): + return x**2 + + def g(x): + y = f(x) + z = f(x + 1) + return y + z + + jaxpr = api.make_jaxpr(g)(2) + self.assertIn("call_jaxpr", jaxpr.eqns[0].params) + self.assertIn("call_jaxpr", jaxpr.eqns[2].params) + subjaxpr1 = jaxpr.eqns[0].params["call_jaxpr"] + subjaxpr2 = jaxpr.eqns[2].params["call_jaxpr"] + self.assertIs(subjaxpr1, subjaxpr2) + + class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( @@ -4274,7 +4322,7 @@ def g(): return seq[0] remat(g)() - remat(g)() + remat(lambda: g())() # lambda defeats caching with self.assertRaisesRegex(UnexpectedTracerError, "global state"): api.jit(f)()