Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward mode "adjoint" #537

Merged
merged 7 commits into from
Dec 22, 2024

Conversation

johannahaffner
Copy link
Contributor

@johannahaffner johannahaffner commented Dec 9, 2024

Here you go! This is the pragmatic solution, without support or test coverage for integer inputs and only a small comment explicating that forward mode is not really an adjoint, even though its diffrax interface is that of an adjoint.

Changes with respect to the last PR:

  • renamed to ForwardMode everywhere
  • add a sentence to the docstring that explains that this is not really an adjoint, but keep inheriting from AbstractAdjoint
  • remove stub for "forward gradient with int" from test_adjoint.py and explain that since JAX does not offer this option, we're not writing our own workaround to test it either

On the last point: if I understood this correctly, then supporting this would entail writing a gradient-computation directly from a JVP with custom "unit pytrees". This is somewhat annoying for mixed array and non-array types.
I'm happy to try again if computing gradients with respect to integer elements of a PyTree is an expected use case (maybe arising from composed/layered transformations of a solve) that requires test coverage.

Earlier comments here.

(This is now rebased on main.)

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nita aside this LGTM! Once you fix things up I'll merge this into a new dev branch.

diffrax/_adjoint.py Outdated Show resolved Hide resolved

class ForwardMode(AbstractAdjoint):
"""Differentiate through a differential equation solve during the forward pass.
(So it is not really an adjoint - it is a different way of quantifying the
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might specify what an adjoint is here, to make this clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gave it a shot!

docs/requirements.txt Outdated Show resolved Hide resolved
@johannahaffner
Copy link
Contributor Author

Completed the above fixes :)

I also added references to ForwardMode to the docstrings of the other adjoints, where they point to DirectAdjoint for forward-mode automatic differentiation. I'm recommending it for the case where only forward-mode automatic differentiation is required, but I'm not sure if I should hedge/tone down my statement that ForwardMode is more efficient than DirectAdjoint. That is - I'm not sure if this statement holds in general. What do you think?

While going though the errata raised by the other adjoints I was also wondering if there is a case ForwardMode cannot handle, such as the UnsafeBrownianPath for which some of the other adjoints error out.

@johannahaffner
Copy link
Contributor Author

I benchmarked these a little bit by computing the Jacobian of the output ys for a variable number of states, once with one parameter for each state and once with a single parameter for all states. (So the Jacobians had shape (num_timepoints, num_states, num_states) in the first case and (num_timepoints, num_states, 1) in the second case.)

I did not include any couplings between states, but this roughly makes what I have seen on my own data (with coupled ODEs), where the DirectAdjoint takes about 1.5x as long as ForwardMode.

Equal number of states and parameters, uncoupled ODE

Largest Jacobian (1024, 512, 512). DirectAdjoint takes ~2.3x as long at high and low end.

Screenshot 2024-12-21 at 12 44 30

One parameter for all states, uncoupled ODE

Largest Jacobian (4096, 512, 1). DirectAdjoint takes ~1.15x as long at high end, ~2.4x as long at low end

Screenshot 2024-12-21 at 13 05 50

@patrick-kidger patrick-kidger changed the base branch from main to dev December 21, 2024 21:11
@patrick-kidger patrick-kidger merged commit 0beb5ce into patrick-kidger:dev Dec 22, 2024
2 checks passed
@patrick-kidger
Copy link
Owner

LGTM! And merged :)

As for DirectAdjoint -- yup, it's totally expected that this will be slower than ForwardMode. Honestly DirectAdjoint is kind of a terrible choice, it only exists because of the very occasional need to support both forward and reverse mode through the same solve. (Hopefully one day JAX implements jvp-of-custom_vjp.)

@johannahaffner johannahaffner deleted the forward-mode branch December 22, 2024 09:52
patrick-kidger pushed a commit that referenced this pull request Dec 24, 2024
* add .venv

* add code, tests and documentation for ForwardAdjoint

* make version of mkdocs-autorefs explicit (patrick-kidger/optimistix#91, but for diffrax)

* rename, add documentation, explicate lack of test covarage for unit-input case.

* rename import of ForwardMode

* fix duplicate

* Make docstring of ForwardMode more precise, add references to it where forward-mode autodiff is mentioned in the other adjoints

---------

Co-authored-by: Johanna Haffner <johanna.haffner@bsse.ethz.ch>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants