Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: Support python control flow in custom_transforms functions #1275

Closed
lxuechen opened this issue Aug 30, 2019 · 7 comments
Closed
Labels
enhancement New feature or request

Comments

@lxuechen
Copy link

lxuechen commented Aug 30, 2019

For fitting parameter values for ODEs a la the adjoint sensitivity method, we might want to override the gradient computation for the forward ODE solve. More concretely, we might have an integrator function odeint that takes in a gradient field f, initial state y0, and a sequence of times ts to be evaluated at.

One specific use case where supporting control-flow in custom_transforms will be useful is for the backward integration (which might involve adaptive solvers, hence non-trivial control-flow). Ideally, we would like to write code as follows

@custom_transforms
def odeint(y0, ts):
  pass  # Some procedure integrating the vector field `f`.

def vjp_y0(g, ans, y0, ts):
  pass # A while loop and some if statements used to determine integration step size.

defvjp(odeint, vjp_y0, None)
@mattjj mattjj added the enhancement New feature or request label Sep 2, 2019
@jacobjinkelly
Copy link
Contributor

Does #1255 fix this?

@lxuechen
Copy link
Author

lxuechen commented Nov 4, 2019

Correct me if I'm wrong, but the custom vjp functions here still aren't using Python control-flow. I understand that there are workarounds using the lax-based control-flow, but it would be nice if could just write Python loops and if-statements.

This is perhaps not necessary for the ODE adjoint, and the example code doesn't seem that complicated. But it might present some overhead when the research we're trying to implement has more complicated control-flow.

@lxuechen
Copy link
Author

lxuechen commented Nov 4, 2019

Unsurprisingly, it seems that unless you want to use tf.function or graph-mode execution, writing a few Python if-statements and loops in the gradient function used for tf.custom_gradient doesn't cause any trouble for TF2.0.

Being able to do this in JAX would be quite helpful to a researcher where in many cases the goal is to test out ideas fast.

@shoyer
Copy link
Collaborator

shoyer commented Nov 4, 2019

One option for using arbitrary control flow is to write new JAX primitives: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html

@mattjj
Copy link
Collaborator

mattjj commented Nov 4, 2019

We’re working on this. The reason it isn’t as simple as it was for Autograd is that JAX uses a new autodiff design in which we only have forward mode and derive reverse mode automatically (composing forward mode with other transformations). That confers several advantages, but a disadvantage is that since the system itself doesn’t work in terms of VJPs, supporting custom VJPs is tricky. (You can write custom JVPs, ie forward-mode rules, with arbitrary Python control flow now.)

@mattjj
Copy link
Collaborator

mattjj commented Mar 22, 2020

#2026 finally landed and added support for Python control flow in custom derivative rules! (The API also changed, so take a look at the tutorial notebook.)

@mattjj mattjj closed this as completed Mar 22, 2020
@lxuechen
Copy link
Author

Thanks for consistently pushing this forward and the amazing work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants