Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

custom_transforms vjp rule clobbered under vmap #1249

Closed
mattjj opened this issue Aug 26, 2019 · 2 comments
Closed

custom_transforms vjp rule clobbered under vmap #1249

mattjj opened this issue Aug 26, 2019 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@mattjj
Copy link
Collaborator

mattjj commented Aug 26, 2019

(Thanks to @John-Jumper for pointing this out.)

@jax.custom_transforms
def f(x):
  return 2. * x
jax.defvjp_all(f, lambda x: (2. * x, lambda _: (3.,)))

print('grad', jax.grad(f)(1.))
print('vmap-of-grad', jax.vmap(jax.grad(f))(np.ones(4)))
print('grad-of-sum-vmap', jax.grad(lambda x: jax.vmap(f)(x).sum())(np.ones(4)))
grad 3.0
vmap-of-grad [3. 3. 3. 3.]
grad-of-sum-vmap [2. 2. 2. 2.]
@mattjj mattjj added the bug Something isn't working label Aug 26, 2019
@jekbradbury
Copy link
Contributor

When you write a custom_transforms primitive, you're basically saying "here is an implementation of this primitive that's valid for everything except the JVP interpreter; for that, use this other implementation." When you vmap that primitive, it vmaps the first implementation; when you run the JVP interpreter on the result of that (e.g. by calling grad), the primitive isn't there any more to be overridden, so JVP tries to work with what it's given.

This is arguably expected/correct behavior in the case where custom_transforms overrides preserve semantics (maybe they improve numerics or performance) but we're seeing custom_transforms be used in ways that go further than that (e.g. giving a JVP rule for a function JAX can't differentiate).

One fix for this would involve essentially coercing jvp(vmap(f)) to vmap(jvp(f)) (i.e., making sure the overridden interpreter always ends up as the innermost trace). But in general our transformations aren't commutative, so even that's somewhat limited.

For those running into this problem: the simplest solution is to manually move the overridden transformation inside any other (so if your code has grad(vmap(f)) try moving the vmap outside the grad). A more general, but more complex, user-level fix is to add a custom_transforms overload to define vmap(f) in terms of f (perhaps by making f rank-polymorphic).

@mattjj mattjj self-assigned this Jan 7, 2020
@mattjj
Copy link
Collaborator Author

mattjj commented Mar 22, 2020

Whew, #2026 fixed this!

@mattjj mattjj closed this as completed Mar 22, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants