You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to get both forward as well as reverse modes to work with a loop. Till now, I was using eqx.internals.scan.
From the internal documentation,
I see that when scan is checkpointed, only reverse-mode can be used while in lax mode is the same as jax.lax.scan?
The reason is that eqxi.scan(..., kind="checkpointed") uses jax.custom_vjp under the hood. Unfortunately, right now JAX only supports reverse-mode autodifferentiation through jax.custom_vjp.
(FWIW, the benefit of eqxi.scan over jax.lax.scan is that the former uses binomial checkpointing, so requires only logarithmic memory usage (in the number of steps) to backpropagate through.)
Hi Patrick,
I am trying to get both forward as well as reverse modes to work with a loop. Till now, I was using
eqx.internals.scan
.From the internal documentation,
I see that when
scan
is checkpointed, only reverse-mode can be used while inlax
mode is the same asjax.lax.scan
?From JAX's docs https://jax.readthedocs.io/en/latest/control-flow.html#summary,
lax.scan
is differentiable in both modes (and Jittable as well).Is there a reason for not supporting both modes here?
The text was updated successfully, but these errors were encountered: