We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
id_tap
Please:
Minimal (non-)working example:
from jax import jit, grad, numpy as jnp from jax.experimental.host_callback import id_tap from jax import custom_vjp @jit def main(x): acc = [] id_tap(lambda a, t: acc.append(jnp.sin(a[0])), [x]) return acc[0] res = grad(main)(jnp.pi) res
When using custom_vjp:
custom_vjp
from jax import jit, grad, numpy as jnp from jax.experimental.host_callback import id_tap from jax import custom_vjp @jit def main(x): @custom_vjp def f(x): acc = [] id_tap(lambda a, t: acc.append(jnp.sin(a[0])), [x]) return acc[0] def f_fwd(x): return f(x), (x,) def f_bwd(res, g): acc = [] y = res id_tap(lambda a, t: acc.append(tuple([jnp.cos(a[0])])), [y[0]]) return acc[0] f.defvjp(f_fwd, f_bwd) return f(x) res = grad(main)(jnp.pi) res
---> 13 return acc[0] 14 15 def f_fwd(x): IndexError: list index out of range
Without jitting no error arises and we get the expected results: DeviceArray(-1., dtype=float32).
DeviceArray(-1., dtype=float32)
The text was updated successfully, but these errors were encountered:
host_callback.id_tap
No branches or pull requests
Please:
Minimal (non-)working example:
When using
custom_vjp
:Without jitting no error arises and we get the expected results:
DeviceArray(-1., dtype=float32)
.The text was updated successfully, but these errors were encountered: