-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.debug.breakpoint
crashes in a hard-to-describe way.
#16732
Comments
I'm having a similar issue where the code runs fine without breakpoints but gives jax.errors.UnexpectedTracerError while trying to debug. I tried checking where it could be coming form with jax.check_tracer_leaks but it points me to tracers that aren't leaking. Let me know if I can help with details if this seems related. |
- Now has a customisable number of frames, according to the `EQX_ON_ERROR_BREAKPOINT_FRAMES` environment variable. This is useful as JAX has a bug when combining certain operations with many frames (jax-ml/jax#16732). - `_module.py` is now excluded tracebacks.
- Now has a customisable number of frames, according to the `EQX_ON_ERROR_BREAKPOINT_FRAMES` environment variable. This is useful as JAX has a bug when combining certain operations with many frames (jax-ml/jax#16732). - `_module.py` is now excluded tracebacks.
Hi folks. This problem is affecting me right now with |
Description
This:
produces:
I think what's happenin is something like the following. The
eval_shape
call means thatx0
is a tracer. Thenjax.debug.breakpoint
grabsx0
from the stack frame above it. Thenjax.jit
saves it as a constant of its jaxpr. This means that JAX itself leaks the tracer.Tagging @sharadmv.
(Also, it doesn't suffice to just stop grabbing stack frames once you leave the JIT'd region. Consider this variant:
)
What jax/jaxlib version are you using?
0.4.13
The text was updated successfully, but these errors were encountered: