Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cache tracing of (sub)calls when forming a jaxpr #9181

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import os
import sys
import threading
from typing import Any, List, Callable, NamedTuple, Optional
from typing import Any, List, Callable, NamedTuple, Optional, Hashable
import warnings

from jax._src import lib
Expand Down Expand Up @@ -324,15 +324,8 @@ def validate(new_val):

return _StateContextManager(name, help, update_thread_local_hook, validate)

def _trace_context(self):
"""Returns a tuple of configuration values that affect tracing.

These values are included in the cache key for linear_util.cache.

Values included in this set should also most likely be included in
the C++ JIT state, which is handled separately."""
return (self.x64_enabled, self.jax_numpy_rank_promotion,
self.jax_default_matmul_precision)
def get_thread_local_trace_state(self):
return get_thread_local_trace_state()

class _StateContextManager:
def __init__(self, name, help, update_thread_local_hook,
Expand Down Expand Up @@ -405,7 +398,7 @@ def __setattr__(self, name, val):

class GlobalJitState(NamedTuple):
numpy_rank_promotion: Optional[str] = None
default_matmul_precision: Optional[Any] = None
default_matmul_precision: Optional[Hashable] = None


def update_global_jit_state(**kw):
Expand All @@ -415,9 +408,10 @@ def update_global_jit_state(**kw):


class ThreadLocalJitState(NamedTuple):
dynamic_trace_state: Optional[Any] = None
dynamic_trace_state: Optional[Hashable] = None
axis_env_state: Optional[Hashable] = None
numpy_rank_promotion: Optional[str] = None
default_matmul_precision: Optional[Any] = None
default_matmul_precision: Optional[Hashable] = None


def update_thread_local_jit_state(**kw):
Expand All @@ -426,6 +420,13 @@ def update_thread_local_jit_state(**kw):
tls.extra_jit_context = context._replace(**kw)


def get_thread_local_trace_state() -> Hashable:
tls = jax_jit.thread_local_state()
ctx = tls.extra_jit_context or ThreadLocalJitState()
return (ctx.axis_env_state, config.jax_enable_x64, config.jax_disable_jit,
config.jax_numpy_rank_promotion, config.jax_default_matmul_precision)


# TODO(mattjj): remove all uses of this flag
flags.DEFINE_bool(
'jax_omnistaging',
Expand Down
3 changes: 0 additions & 3 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ def _initial_style_jaxpr(fun, in_avals):
def _close_jaxpr(jaxpr):
return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())

def _initial_style_staging() -> bool:
return core.thread_local_state.trace_state.initial_style

def _sum_tangents(_, x, *xs):
return reduce(ad.add_tangents, xs, x)

Expand Down
1 change: 1 addition & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6769,6 +6769,7 @@ def _compress_method(a, condition, axis=None, out=None):
return compress(condition, a, axis, out)


@core.stash_axis_env()
@partial(jit, static_argnums=(1,2,3))
def _multi_slice(arr,
start_indices: Tuple[Tuple[int, ...]],
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def wrapper(*args, **kwargs):
if config.jax_check_tracer_leaks:
return f(*args, **kwargs)
else:
return cached(config._trace_context(), *args, **kwargs)
return cached(config.get_thread_local_trace_state(), *args, **kwargs)

wrapper.cache_clear = cached.cache_clear
wrapper.cache_info = cached.cache_info
Expand Down
20 changes: 20 additions & 0 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1854,20 +1854,40 @@ def _unmap_shaped_array(size: int, axis_name, axis: int, aval: ShapedArray) -> S
def extend_axis_env(axis_name: AxisName, size: int, tag: Any):
frame = AxisEnvFrame(axis_name, size, tag)
thread_local_state.trace_state.axis_env.append(frame)
jax_config.update_thread_local_jit_state(
axis_env_state=tuple(thread_local_state.trace_state.axis_env))
try:
yield
finally:
thread_local_state.trace_state.axis_env.pop()
jax_config.update_thread_local_jit_state(
axis_env_state=tuple(thread_local_state.trace_state.axis_env))

@contextmanager
def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]]):
frames = [AxisEnvFrame(axis_name, size, None) for axis_name, size in axes]
thread_local_state.trace_state.axis_env.extend(frames)
jax_config.update_thread_local_jit_state(
axis_env_state=tuple(thread_local_state.trace_state.axis_env))
try:
yield
finally:
for _ in frames:
thread_local_state.trace_state.axis_env.pop()
jax_config.update_thread_local_jit_state(
axis_env_state=tuple(thread_local_state.trace_state.axis_env))

