From 240079fb14c11273b07c8eb18ca9c83c8bc241b2 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 12 Jan 2022 21:08:39 -0800 Subject: [PATCH 1/3] add axis env state to cache keys, fixes #9187 --- jax/_src/config.py | 27 ++++++++++++++------------- jax/_src/custom_derivatives.py | 3 --- jax/_src/numpy/lax_numpy.py | 1 + jax/_src/util.py | 2 +- jax/core.py | 20 ++++++++++++++++++++ jax/interpreters/ad.py | 1 + jax/linear_util.py | 6 +++--- tests/api_test.py | 13 +++++++++++++ tests/lax_control_flow_test.py | 10 ++++++++++ tests/pmap_test.py | 3 ++- 10 files changed, 65 insertions(+), 21 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index c2d174283e47..b0a736394b33 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -21,7 +21,7 @@ import os import sys import threading -from typing import Any, List, Callable, NamedTuple, Optional +from typing import Any, List, Callable, NamedTuple, Optional, Hashable import warnings from jax._src import lib @@ -324,15 +324,8 @@ def validate(new_val): return _StateContextManager(name, help, update_thread_local_hook, validate) - def _trace_context(self): - """Returns a tuple of configuration values that affect tracing. - - These values are included in the cache key for linear_util.cache. - - Values included in this set should also most likely be included in - the C++ JIT state, which is handled separately.""" - return (self.x64_enabled, self.jax_numpy_rank_promotion, - self.jax_default_matmul_precision) + def get_thread_local_trace_state(self): + return get_thread_local_trace_state() class _StateContextManager: def __init__(self, name, help, update_thread_local_hook, @@ -405,7 +398,7 @@ def __setattr__(self, name, val): class GlobalJitState(NamedTuple): numpy_rank_promotion: Optional[str] = None - default_matmul_precision: Optional[Any] = None + default_matmul_precision: Optional[Hashable] = None def update_global_jit_state(**kw): @@ -415,9 +408,10 @@ def update_global_jit_state(**kw): class ThreadLocalJitState(NamedTuple): - dynamic_trace_state: Optional[Any] = None + dynamic_trace_state: Optional[Hashable] = None + axis_env_state: Optional[Hashable] = None numpy_rank_promotion: Optional[str] = None - default_matmul_precision: Optional[Any] = None + default_matmul_precision: Optional[Hashable] = None def update_thread_local_jit_state(**kw): @@ -426,6 +420,13 @@ def update_thread_local_jit_state(**kw): tls.extra_jit_context = context._replace(**kw) +def get_thread_local_trace_state() -> Hashable: + tls = jax_jit.thread_local_state() + ctx = tls.extra_jit_context or ThreadLocalJitState() + return (ctx.axis_env_state, config.jax_enable_x64, config.jax_disable_jit, + config.jax_numpy_rank_promotion, config.jax_default_matmul_precision) + + # TODO(mattjj): remove all uses of this flag flags.DEFINE_bool( 'jax_omnistaging', diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index e60110e66e8a..131859795e78 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -62,9 +62,6 @@ def _initial_style_jaxpr(fun, in_avals): def _close_jaxpr(jaxpr): return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) -def _initial_style_staging() -> bool: - return core.thread_local_state.trace_state.initial_style - def _sum_tangents(_, x, *xs): return reduce(ad.add_tangents, xs, x) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4b86cf72cb84..843a1c7f2cd1 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -6769,6 +6769,7 @@ def _compress_method(a, condition, axis=None, out=None): return compress(condition, a, axis, out) +@core.stash_axis_env() @partial(jit, static_argnums=(1,2,3)) def _multi_slice(arr, start_indices: Tuple[Tuple[int, ...]], diff --git a/jax/_src/util.py b/jax/_src/util.py index a762253c27b6..ad35c1467a3b 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -199,7 +199,7 @@ def wrapper(*args, **kwargs): if config.jax_check_tracer_leaks: return f(*args, **kwargs) else: - return cached(config._trace_context(), *args, **kwargs) + return cached(config.get_thread_local_trace_state(), *args, **kwargs) wrapper.cache_clear = cached.cache_clear wrapper.cache_info = cached.cache_info diff --git a/jax/core.py b/jax/core.py index db7fd96d7393..8d717930b3a4 100644 --- a/jax/core.py +++ b/jax/core.py @@ -1854,20 +1854,40 @@ def _unmap_shaped_array(size: int, axis_name, axis: int, aval: ShapedArray) -> S def extend_axis_env(axis_name: AxisName, size: int, tag: Any): frame = AxisEnvFrame(axis_name, size, tag) thread_local_state.trace_state.axis_env.append(frame) + jax_config.update_thread_local_jit_state( + axis_env_state=tuple(thread_local_state.trace_state.axis_env)) try: yield finally: thread_local_state.trace_state.axis_env.pop() + jax_config.update_thread_local_jit_state( + axis_env_state=tuple(thread_local_state.trace_state.axis_env)) @contextmanager def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]]): frames = [AxisEnvFrame(axis_name, size, None) for axis_name, size in axes] thread_local_state.trace_state.axis_env.extend(frames) + jax_config.update_thread_local_jit_state( + axis_env_state=tuple(thread_local_state.trace_state.axis_env)) try: yield finally: for _ in frames: thread_local_state.trace_state.axis_env.pop() + jax_config.update_thread_local_jit_state( + axis_env_state=tuple(thread_local_state.trace_state.axis_env)) + +@contextmanager +def stash_axis_env(): + "Promise that a function or with-suite does not depend implicitly on axis env" + s = thread_local_state.trace_state + prev_axis_env, s.axis_env = s.axis_env, [] + jax_config.update_thread_local_jit_state(axis_env_state=()) + try: + yield + finally: + s.axis_env = prev_axis_env + jax_config.update_thread_local_jit_state(axis_env_state=tuple(s.axis_env)) # When a mapped function is given no axis name, we generate a name object based diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 1a6398aa4443..fbfdf2189c1a 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -625,6 +625,7 @@ def out_axes_thunk(): # The freevars are being fanned out (not mapped). During transpose the # dual of fan-out is fan-in-sum. We apply it to the unmapped invars. + # TODO(mattjj,jekbradbury): should this look at global_axis_size? assert len(in_axes) == len(arg_cts) def unmap_zero(zero, in_axis): return (zero if in_axis is None else diff --git a/jax/linear_util.py b/jax/linear_util.py index 397790aec701..df34b675321b 100644 --- a/jax/linear_util.py +++ b/jax/linear_util.py @@ -260,10 +260,10 @@ def memoized_fun(fun: WrappedFun, *args): cache = fun_caches.setdefault(fun.f, {}) if config.jax_check_tracer_leaks: key = (_copy_main_traces(fun.transforms), fun.params, args, - config.x64_enabled, config._trace_context()) + config.get_thread_local_trace_state()) else: - key = (fun.transforms, fun.params, args, config.x64_enabled, - config._trace_context()) + key = (fun.transforms, fun.params, args, + config.get_thread_local_trace_state()) result = cache.get(key, None) if result is not None: ans, stores = result diff --git a/tests/api_test.py b/tests/api_test.py index 692931c25bfc..59fa0ceabd10 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -828,6 +828,19 @@ def f(d) -> float: with self.assertRaisesRegex(TypeError, "'<' not supported.*"): f({E.A: 1.0, E.B: 2.0}) + def test_caches_depend_on_axis_env(self): + # https://github.com/google/jax/issues/9187 + f = lambda: lax.psum(1, 'i') + g = jax.jit(f) + expected = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)() + ans = jax.vmap(g, axis_name='i', axis_size=2, out_axes=None)() + self.assertEqual(ans, expected) + + # This second call to g could erroneously get a cache hit. + expected = jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)() + ans = jax.vmap(g, axis_name='i', axis_size=3, out_axes=None)() + self.assertEqual(ans, expected) + class PythonJitTest(CPPJitTest): diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index d32bce80f4b6..26561eb23ee2 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2908,5 +2908,15 @@ def body(carry): return lax.while_loop(cond, body, (i, jnp.ones(3)))[1] jax.vmap(f, in_axes=(0, 1))(jnp.arange(4), jnp.ones((3, 4))) + def test_caches_depend_on_axis_env(self): + # https://github.com/google/jax/issues/9187 + scanned_f = lambda _, __: (lax.psum(1, 'i'), None) + f = lambda: lax.scan(scanned_f, 0, None, length=1)[0] + ans = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)() + self.assertEqual(ans, 2) + ans = jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)() + self.assertEqual(ans, 3) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 5e8db1e5d304..37a7c8ee5b49 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -1880,8 +1880,9 @@ def f(x): self.assertEqual(count[0], 2) # one for fwd, one for bwd with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 - _ = jax.vjp(f, x) + _, f_bwd2 = jax.vjp(f, x) _ = f_bwd(x) + _ = f_bwd2(x) self.assertEqual(count[0], 0) # cache hits on fwd and bwd @unittest.skipIf(jax._src.lib._xla_extension_version < 44, From 2d4e797f69bf7d65dcf4ff30bbf0697f296a656e Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 12 Jan 2022 21:08:39 -0800 Subject: [PATCH 2/3] add axis env state to cache keys, fixes #9187 --- jax/_src/numpy/lax_numpy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 843a1c7f2cd1..644f39c4443d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -6775,6 +6775,8 @@ def _multi_slice(arr, start_indices: Tuple[Tuple[int, ...]], limit_indices: Tuple[Tuple[int, ...]], removed_dims: Tuple[Tuple[int, ...]]): + print(core.thread_local_state.trace_state.axis_env) + breakpoint() """Extracts multiple slices from `arr`. This is used to shard DeviceArray arguments to pmap. It's implemented as a From 8b5a9f5916121db6cb40ce021f63012b818be07c Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 12 Jan 2022 13:14:02 -0800 Subject: [PATCH 3/3] cache tracing of (sub)calls when forming a jaxpr --- jax/_src/numpy/lax_numpy.py | 2 -- jax/interpreters/partial_eval.py | 20 ++++++++++++-- tests/api_test.py | 46 ++++++++++++++++++++++++++++++-- tests/pmap_test.py | 2 +- 4 files changed, 63 insertions(+), 7 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 644f39c4443d..843a1c7f2cd1 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -6775,8 +6775,6 @@ def _multi_slice(arr, start_indices: Tuple[Tuple[int, ...]], limit_indices: Tuple[Tuple[int, ...]], removed_dims: Tuple[Tuple[int, ...]]): - print(core.thread_local_state.trace_state.axis_env) - breakpoint() """Extracts multiple slices from `arr`. This is used to shard DeviceArray arguments to pmap. It's implemented as a diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 3e35b568852a..18a52cf32e7e 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -1361,8 +1361,13 @@ def process_call(self, call_primitive, f, tracers, params): in_avals = _tracers_to_avals({}, dim_tracers + tracers) keep_inputs = [False] * len(dim_tracers) + [True] * len(tracers) with core.new_sublevel(): - jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( - f, self.main, in_avals, keep_inputs=keep_inputs) + if config.jax_check_tracer_leaks: + # Don't want to keep a strong ref to 'main' in memoization cache key + jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( + f, self.main, in_avals, keep_inputs=keep_inputs) + else: + jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic_memoized( + f, self.main, tuple(in_avals), tuple(keep_inputs)).val if params.get('inline', False): return core.eval_jaxpr(jaxpr, consts, *dim_tracers, *tracers) source_info = source_info_util.current() @@ -1603,6 +1608,17 @@ def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace, del fun, main, trace, frame, in_tracers, out_tracers, ans return jaxpr, out_avals, consts +@lu.cache +def trace_to_subjaxpr_dynamic_memoized( + fun: lu.WrappedFun, main: core.MainTrace, in_avals, keep_inputs): + tup = trace_to_subjaxpr_dynamic(fun, main, in_avals, keep_inputs=keep_inputs) + return WrapperForWeakRef(tup) + +class WrapperForWeakRef: + val: Any + def __init__(self, val): + self.val = val + @contextlib.contextmanager def extend_jaxpr_stack(main, frame): main.jaxpr_stack = main.jaxpr_stack + (frame,) diff --git a/tests/api_test.py b/tests/api_test.py index 59fa0ceabd10..9cdec48e0941 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1260,7 +1260,6 @@ def f(x, u): self.assertEqual(fwd(rev(f, 0), 1)(x, u).shape, (5, 2)) self.assertEqual(fwd(fwd(f, 0), 1)(x, u).shape, (5, 2)) - def test_large_device_constant(self): ans = jit(lambda x: 2 * x)(jnp.ones(int(2e6))) # doesn't crash self.assertAllClose(ans, np.ones(int(2e6)) * 2., check_dtypes=False) @@ -3322,6 +3321,49 @@ def test_jnp_array_doesnt_device_put(self): api.make_jaxpr(lambda: jnp.array(3))() self.assertEqual(count[0], 0) + def test_subcall_trace_caching(self): + should_be_tracing_f = False + + @api.jit + def f(x): + self.assertTrue(should_be_tracing_f) + return x ** 2 + + @api.jit + def g(x): + nonlocal should_be_tracing_f + self.assertTrue(should_be_tracing_g) + should_be_tracing_f = True + y = f(x) + should_be_tracing_f = False + z = f(x + 1) + return y + z + + should_be_tracing_g = True + out = g(2) + self.assertEqual(out, 13) + + should_be_tracing_g = False + out = g(3) + self.assertEqual(out, 25) + + def test_subcall_jaxpr_id(self): + @api.jit + def f(x): + return x ** 2 + + def g(x): + y = f(x) + z = f(x + 1) + return y + z + + jaxpr = api.make_jaxpr(g)(2) + self.assertIn('call_jaxpr', jaxpr.eqns[0].params) + self.assertIn('call_jaxpr', jaxpr.eqns[2].params) + subjaxpr1 = jaxpr.eqns[0].params['call_jaxpr'] + subjaxpr2 = jaxpr.eqns[2].params['call_jaxpr'] + self.assertIs(subjaxpr1, subjaxpr2) + class RematTest(jtu.JaxTestCase): @@ -3776,7 +3818,7 @@ def g(): return seq[0] remat(g)() - remat(g)() + remat(lambda: g())() # lambda defeats caching with self.assertRaisesRegex(UnexpectedTracerError, "global state"): api.jit(f)() diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 37a7c8ee5b49..d035099fcdb3 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -1877,7 +1877,7 @@ def f(x): with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 _, f_bwd = jax.vjp(f, x) _ = f_bwd(x) - self.assertEqual(count[0], 2) # one for fwd, one for bwd + self.assertEqual(count[0], 2) # once for fwd, once for bwd with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 _, f_bwd2 = jax.vjp(f, x)