You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@jax.custom_vjpdefcustom_identity(a):
returnadefcustom_identity_fwd(a):
y=custom_identity(a)
returny, (a,)
defcustom_identity_bwd(args, g):
# custom_identity_bwd must be implemented in terms of custom_identity# i.e. `return (g,)` does not trigger the error even though it is semantically# identicalreturn (custom_identity(g),)
custom_identity.defvjp(custom_identity_fwd, custom_identity_bwd)
defcustom_identity_in_scan(a):
defscan_fun(carry, a):
returncarry, custom_identity(a)
# custom_identity must be called inside a scan_, out=jax.lax.scan(
f=scan_fun,
init=a,
xs=jnp.stack([a] *3, axis=0)
)
returnout[0]
print(jax.grad(custom_identity)(1.0)) # prints 1.0print(jax.grad(custom_identity_in_scan)(1.0)) # <-- error here
The text was updated successfully, but these errors were encountered:
This code works in the incoming core revision #3370, so it's tempting just to wait for that to land rather than fixing it on master. How blocking is this for you?
I found I can work around this by adding another layer of indirection so that custom_identity_bwd is not implemented in terms of custom_identity, but instead both are implemented in terms of a separate _custom_identity_impl.
I expect this might not work with higher order derivatives but I don't need that right now so it is not blocking.
The following code snippit triggers an assertion in
process_custom_vjp_call
(here: https://github.com/google/jax/blob/master/jax/interpreters/partial_eval.py#L315). The comments imply that this should not be possible.The text was updated successfully, but these errors were encountered: