From 8fbe99cf654f61e4ca30cebc18f4602721774d7c Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 11 Nov 2024 21:26:00 +0000 Subject: [PATCH] Add flax.nnx.remat docstring --- flax/nnx/transforms/autodiff.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index b86823c527..038b9f63d2 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -873,3 +873,18 @@ def remat( ), ) ) + """A 'lifted' version of the + `jax.checkpoint `__ + (a.k.a. ``jax.remat``). + + ``flax.nnx.remat``, similar to ``jax.checkpoint`` can provide control over, for + example, how ``flax.nnx.grad`` values are computed and saved during the forward pass versus + how they are recomputed during the backward pass, trading off memory and FLOPs. + + Learn more in `Flax NNX vs JAX Transformations `_. + + To learn about ``jax.remat``, go to JAX's + `fundamentals of jax.checkpoint `_ + and `practical notes `_. + """ +