-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
lax.custom_linear_solve primitive #1402
Conversation
The primitive's implementation and JVP rules are now 100% pytree free.
Please take another look. I rewrote everything so |
This is ready for a final (?) review. |
"""Transpose a linear function.""" | ||
# TODO(shoyer): can we use something more direct than the vjp machinery? | ||
# It's particularly awkward that we need the second argument to give | ||
# particular values of the primals, which are entirely arbitrary. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. It might be a new transformation, but one that shares a lot of machinery with the existing jvp/transpose.
A follow-on to
lax.root
from #1339. See the tests for a simple example of defining gradients for an iterative linear solver.Eventually, I expect to use this for implementing iterative solvers (e.g., GMRES, CG) from
scipy.sparse.linalg
, which in turn we will use for the linear solves needed for gradients inscipy.optimize.root
.