-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature request: Support python control flow in custom_transforms
functions
#1275
Comments
Does #1255 fix this? |
Correct me if I'm wrong, but the custom vjp functions here still aren't using Python control-flow. I understand that there are workarounds using the lax-based control-flow, but it would be nice if could just write Python loops and if-statements. This is perhaps not necessary for the ODE adjoint, and the example code doesn't seem that complicated. But it might present some overhead when the research we're trying to implement has more complicated control-flow. |
Unsurprisingly, it seems that unless you want to use Being able to do this in JAX would be quite helpful to a researcher where in many cases the goal is to test out ideas fast. |
One option for using arbitrary control flow is to write new JAX primitives: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html |
We’re working on this. The reason it isn’t as simple as it was for Autograd is that JAX uses a new autodiff design in which we only have forward mode and derive reverse mode automatically (composing forward mode with other transformations). That confers several advantages, but a disadvantage is that since the system itself doesn’t work in terms of VJPs, supporting custom VJPs is tricky. (You can write custom JVPs, ie forward-mode rules, with arbitrary Python control flow now.) |
#2026 finally landed and added support for Python control flow in custom derivative rules! (The API also changed, so take a look at the tutorial notebook.) |
Thanks for consistently pushing this forward and the amazing work! |
For fitting parameter values for ODEs a la the adjoint sensitivity method, we might want to override the gradient computation for the forward ODE solve. More concretely, we might have an integrator function
odeint
that takes in a gradient fieldf
, initial statey0
, and a sequence of timests
to be evaluated at.One specific use case where supporting control-flow in
custom_transforms
will be useful is for the backward integration (which might involve adaptive solvers, hence non-trivial control-flow). Ideally, we would like to write code as followsThe text was updated successfully, but these errors were encountered: