Skip to content

Commit

Permalink
Make docstring of ForwardMode more precise, add references to it wher…
Browse files Browse the repository at this point in the history
…e forward-mode autodiff is mentioned in the other adjoints
  • Loading branch information
Johanna Haffner committed Dec 20, 2024
1 parent 08907c7 commit e455522
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ class RecursiveCheckpointAdjoint(AbstractAdjoint):
!!! info
Note that this cannot be forward-mode autodifferentiated. (E.g. using
`jax.jvp`.) Try using [`diffrax.DirectAdjoint`][] if that is something you need.
`jax.jvp`.) Try using [`diffrax.DirectAdjoint`][] if you need both forward-mode
and reverse-mode autodifferentiation, and [`diffrax.ForwardMode`][] if you need
only forward-mode autodifferentiation.
??? cite "References"
Expand Down Expand Up @@ -333,6 +335,8 @@ class DirectAdjoint(AbstractAdjoint):
So unless you need forward-mode autodifferentiation then
[`diffrax.RecursiveCheckpointAdjoint`][] should be preferred.
If you need only forward-mode autodifferentiation, then [`diffrax.ForwardMode`][] is
more efficient.
"""

def loop(
Expand Down Expand Up @@ -855,11 +859,13 @@ def loop(


class ForwardMode(AbstractAdjoint):
"""Differentiate through a differential equation solve during the forward pass.
(So it is not really an adjoint - it is a different way of quantifying the
sensitivity of the output to the input.)
"""Supports forward-mode automatic differentiation through a differential equation
solve. This works by propagating the derivatives during the forward-pass - that is,
during the ODE solve, instead of solving the adjoint equations afterwards.
(So this is really a different way of quantifying the sensitivity of the output to
the input, even if its interface is that of an adjoint for convenience.)
ForwardMode is useful when we have many more outputs than inputs to a function - for
This is useful when we have many more outputs than inputs to a function - for
instance during parameter inference for ODE models with least-squares solvers such
as `optimistix.Levenberg-Marquardt`, that operate on the residuals.
"""
Expand Down

0 comments on commit e455522

Please sign in to comment.