-
-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Forward mode "adjoint" #537
Forward mode "adjoint" #537
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nita aside this LGTM! Once you fix things up I'll merge this into a new dev branch.
diffrax/_adjoint.py
Outdated
|
||
class ForwardMode(AbstractAdjoint): | ||
"""Differentiate through a differential equation solve during the forward pass. | ||
(So it is not really an adjoint - it is a different way of quantifying the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might specify what an adjoint is here, to make this clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gave it a shot!
… but for diffrax)
…e forward-mode autodiff is mentioned in the other adjoints
f75dd51
to
e455522
Compare
Completed the above fixes :) I also added references to While going though the errata raised by the other adjoints I was also wondering if there is a case |
LGTM! And merged :) As for |
* add .venv * add code, tests and documentation for ForwardAdjoint * make version of mkdocs-autorefs explicit (patrick-kidger/optimistix#91, but for diffrax) * rename, add documentation, explicate lack of test covarage for unit-input case. * rename import of ForwardMode * fix duplicate * Make docstring of ForwardMode more precise, add references to it where forward-mode autodiff is mentioned in the other adjoints --------- Co-authored-by: Johanna Haffner <johanna.haffner@bsse.ethz.ch>
Here you go! This is the pragmatic solution, without support or test coverage for integer inputs and only a small comment explicating that forward mode is not really an adjoint, even though its diffrax interface is that of an adjoint.
Changes with respect to the last PR:
ForwardMode
everywhereAbstractAdjoint
test_adjoint.py
and explain that since JAX does not offer this option, we're not writing our own workaround to test it eitherOn the last point: if I understood this correctly, then supporting this would entail writing a gradient-computation directly from a JVP with custom "unit pytrees". This is somewhat annoying for mixed array and non-array types.
I'm happy to try again if computing gradients with respect to integer elements of a PyTree is an expected use case (maybe arising from composed/layered transformations of a solve) that requires test coverage.
Earlier comments here.
(This is now rebased on main.)