diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 7194bc33ea..5ef0d183b7 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -874,3 +874,18 @@ def remat( ), ) ) + """A 'lifted' version of the + `jax.checkpoint `__ + (a.k.a. ``jax.remat``). + + ``flax.nnx.remat``, similar to ``jax.checkpoint`` can provide control over, for + example, how ``flax.nnx.grad`` values are computed and saved during the forward pass versus + how they are recomputed during the backward pass, trading off memory and FLOPs. + + Learn more in `Flax NNX vs JAX Transformations `_. + + To learn about ``jax.remat``, go to JAX's + `fundamentals of jax.checkpoint `_ + and `practical notes `_. + """ +