@contextmanager
def stash_axis_env():
"Promise that a function or with-suite does not depend implicitly on axis env"
s = thread_local_state.trace_state
prev_axis_env, s.axis_env = s.axis_env, []
jax_config.update_thread_local_jit_state(axis_env_state=())
try:
yield
finally:
s.axis_env = prev_axis_env
jax_config.update_thread_local_jit_state(axis_env_state=tuple(s.axis_env))


# When a mapped function is given no axis name, we generate a name object based
Expand Down
1 change: 1 addition & 0 deletions jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,7 @@ def out_axes_thunk():

# The freevars are being fanned out (not mapped). During transpose the
# dual of fan-out is fan-in-sum. We apply it to the unmapped invars.
# TODO(mattjj,jekbradbury): should this look at global_axis_size?
assert len(in_axes) == len(arg_cts)
def unmap_zero(zero, in_axis):
return (zero if in_axis is None else
Expand Down
20 changes: 18 additions & 2 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,8 +1361,13 @@ def process_call(self, call_primitive, f, tracers, params):
in_avals = _tracers_to_avals({}, dim_tracers + tracers)
keep_inputs = [False] * len(dim_tracers) + [True] * len(tracers)
with core.new_sublevel():
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, in_avals, keep_inputs=keep_inputs)
if config.jax_check_tracer_leaks:
# Don't want to keep a strong ref to 'main' in memoization cache key
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, in_avals, keep_inputs=keep_inputs)
else:
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic_memoized(
f, self.main, tuple(in_avals), tuple(keep_inputs)).val
if params.get('inline', False):
return core.eval_jaxpr(jaxpr, consts, *dim_tracers, *tracers)
source_info = source_info_util.current()
Expand Down Expand Up @@ -1603,6 +1608,17 @@ def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
del fun, main, trace, frame, in_tracers, out_tracers, ans
return jaxpr, out_avals, consts

@lu.cache
def trace_to_subjaxpr_dynamic_memoized(
fun: lu.WrappedFun, main: core.MainTrace, in_avals, keep_inputs):
tup = trace_to_subjaxpr_dynamic(fun, main, in_avals, keep_inputs=keep_inputs)
return WrapperForWeakRef(tup)

class WrapperForWeakRef:
val: Any
def __init__(self, val):
self.val = val

@contextlib.contextmanager
def extend_jaxpr_stack(main, frame):
main.jaxpr_stack = main.jaxpr_stack + (frame,)
Expand Down
6 changes: 3 additions & 3 deletions jax/linear_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,10 @@ def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, args,
config.x64_enabled, config._trace_context())
config.get_thread_local_trace_state())
else:
key = (fun.transforms, fun.params, args, config.x64_enabled,
config._trace_context())
key = (fun.transforms, fun.params, args,
config.get_thread_local_trace_state())
result = cache.get(key, None)
if result is not None:
ans, stores = result
Expand Down
59 changes: 57 additions & 2 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,19 @@ def f(d) -> float:
with self.assertRaisesRegex(TypeError, "'<' not supported.*"):
f({E.A: 1.0, E.B: 2.0})

def test_caches_depend_on_axis_env(self):
# https://github.com/google/jax/issues/9187
f = lambda: lax.psum(1, 'i')
g = jax.jit(f)
expected = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)()
ans = jax.vmap(g, axis_name='i', axis_size=2, out_axes=None)()
self.assertEqual(ans, expected)

