Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add custom_jvp / vjp, delete custom_transforms #2026

Merged
merged 1 commit into from
Mar 22, 2020
Merged

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Jan 19, 2020

fixes #116, #1097, #1249, #1275, #1366, #1723, #1670, #1875, #1938, #2345, #2346

See the design doc and user tutorial for details.

TODO:

  • improve error messages
  • write a PR message / design doc
  • add docstrings and reference docs
  • fix user code (Google code, NumPyro, jaxnet, anything else we think of...)

@mattjj mattjj requested a review from dougalm January 19, 2020 06:05
@mattjj mattjj force-pushed the custom-transforms3 branch 2 times, most recently from e67d117 to 3c8e501 Compare January 19, 2020 18:43
@shoyer
Copy link
Collaborator

shoyer commented Jan 20, 2020

Very cool!

I'm going to see how it looks to use this for custom_root.

How hard would it be to support defining a transpose rule along with the jvp rule, i.e., custom_jvp_and_transpose? I think this might be done by doing something like calling into custom_linearized_p.bind for custom_jvp_call as well as custom_vjp_call inside JVPTrace.process_custom_call. This would potentially be enough to let us use this machinery for custom_linear_solve (and other linear functions that need custom derivative rules, like linear_odeint).

@shoyer
Copy link
Collaborator

shoyer commented Jan 20, 2020

See #2029 for my attempt to use this in custom_root.

Generally this went well, but I think it turned up a bug for higher order differentiation of custom_jvp. Here's my simplified test case:

import jax

def fwd(x):
    return 2 * x, None

def jvp(aux, g):
    return 3 * g

f = jax.custom_jvp(fwd, jvp)

def f2(x):
    # should match f
    y, _ = jax.jvp(f, (x,), (x,))
    return y

x = 1.0
print('f:', f(x))
print('f jvp:', jax.jvp(f, (x,), (x,)))
print('f2:', f2(x))
print('f2 jvp:', jax.jvp(f2, (x,), (x,)))

Outputs:

f: 2.0
f jvp: (2.0, 3.0)
f2: 2.0
f2 jvp: (DeviceArray(2., dtype=float32), DeviceArray(2., dtype=float32))

Notice that custom JVP rule only effects direct differentiation of f. The alternate version f2 that should be equivalent uses the original derivative.

My guess is that this comes down to how the custom JVP function is called directly inside JVPTrace.process_custom_call rather than calling primitive.bind -- so the custom derivative rule no longer exists once you've gone through a JVP pass.

@mattjj
Copy link
Collaborator Author

mattjj commented Jan 20, 2020

Thanks for trying this out, @shoyer!

Re: higher-order differentiation, that behavior was actually intentional. But I like the point you're making, namely that f2 should match f, and I think you're right. I can adapt the implementation to reflect those semantics.

@shoyer
Copy link
Collaborator

shoyer commented Jan 20, 2020

Re: higher-order differentiation, that behavior was actually intentional. But I like the point you're making, namely that f2 should match f, and I think you're right. I can adapt the implementation to reflect those semantics.

OK, great! I'm hoping that will also fix higher order differentiation for my differentiable binary search over in #2029.

@mattjj
Copy link
Collaborator Author

mattjj commented Jan 20, 2020

Btw, I'm not confident in the commit I pushed. Gotta iterate more tomorrow!

jax/api.py Outdated Show resolved Hide resolved
@IvanYashchuk
Copy link

Hi!
There are two changes needed to make custom_vjp function to work both with jax.vjp and jax.jvp:

  1. Replace custom_lin_p with direct call to rev
-  tangents_out = custom_lin_p.bind(
-      *it.chain(res, tangents_in), num_res=res_tree.num_leaves, rev=rev,
-      avals_out=avals_out)
+  tangents_out = rev.call_wrapped(*(res + tangents_in))
  1. Define transpose rule for custom_vjp_call_jaxpr_p . How difficult is this transpose? I tried to implement it myself but couldn't make it yet.

Alternatively, defining the implementation for custom_lin_p makes it possible to use both jax.vjp and jax.jvp on custom_vjp functions.
IvanYashchuk@97625bc
However, as custom_lin_p does not have batching, jvp, xla rules defined jax.jacfwd etc. do not work.

