Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when computing the gradient of a jitted function that uses id_tap #9172

Open
3 tasks done
antalszava opened this issue Jan 11, 2022 · 0 comments
Open
3 tasks done
Labels
bug Something isn't working

Comments

@antalszava
Copy link
Contributor

Please:

  • Check for duplicate issues.
  • Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:

Minimal (non-)working example:

from jax import jit, grad, numpy as jnp
from jax.experimental.host_callback import id_tap

from jax import custom_vjp

@jit
def main(x):
    acc = []
    id_tap(lambda a, t: acc.append(jnp.sin(a[0])), [x])
    return acc[0]

res = grad(main)(jnp.pi)
res

When using custom_vjp:

from jax import jit, grad, numpy as jnp
from jax.experimental.host_callback import id_tap

from jax import custom_vjp

@jit
def main(x):

    @custom_vjp
    def f(x):
        acc = []
        id_tap(lambda a, t: acc.append(jnp.sin(a[0])), [x])
        return acc[0]

    def f_fwd(x):
        return f(x), (x,)

    def f_bwd(res, g):
        acc = []
        y = res

        id_tap(lambda a, t: acc.append(tuple([jnp.cos(a[0])])), [y[0]])
        return acc[0]

    f.defvjp(f_fwd, f_bwd)
    return f(x)

res = grad(main)(jnp.pi)
res
  • If applicable, include full error messages/tracebacks.
---> 13         return acc[0]
     14 
     15     def f_fwd(x):

IndexError: list index out of range

Without jitting no error arises and we get the expected results: DeviceArray(-1., dtype=float32).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant