Skip to content

Commit ab0ce8a

Browse files
Merge pull request #26811 from dfm:direct-lin
PiperOrigin-RevId: 735388827
2 parents d2bf034 + b7ecfdf commit ab0ce8a

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

jax/_src/interpreters/ad.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -349,12 +349,34 @@ def write_primal(v, val):
349349
# forces primal_in to contain UndefinedPrimals for tangent values!
350350
map(write_primal, jaxpr.invars, primals_in)
351351

352+
# Start with a forward pass to evaluate any side-effect-free JaxprEqns that
353+
# only operate on primals. This is required to support primitives with
354+
# linearization rules that include computations on the residuals.
355+
lin_eqns = []
356+
for eqn in jaxpr.eqns:
357+
# TODO (dfm): The effects check is probably stricter than necessary.
358+
# Consider adding an allowlist of effects here.
359+
if jaxpr.effects or any(
360+
type(x) is not Literal and x not in primal_env for x in eqn.invars):
361+
lin_eqns.append(eqn)
362+
continue
363+
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
364+
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
365+
traceback = eqn.source_info.traceback
366+
with source_info_util.user_context(
367+
traceback, name_stack=name_stack), eqn.ctx.manager:
368+
ans = eqn.primitive.bind(*subfuns, *map(read_primal, eqn.invars), **bind_params)
369+
if eqn.primitive.multiple_results:
370+
map(write_primal, eqn.outvars, ans)
371+
else:
372+
write_primal(eqn.outvars[0], ans)
373+
352374
ct_env: dict[Any, Any] = {}
353375
ctx = (source_info_util.transform_name_stack('transpose') if transform_stack
354376
else contextlib.nullcontext())
355377
with ctx:
356378
map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
357-
for eqn in jaxpr.eqns[::-1]:
379+
for eqn in lin_eqns[::-1]:
358380
if eqn.primitive.ref_primitive:
359381
if eqn.primitive is core.mutable_array_p:
360382
val_var, = eqn.invars

tests/api_test.py

+14
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from jax._src import debugging
5858
from jax._src import pjit as pjit_lib
5959
from jax._src.ad_checkpoint import saved_residuals
60+
from jax._src.interpreters import ad as ad_internal
6061
from jax._src.interpreters import mlir
6162
from jax._src.interpreters import partial_eval as pe
6263
from jax._src.compilation_cache import is_persistent_cache_enabled
@@ -4732,6 +4733,19 @@ def sin_of_sin(x):
47324733

47334734
check_invariant_to_use_direct_linearize(lambda: jax.grad(sin_of_sin)(1.0))
47344735

4736+
def test_deferred_primal_with_direct_linearize(self):
4737+
def my_sin_lin(nzs, x):
4738+
nz, = nzs
4739+
return (my_sin_p.bind(x), nz, x, lambda x, t: lax.mul(t, lax.cos(x)))
4740+
4741+
my_sin_p = core.Primitive("my_sin_p")
4742+
my_sin_p.def_impl(lax.sin)
4743+
my_sin_p.def_abstract_eval(lambda x: x)
4744+
ad_internal.primitive_linearizations[my_sin_p] = my_sin_lin
4745+
4746+
with config.use_direct_linearize(True):
4747+
jax.grad(my_sin_p.bind)(1.0) # doesn't crash
4748+
47354749

47364750
class RematTest(jtu.JaxTestCase):
47374751

0 commit comments

Comments
 (0)