Script I used for experiments: https://gist.github.com/IvanYashchuk/bd1a1aaddf952d66569a55b8fba48e67

@shoyer
Copy link
Collaborator

shoyer commented Mar 5, 2020

@IvanYashchuk could you give a specific example of why you need these changes? What doesn't work with the current version of this PR?

@IvanYashchuk
Copy link

Let's consider only custom_vjp here. In the current version of this PR jax.jvp doesn't work for functions wrapped with custom_vjp, because of https://github.com/google/jax/blob/b11c0b2a53b1ca790b5d69a55cae1926774524c9/jax/custom_derivatives.py#L317-L320
custom_vjp requires fwd and rev functions, which are basically jvp and transpose rules (am I correct?). So the ingredients are there.

import jax
import jax.numpy as jnp
from jax.core import Primitive

@jax.api.custom_vjp
def external_cos(x):
    return jnp.cos(x)

@jax.api.custom_vjp
def external_sin(x):
    return jnp.sin(x)

def fwd_external_cos(x):
    out_primal = external_cos(x)
    out_tangent = -external_sin(x)
    return out_primal, out_tangent

def rev_external_cos(fwd_out, g):
    return (g * fwd_out,)

def fwd_external_sin(x):
    out_primal = external_sin(x)
    out_tangent = external_cos(x)
    return out_primal, out_tangent

def rev_external_sin(fwd_out, g):
    return (g * fwd_out,)

external_cos.defvjp(fwd_external_cos, rev_external_cos)
external_sin.defvjp(fwd_external_sin, rev_external_sin)

# vjp, nested grad and jacrev are OK as expected
x = 1.0
res, vjp_cos = jax.vjp(external_cos, x)
vjp_res = vjp_cos(1.0)  # -sin(x) * 1.0

jax.grad(jax.grad(jax.grad(jax.grad(external_cos))))(1.0)  # sin(x)

# This raises an error on current version of PR!
jax.jvp(external_cos, (jnp.ones(1),), (0.5*jnp.ones(1),))

Applying this small change IvanYashchuk/jax@97625bc makes it possible to use jax.jvp on wrapped custom_vjp functions.

# Now this works with IvanYashchuk/jax@97625bc
res, jvp_res = jax.jvp(external_cos, (jnp.ones(1),), (0.5*jnp.ones(1),))
# or
_, jvp_fun = jax.linearize(external_cos, x)
jvp_res2 = jvp_fun(0.5)
assert jvp_res == jvp_res2

Good, but jax.jacfwd still does not work because Batching rule for custom_lin Primitive is not implemented. JVP over jax.linearize also does not work because JVP rule for custom_lin is not implemented. Let's try to fix it by not using custom_lin:
In custom_derivatives.py, in _custom_vjp_call_jvp function body replace custom_lin_p with direct call to rev

-  tangents_out = custom_lin_p.bind(
-      *it.chain(res, tangents_in), num_res=res_tree.num_leaves, rev=rev,
-      avals_out=avals_out)
+  tangents_out = rev.call_wrapped(*(res + tangents_in))

With this change jacfwd works!

jax.jacfwd(external_cos)(jnp.ones((5, 1)))

Also does work JVP with linearize function and grad

_, jvp_fun = jax.linearize(external_cos, x)
jax.jvp(jvp_fun , (1.0,), (2.0,))  # (-sin(x)*primal, -sin(x)*tangent)
jax.jvp(jax.grad(external_cos), (1.0,), (2.0,)) # (-sin(primal), -cos(primal)*tangent)

Hessian also works

ff = lambda x: jnp.sum((external_cos)(x))
jax.hessian(ff)(jnp.ones((5, 1)))

This is all great, seems like everything works now!

However, _custom_vjp_call_jaxpr path gets broken. jvp and jacfwd work but vjp gets broken because now Transpose rule for custom_vjp_call_jaxpr Primitive is required, which is not implemented.
JVP over jax.linearize doesn't work: error in custom_derivatives.py in _flatten_rev
ValueError: Too many leaves for PyTreeDef; expected 1.
This _custom_vjp_call_jaxpr path is taken for example when wrapped custom_vjp function binds to custom Primitive, like in this gist.

