Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Impossible master trace type #3397

Closed
mdenil opened this issue Jun 10, 2020 · 2 comments · Fixed by #4038
Closed

Impossible master trace type #3397

mdenil opened this issue Jun 10, 2020 · 2 comments · Fixed by #4038
Assignees
Labels
bug Something isn't working

Comments

@mdenil
Copy link

mdenil commented Jun 10, 2020

The following code snippit triggers an assertion in process_custom_vjp_call (here: https://github.com/google/jax/blob/master/jax/interpreters/partial_eval.py#L315). The comments imply that this should not be possible.

@jax.custom_vjp
def custom_identity(a):
  return a
 
def custom_identity_fwd(a):
  y = custom_identity(a)
  return y, (a,)
 
def custom_identity_bwd(args, g):
  # custom_identity_bwd must be implemented in terms of custom_identity
  # i.e. `return (g,)` does not trigger the error even though it is semantically
  # identical
  return (custom_identity(g),)
 
custom_identity.defvjp(custom_identity_fwd, custom_identity_bwd)
 
 
def custom_identity_in_scan(a):
  def scan_fun(carry, a):
    return carry, custom_identity(a)
 
  # custom_identity must be called inside a scan
  _, out = jax.lax.scan(
      f=scan_fun,
      init=a,
      xs=jnp.stack([a] * 3, axis=0)
  )
 
  return out[0]
 
print(jax.grad(custom_identity)(1.0))  # prints 1.0
print(jax.grad(custom_identity_in_scan)(1.0))  # <-- error here
@mattjj mattjj self-assigned this Jun 10, 2020
@mattjj mattjj added the bug Something isn't working label Jun 10, 2020
@mattjj
Copy link
Collaborator

mattjj commented Jun 10, 2020

This code works in the incoming core revision #3370, so it's tempting just to wait for that to land rather than fixing it on master. How blocking is this for you?

@mdenil
Copy link
Author

mdenil commented Jun 11, 2020

Thanks for the quick response.

I found I can work around this by adding another layer of indirection so that custom_identity_bwd is not implemented in terms of custom_identity, but instead both are implemented in terms of a separate _custom_identity_impl.

I expect this might not work with higher order derivatives but I don't need that right now so it is not blocking.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants