Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add caching to trace_to_subjaxpr_dynamic. #10775

Merged
merged 1 commit into from
Jul 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,11 @@ def _update_disable_jit_thread_local(val):
)
)

config.define_bool_state(
name='jax_experimental_subjaxpr_lowering_cache',
default=False,
help='Enable using a cache for lowering subjaxprs.')

@contextlib.contextmanager
def explicit_device_put_scope() -> Iterator[None]:
"""Indicates that the current context is an explicit device_put*() call."""
Expand Down
18 changes: 17 additions & 1 deletion jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1818,7 +1818,11 @@ def process_call(self, call_primitive, f, explicit_tracers, params):
in_tracers = [*implicit_tracers, *explicit_tracers]
# TODO(mattjj): check in_tracers are consistent with f.in_type annotation
with core.new_sublevel():
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main)
if config.jax_check_tracer_leaks or not config.jax_experimental_subjaxpr_lowering_cache:
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main)
else:
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2_memoized(
f, self.main).val
if jaxpr.effects:
raise NotImplementedError('Effects not supported for call primitives.')
if params.get('inline', False):
Expand Down Expand Up @@ -2117,6 +2121,18 @@ def trace_to_subjaxpr_dynamic2(
return jaxpr, out_type, consts


@lu.cache
def trace_to_subjaxpr_dynamic2_memoized(fun: lu.WrappedFun,
main: core.MainTrace):
return WrapperForWeakRef(trace_to_subjaxpr_dynamic2(fun, main))


class WrapperForWeakRef:
val: Any

def __init__(self, val):
self.val = val

@contextlib.contextmanager
def extend_jaxpr_stack(main, frame):
main.jaxpr_stack = main.jaxpr_stack + (frame,)
Expand Down
50 changes: 49 additions & 1 deletion tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3779,6 +3779,54 @@ def test_jit_negative_static_argnums(self):
g(1, 2) # doesn't crash


@jtu.with_config(jax_experimental_subjaxpr_lowering_cache=True)
class SubcallTraceCacheTest(jtu.JaxTestCase):

def test_subcall_trace_caching(self):
should_be_tracing_f = False

@api.jit
def f(x):
self.assertTrue(should_be_tracing_f)
return x**2

@api.jit
def g(x):
nonlocal should_be_tracing_f
self.assertTrue(should_be_tracing_g)
should_be_tracing_f = True
y = f(x)
should_be_tracing_f = False
z = f(x + 1)
return y + z

should_be_tracing_g = True
out = g(2)
self.assertEqual(out, 13)

should_be_tracing_g = False
out = g(3)
self.assertEqual(out, 25)

def test_subcall_jaxpr_id(self):

@api.jit
def f(x):
return x**2

def g(x):
y = f(x)
z = f(x + 1)
return y + z

jaxpr = api.make_jaxpr(g)(2)
self.assertIn("call_jaxpr", jaxpr.eqns[0].params)
self.assertIn("call_jaxpr", jaxpr.eqns[2].params)
subjaxpr1 = jaxpr.eqns[0].params["call_jaxpr"]
subjaxpr2 = jaxpr.eqns[2].params["call_jaxpr"]
self.assertIs(subjaxpr1, subjaxpr2)


class RematTest(jtu.JaxTestCase):

@parameterized.named_parameters(
Expand Down Expand Up @@ -4274,7 +4322,7 @@ def g():
return seq[0]

remat(g)()
remat(g)()
remat(lambda: g())() # lambda defeats caching

with self.assertRaisesRegex(UnexpectedTracerError, "global state"):
api.jit(f)()
Expand Down