@mattjj
Copy link
Collaborator Author

mattjj commented Mar 6, 2020

@IvanYashchuk thanks for digging in!

However, this won't work in general:

tangents_out = rev.call_wrapped(*(res + tangents_in))

It's off by a transpose. You might not have noticed if you only tried it with elementwise functions, like sine and cosine, essentially because the transpose a diagonal operator is the operator itself. But for general functions (e.g. involving dot or conv) the shapes don't even work out for rev to be applied to tangents_in.

One thing we can do better is automatically derive a forward-mode rule from a custom VJP definition by automatic transposition, exactly analogous to how this PR already supports reverse mode over a custom JVP rule (see e.g. CustomJVPTest.test_basic for an example of grad over a custom_jvp function). I didn't add that, and instead raised an error, because I wasn't sure anyone would care about it. Is this a real use case for you? If so, we can open a feature request for it, though I'm not likely to add it in this PR.

(By the way, another problem with the diff removing custom_lin_p.bind is that rev is called on the forward pass, which would break some of our highest-priority use cases for this code, e.g. being able to insert a debugger trace on the backward pass. custom_lin exists for a reason.)

@IvanYashchuk
Copy link

IvanYashchuk commented Mar 6, 2020

@mattjj thank you for the explanations! I'm just scratching the surface here and trying to understand the inner workings.

I didn't notice this test case that grad is working for custom_jvp and apparently vjp over grad also works. When I was testing it myself I used custom primitives for all operations and I hit

NotImplementedError: Reverse-mode differentiation rule for 'custom_jvp_call_jaxpr' not implemented

(Sorry for the same messy gist, lines 340-369)

In my application, products with tangent or cotangent vectors happen implicitly through linear solve hence custom primitive is necessary, which I tried to emulate with external product in my experiments.
For my use case I think I almost got what I need by defining all rules for primitives directly without using helper functions defvjp, defjvp and be able to do forward-over-reverse and higher-order.

@mattjj mattjj force-pushed the custom-transforms3 branch 2 times, most recently from 906b676 to 265740a Compare March 11, 2020 22:47
@mattjj
Copy link
Collaborator Author

mattjj commented Mar 11, 2020

@IvanYashchuk thanks for sharing that excellent code!

You've identified a constraint I was not aware of: we can't currently use custom_jvp on a linear function. And external_product is linear in each of its arguments separately; moreover, it's used in a computation that must be linear (namely the expression for out_tangent in both jvp_external_sin and jvp_external_cos). (The error message being raised is a bit misleading: the error is actually that there's no transpose rule for custom_jvp_call_jaxpr_p.)

We could:

  1. just say you can't use custom_jvp on a linear function and be happy with that
  2. attempt to transpose custom_jvp_call_jaxpr like we do regular calls, which is only correct (and will only succeed) if the jaxpr being called is in fact linear, though it'd be nicer if we could check up front that the jaxpr is linear

I'll look into the second option.

@dougalm might find this interesting!

@mattjj
Copy link
Collaborator Author

mattjj commented Mar 12, 2020

I added a transpose rule, which assumes the jaxpr is linear. The code in your gist runs after you define a transpose rule for the external_product_p primitive (which can be the same as the transpose rule for mul_p defined in lax.py).

I'm not sure if it'll come up much, but it was a nice thing to think about.

@mattjj mattjj force-pushed the custom-transforms3 branch 2 times, most recently from bcbf144 to 303de1d Compare March 22, 2020 05:04
@mattjj mattjj force-pushed the custom-transforms3 branch from 303de1d to 7e480fa Compare March 22, 2020 05:08
@mattjj mattjj merged commit 069cb3e into master Mar 22, 2020
mattjj added a commit that referenced this pull request Mar 26, 2020
mattjj added a commit that referenced this pull request Mar 26, 2020
temporarily revert parts of #2026 pending bug fix
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Apr 2, 2020
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Apr 13, 2020
@jakevdp jakevdp deleted the custom-transforms3 branch October 6, 2021 19:35
@froystig froystig added the JEP JAX enhancement proposal label Aug 9, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes JEP JAX enhancement proposal
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Easy api for custom primitives and vjps
8 participants