-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom VJPs for external functions #1142
Comments
I'm having the same issue. I can not get the example code from the defvjp_all documentation to work. System information: Build information:
Minimal code to reproduce (jax_test.py):
Stack trace:
|
Currently custom gradients have the same restrictions as |
Even if we get this working, I suspect you may still find performance to be painfully slow if you can't |
Hi all, sorry for the slow response! @tpr0p @MRBaozi The issue here is the difference between a From the
Let me unpack that, because it's not very detailed. The In contrast, a JAX We haven't documented how to set up your own from jax import core
import numpy as onp # I changed this name out of habit
# Set up a Primitive, using a handy level of indirection
def foo(x):
return foo_p.bind(x)
foo_p = core.Primitive('foo') At this point there are no rules defined for In [2]: foo(3)
NotImplementedError: Evaluation rule for 'foo' not implemented Let's define an evaluation rule in terms of foo_p.def_impl(onp.square) And now: In [4]: foo(3)
Out[4]: 9 Woohoo! But we can't do anything else with it. We can add a VJP rule like this (though actually for all our primitives we instead define a JVP rule, this might be more familiar, cf. #636): from jax.interpreters import ad
ad.defvjp(foo_p, lambda g, x: 2 * x * g) And now: In [5]: from jax import grad
In [6]: grad(foo)(3.)
Out[6]: DeviceArray(6., dtype=float32) There's also an API closer to the one in @tpr0p 's original example: def f_vjp(x):
return foo(x), lambda g: (2 * g * x,)
ad.defvjp_all(foo_p, f_vjp) To use Does that make sense? What'd I miss? |
Thanks so much for the thorough explanation @mattjj! I managed to get everything working now. I can see how this approach is more general, only the solution is somewhat non-obvious from an outsider's perspective ;-) |
Hi @mattjj, the example code you give works fun. However, if my self-implement VJP function contains raw numpy functions, I will run into the same issue:
When I run the code:
System information: |
@HamletWantToCode my understanding is that JVP rules cannot make use of NumPy functions directly, because JAX wants to support higher order differentiation. You could make your pure-NumPy example work by defining another import numpy as onp
from jax.core import Primitive
from jax.interpreters.ad import defvjp
from jax import grad
# Define function to be differentiate
def foo(x):
return foo_p.bind(x)
foo_p = Primitive('foo')
def f(x):
return onp.sin(x)
foo_p.def_impl(f)
def dfoo(g, x):
return g*bar(x)
defvjp(foo_p, dfoo)
def bar(x):
return bar_p.bind(x)
bar_p = Primitive('bar')
def g(x):
return onp.cos(x)
bar_p.def_impl(g)
def dbar(g, x):
return -g*foo(x)
defvjp(bar_p, dbar) |
Thank you very much @shoyer |
@mattjj should we consider this fixed with the new custom gradients code? |
Actually, for external functions a new primitive should be used, not custom_jvp/vjp stuff. That is, external functions fall into case 2 articulated at the top of the Custom derivative rules for JAX-transformable Python functions tutorial. I think this topic is important enough that it needs its own tutorial explanation (i.e. I don't think the "How JAX primitives work" is quite the right explanation for people looking to solve this particular issue, just because we should have more direct examples for this use case). |
I changed the issue label to "documentation" so that we can add such a tutorial. |
The new custom gradients does make this easier in some cases, though indeed a primitive would allow for a more complete solution. Here's a rough prototype I worked out for TensorFlow 2 <-> JAX, which may be a useful point of reference: |
The current recommendation for wrapping external functions in a way that is compatible with JIT would be to use If you need to use a pre-existing VJP rule, then I think you need to use |
Hello. It seems that you were doing something like inverse photonics design with a solver. Currently, I am doing a project about topology optimization for photonics structure. Actually, I am struggling in applying Autograd into a solver, so that I can get the gradient. If possible, could I ask you how to do that because I still have no clue how to do that? Really appreciate that if we can discuss that. |
JAX now has a supported/recommended way of doing this: |
In case anyone is still interested in using autograd functions in a jit/vmap/jacrev compatible way, I have an experimental wrapper in the agjax package which addresses this: https://github.com/mfschubert/agjax |
Hi! I want to define custom gradients for a simulation for sensitivity analysis. I have been using autograd for this, but since it is not actively being developed anymore I wanted to switch to jax.
In autograd I would write something like this:
In autograd, this worked fine and I was able to chain this together with some other differentiable transformations and get gradients out of the whole thing.
From what I was able to gather, the above would be written in jax as follows:
However, this throws
Exception: Tracer can't be used with raw numpy functions.
, which I assume is because the simulation code does not use jax. Are the custom gradients in jax not black-boxes as in autograd anymore, i.e. is this a fundamental limitation or have I screwed something up? Do I need to implement this using lax primitives, and if so, how?I would be grateful for a minimal example implementing this for some arbitrary non-jax function. This code here for example works in autograd:
How would one translate this so it works in jax?
Thanks so much!
The text was updated successfully, but these errors were encountered: