Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.debug.breakpoint crashes in a hard-to-describe way. #16732

Open
patrick-kidger opened this issue Jul 14, 2023 · 2 comments
Open

jax.debug.breakpoint crashes in a hard-to-describe way. #16732

patrick-kidger opened this issue Jul 14, 2023 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@patrick-kidger
Copy link
Collaborator

patrick-kidger commented Jul 14, 2023

Description

This:

import jax

@jax.jit
def brk():
    jax.debug.breakpoint()

def fn():
    x0 = jax.numpy.zeros(2)
    brk()

jax.eval_shape(fn)
fn()

produces:

jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.

I think what's happenin is something like the following. The eval_shape call means that x0 is a tracer. Then jax.debug.breakpoint grabs x0 from the stack frame above it. Then jax.jit saves it as a constant of its jaxpr. This means that JAX itself leaks the tracer.

Tagging @sharadmv.

(Also, it doesn't suffice to just stop grabbing stack frames once you leave the JIT'd region. Consider this variant:

import jax

@jax.jit
def brk():
    jax.debug.breakpoint()

def fn():
    x0 = jax.numpy.zeros(2)
    brk()

@jax.jit
def run():
    jax.eval_shape(fn)
    fn()

run()

)

What jax/jaxlib version are you using?

0.4.13

@DiegoRenner
Copy link

I'm having a similar issue where the code runs fine without breakpoints but gives jax.errors.UnexpectedTracerError while trying to debug.

I tried checking where it could be coming form with jax.check_tracer_leaks but it points me to tracers that aren't leaking.

Let me know if I can help with details if this seems related.

patrick-kidger added a commit to patrick-kidger/equinox that referenced this issue Oct 10, 2023
- Now has a customisable number of frames, according to the
  `EQX_ON_ERROR_BREAKPOINT_FRAMES` environment variable. This is useful
  as JAX has a bug when combining certain operations with many frames
  (jax-ml/jax#16732).
- `_module.py` is now excluded tracebacks.
patrick-kidger added a commit to patrick-kidger/equinox that referenced this issue Oct 12, 2023
- Now has a customisable number of frames, according to the
  `EQX_ON_ERROR_BREAKPOINT_FRAMES` environment variable. This is useful
  as JAX has a bug when combining certain operations with many frames
  (jax-ml/jax#16732).
- `_module.py` is now excluded tracebacks.
@cool-RR
Copy link
Contributor

cool-RR commented Apr 29, 2024

Hi folks. This problem is affecting me right now with jax==0.4.26. Solution aside, is there a workaround? I'm not sure how to debug my program otherwise...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants