diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 37ad40d22a3d..c8200fdf9809 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -330,12 +330,34 @@ def write_primal(v, val): # forces primal_in to contain UndefinedPrimals for tangent values! map(write_primal, jaxpr.invars, primals_in) + # Start with a forward pass to evaluate any side-effect-free JaxprEqns that + # only operate on primals. This is required to support primitives with + # linearization rules that include computations on the residuals. + lin_eqns = [] + for eqn in jaxpr.eqns: + # TODO (dfm): The effects check is probably stricter than necessary. + # Consider adding an allowlist of effects here. + if jaxpr.effects or any( + type(x) is not Literal and x not in primal_env for x in eqn.invars): + lin_eqns.append(eqn) + continue + subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) + name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack + traceback = eqn.source_info.traceback + with source_info_util.user_context( + traceback, name_stack=name_stack), eqn.ctx.manager: + ans = eqn.primitive.bind(*subfuns, *map(read_primal, eqn.invars), **bind_params) + if eqn.primitive.multiple_results: + map(write_primal, eqn.outvars, ans) + else: + write_primal(eqn.outvars[0], ans) + ct_env: dict[Any, Any] = {} ctx = (source_info_util.transform_name_stack('transpose') if transform_stack else contextlib.nullcontext()) with ctx: map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) - for eqn in jaxpr.eqns[::-1]: + for eqn in lin_eqns[::-1]: if eqn.primitive.ref_primitive: if eqn.primitive is core.mutable_array_p: val_var, = eqn.invars diff --git a/tests/api_test.py b/tests/api_test.py index ff729c03dd71..35e92f748a13 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -57,6 +57,7 @@ from jax._src import debugging from jax._src import pjit as pjit_lib from jax._src.ad_checkpoint import saved_residuals +from jax._src.interpreters import ad as ad_internal from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.compilation_cache import is_persistent_cache_enabled @@ -4732,6 +4733,19 @@ def sin_of_sin(x): check_invariant_to_use_direct_linearize(lambda: jax.grad(sin_of_sin)(1.0)) + def test_deferred_primal_with_direct_linearize(self): + def my_sin_lin(nzs, x): + nz, = nzs + return (my_sin_p.bind(x), nz, x, lambda x, t: lax.mul(t, lax.cos(x))) + + my_sin_p = core.Primitive("my_sin_p") + my_sin_p.def_impl(lax.sin) + my_sin_p.def_abstract_eval(lambda x: x) + ad_internal.primitive_linearizations[my_sin_p] = my_sin_lin + + with config.use_direct_linearize(True): + jax.grad(my_sin_p.bind)(1.0) # doesn't crash + class RematTest(jtu.JaxTestCase):