@@ -349,12 +349,34 @@ def write_primal(v, val):
349
349
# forces primal_in to contain UndefinedPrimals for tangent values!
350
350
map (write_primal , jaxpr .invars , primals_in )
351
351
352
+ # Start with a forward pass to evaluate any side-effect-free JaxprEqns that
353
+ # only operate on primals. This is required to support primitives with
354
+ # linearization rules that include computations on the residuals.
355
+ lin_eqns = []
356
+ for eqn in jaxpr .eqns :
357
+ # TODO (dfm): The effects check is probably stricter than necessary.
358
+ # Consider adding an allowlist of effects here.
359
+ if jaxpr .effects or any (
360
+ type (x ) is not Literal and x not in primal_env for x in eqn .invars ):
361
+ lin_eqns .append (eqn )
362
+ continue
363
+ subfuns , bind_params = eqn .primitive .get_bind_params (eqn .params )
364
+ name_stack = source_info_util .current_name_stack () + eqn .source_info .name_stack
365
+ traceback = eqn .source_info .traceback
366
+ with source_info_util .user_context (
367
+ traceback , name_stack = name_stack ), eqn .ctx .manager :
368
+ ans = eqn .primitive .bind (* subfuns , * map (read_primal , eqn .invars ), ** bind_params )
369
+ if eqn .primitive .multiple_results :
370
+ map (write_primal , eqn .outvars , ans )
371
+ else :
372
+ write_primal (eqn .outvars [0 ], ans )
373
+
352
374
ct_env : dict [Any , Any ] = {}
353
375
ctx = (source_info_util .transform_name_stack ('transpose' ) if transform_stack
354
376
else contextlib .nullcontext ())
355
377
with ctx :
356
378
map (partial (write_cotangent , 'outvars' ), jaxpr .outvars , cotangents_in )
357
- for eqn in jaxpr . eqns [::- 1 ]:
379
+ for eqn in lin_eqns [::- 1 ]:
358
380
if eqn .primitive .ref_primitive :
359
381
if eqn .primitive is core .mutable_array_p :
360
382
val_var , = eqn .invars
0 commit comments