# This second call to g could erroneously get a cache hit.
expected = jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)()
ans = jax.vmap(g, axis_name='i', axis_size=3, out_axes=None)()
self.assertEqual(ans, expected)


class PythonJitTest(CPPJitTest):

Expand Down Expand Up @@ -1247,7 +1260,6 @@ def f(x, u):
self.assertEqual(fwd(rev(f, 0), 1)(x, u).shape, (5, 2))
self.assertEqual(fwd(fwd(f, 0), 1)(x, u).shape, (5, 2))


def test_large_device_constant(self):
ans = jit(lambda x: 2 * x)(jnp.ones(int(2e6))) # doesn't crash
self.assertAllClose(ans, np.ones(int(2e6)) * 2., check_dtypes=False)
Expand Down Expand Up @@ -3309,6 +3321,49 @@ def test_jnp_array_doesnt_device_put(self):
api.make_jaxpr(lambda: jnp.array(3))()
self.assertEqual(count[0], 0)

def test_subcall_trace_caching(self):
should_be_tracing_f = False

@api.jit
def f(x):
self.assertTrue(should_be_tracing_f)
return x ** 2

@api.jit
def g(x):
nonlocal should_be_tracing_f
self.assertTrue(should_be_tracing_g)
should_be_tracing_f = True
y = f(x)
should_be_tracing_f = False
z = f(x + 1)
return y + z

should_be_tracing_g = True
out = g(2)
self.assertEqual(out, 13)

should_be_tracing_g = False
out = g(3)
self.assertEqual(out, 25)

def test_subcall_jaxpr_id(self):
@api.jit
def f(x):
return x ** 2

def g(x):
y = f(x)
z = f(x + 1)
return y + z

jaxpr = api.make_jaxpr(g)(2)
self.assertIn('call_jaxpr', jaxpr.eqns[0].params)
self.assertIn('call_jaxpr', jaxpr.eqns[2].params)
subjaxpr1 = jaxpr.eqns[0].params['call_jaxpr']
subjaxpr2 = jaxpr.eqns[2].params['call_jaxpr']
self.assertIs(subjaxpr1, subjaxpr2)


class RematTest(jtu.JaxTestCase):

Expand Down Expand Up @@ -3763,7 +3818,7 @@ def g():
return seq[0]

remat(g)()
remat(g)()
remat(lambda: g())() # lambda defeats caching

with self.assertRaisesRegex(UnexpectedTracerError, "global state"):
api.jit(f)()
Expand Down
10 changes: 10 additions & 0 deletions tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2908,5 +2908,15 @@ def body(carry):
return lax.while_loop(cond, body, (i, jnp.ones(3)))[1]
jax.vmap(f, in_axes=(0, 1))(jnp.arange(4), jnp.ones((3, 4)))

def test_caches_depend_on_axis_env(self):
# https://github.com/google/jax/issues/9187
scanned_f = lambda _, __: (lax.psum(1, 'i'), None)
f = lambda: lax.scan(scanned_f, 0, None, length=1)[0]
ans = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)()
self.assertEqual(ans, 2)
ans = jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)()
self.assertEqual(ans, 3)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())
5 changes: 3 additions & 2 deletions tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1877,11 +1877,12 @@ def f(x):
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
_, f_bwd = jax.vjp(f, x)
_ = f_bwd(x)
self.assertEqual(count[0], 2) # one for fwd, one for bwd
self.assertEqual(count[0], 2) # once for fwd, once for bwd

with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
_ = jax.vjp(f, x)
_, f_bwd2 = jax.vjp(f, x)
_ = f_bwd(x)
_ = f_bwd2(x)
self.assertEqual(count[0], 0) # cache hits on fwd and bwd

@unittest.skipIf(jax._src.lib._xla_extension_version < 44,
Expand Down