-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add custom_jvp / vjp, delete custom_transforms #2026
Conversation
e67d117
to
3c8e501
Compare
Very cool! I'm going to see how it looks to use this for How hard would it be to support defining a transpose rule along with the jvp rule, i.e., |
See #2029 for my attempt to use this in Generally this went well, but I think it turned up a bug for higher order differentiation of custom_jvp. Here's my simplified test case: import jax
def fwd(x):
return 2 * x, None
def jvp(aux, g):
return 3 * g
f = jax.custom_jvp(fwd, jvp)
def f2(x):
# should match f
y, _ = jax.jvp(f, (x,), (x,))
return y
x = 1.0
print('f:', f(x))
print('f jvp:', jax.jvp(f, (x,), (x,)))
print('f2:', f2(x))
print('f2 jvp:', jax.jvp(f2, (x,), (x,))) Outputs:
Notice that custom JVP rule only effects direct differentiation of My guess is that this comes down to how the custom JVP function is called directly inside |
Thanks for trying this out, @shoyer! Re: higher-order differentiation, that behavior was actually intentional. But I like the point you're making, namely that |
OK, great! I'm hoping that will also fix higher order differentiation for my differentiable binary search over in #2029. |
Btw, I'm not confident in the commit I pushed. Gotta iterate more tomorrow! |
b921d44
to
cde7e1f
Compare
b606c54
to
b11c0b2
Compare
Hi!
Alternatively, defining the implementation for Script I used for experiments: https://gist.github.com/IvanYashchuk/bd1a1aaddf952d66569a55b8fba48e67 |
@IvanYashchuk could you give a specific example of why you need these changes? What doesn't work with the current version of this PR? |
Let's consider only import jax
import jax.numpy as jnp
from jax.core import Primitive
@jax.api.custom_vjp
def external_cos(x):
return jnp.cos(x)
@jax.api.custom_vjp
def external_sin(x):
return jnp.sin(x)
def fwd_external_cos(x):
out_primal = external_cos(x)
out_tangent = -external_sin(x)
return out_primal, out_tangent
def rev_external_cos(fwd_out, g):
return (g * fwd_out,)
def fwd_external_sin(x):
out_primal = external_sin(x)
out_tangent = external_cos(x)
return out_primal, out_tangent
def rev_external_sin(fwd_out, g):
return (g * fwd_out,)
external_cos.defvjp(fwd_external_cos, rev_external_cos)
external_sin.defvjp(fwd_external_sin, rev_external_sin)
# vjp, nested grad and jacrev are OK as expected
x = 1.0
res, vjp_cos = jax.vjp(external_cos, x)
vjp_res = vjp_cos(1.0) # -sin(x) * 1.0
jax.grad(jax.grad(jax.grad(jax.grad(external_cos))))(1.0) # sin(x)
# This raises an error on current version of PR!
jax.jvp(external_cos, (jnp.ones(1),), (0.5*jnp.ones(1),)) Applying this small change IvanYashchuk/jax@97625bc makes it possible to use # Now this works with IvanYashchuk/jax@97625bc
res, jvp_res = jax.jvp(external_cos, (jnp.ones(1),), (0.5*jnp.ones(1),))
# or
_, jvp_fun = jax.linearize(external_cos, x)
jvp_res2 = jvp_fun(0.5)
assert jvp_res == jvp_res2 Good, but
With this change jax.jacfwd(external_cos)(jnp.ones((5, 1))) Also does work JVP with linearize function and grad _, jvp_fun = jax.linearize(external_cos, x)
jax.jvp(jvp_fun , (1.0,), (2.0,)) # (-sin(x)*primal, -sin(x)*tangent)
jax.jvp(jax.grad(external_cos), (1.0,), (2.0,)) # (-sin(primal), -cos(primal)*tangent) Hessian also works ff = lambda x: jnp.sum((external_cos)(x))
jax.hessian(ff)(jnp.ones((5, 1))) This is all great, seems like everything works now! However, |
@IvanYashchuk thanks for digging in! However, this won't work in general: tangents_out = rev.call_wrapped(*(res + tangents_in)) It's off by a transpose. You might not have noticed if you only tried it with elementwise functions, like sine and cosine, essentially because the transpose a diagonal operator is the operator itself. But for general functions (e.g. involving dot or conv) the shapes don't even work out for One thing we can do better is automatically derive a forward-mode rule from a custom VJP definition by automatic transposition, exactly analogous to how this PR already supports reverse mode over a custom JVP rule (see e.g. (By the way, another problem with the diff removing |
@mattjj thank you for the explanations! I'm just scratching the surface here and trying to understand the inner workings. I didn't notice this test case that NotImplementedError: Reverse-mode differentiation rule for 'custom_jvp_call_jaxpr' not implemented (Sorry for the same messy gist, lines 340-369) In my application, products with tangent or cotangent vectors happen implicitly through linear solve hence custom primitive is necessary, which I tried to emulate with |
906b676
to
265740a
Compare
@IvanYashchuk thanks for sharing that excellent code! You've identified a constraint I was not aware of: we can't currently use We could:
I'll look into the second option. @dougalm might find this interesting! |
I added a transpose rule, which assumes the jaxpr is linear. The code in your gist runs after you define a transpose rule for the I'm not sure if it'll come up much, but it was a nice thing to think about. |
bcbf144
to
303de1d
Compare
303de1d
to
7e480fa
Compare
temporarily revert parts of #2026 pending bug fix
fixes #116, #1097, #1249, #1275, #1366, #1723, #1670, #1875, #1938, #2345, #2346
See the design doc and user tutorial for details.
TODO: