From e45552209228e5cabed6732726caaf245df98101 Mon Sep 17 00:00:00 2001 From: Johanna Haffner Date: Fri, 20 Dec 2024 12:08:23 +0000 Subject: [PATCH] Make docstring of ForwardMode more precise, add references to it where forward-mode autodiff is mentioned in the other adjoints --- diffrax/_adjoint.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 9c084ecc..46338ea2 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -184,7 +184,9 @@ class RecursiveCheckpointAdjoint(AbstractAdjoint): !!! info Note that this cannot be forward-mode autodifferentiated. (E.g. using - `jax.jvp`.) Try using [`diffrax.DirectAdjoint`][] if that is something you need. + `jax.jvp`.) Try using [`diffrax.DirectAdjoint`][] if you need both forward-mode + and reverse-mode autodifferentiation, and [`diffrax.ForwardMode`][] if you need + only forward-mode autodifferentiation. ??? cite "References" @@ -333,6 +335,8 @@ class DirectAdjoint(AbstractAdjoint): So unless you need forward-mode autodifferentiation then [`diffrax.RecursiveCheckpointAdjoint`][] should be preferred. + If you need only forward-mode autodifferentiation, then [`diffrax.ForwardMode`][] is + more efficient. """ def loop( @@ -855,11 +859,13 @@ def loop( class ForwardMode(AbstractAdjoint): - """Differentiate through a differential equation solve during the forward pass. - (So it is not really an adjoint - it is a different way of quantifying the - sensitivity of the output to the input.) + """Supports forward-mode automatic differentiation through a differential equation + solve. This works by propagating the derivatives during the forward-pass - that is, + during the ODE solve, instead of solving the adjoint equations afterwards. + (So this is really a different way of quantifying the sensitivity of the output to + the input, even if its interface is that of an adjoint for convenience.) - ForwardMode is useful when we have many more outputs than inputs to a function - for + This is useful when we have many more outputs than inputs to a function - for instance during parameter inference for ODE models with least-squares solvers such as `optimistix.Levenberg-Marquardt`, that operate on the residuals. """