Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tracebacks no longer have JAX-internal frames prepended by default #16949

Merged
merged 1 commit into from
Aug 3, 2023

Conversation

patrick-kidger
Copy link
Collaborator

@patrick-kidger patrick-kidger commented Aug 3, 2023

This change is to make JAX errors a little less inscrutable.

When running:

import jax

@jax.grad
def g(x):
    bool(x)  # not valid code

@jax.jit
def f(x):
    return g(x)

f(1.)

Then for Python >=3.11 we now get (using the new add_notes feature):

❯ python file.py
Traceback (most recent call last):
  File "file.py", line 11, in <module>
    f(1.)
  File "file.py", line 9, in f
    return g(x)
           ^^^^
  File "file.py", line 5, in g
    bool(x)  # not valid code
    ^^^^^^^
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape float32[]..
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

and for Python <3.11 we now get:

❯ python file.py       
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "file.py", line 11, in <module>
    f(1.)
  File "file.py", line 9, in f
    return g(x)
           ^^^^
  File "file.py", line 5, in g
    bool(x)  # not valid code
    ^^^^^^^
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape float32[]..
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

and previously, we used to always get:

❯ python file.py
Traceback (most recent call last):
  File "file.py", line 11, in <module>
    f(1.)
  File "jax/_src/traceback_util.py", line 174, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "jax/_src/pjit.py", line 252, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
                                                 ^^^^^^^^^^^^^^^^^^^^
  File "jax/_src/pjit.py", line 160, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
                                                 ^^^^^^^^^^^^^^^^
  File "jax/_src/api.py", line 324, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "jax/_src/pjit.py", line 486, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
                                                      ^^^^^^^^^^^^
  File "jax/_src/pjit.py", line 964, in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(
                                    ^^^^^^^^^^^^^^^^^^^
  File "jax/_src/linear_util.py", line 345, in memoized_fun
    ans = call(fun, *args)
          ^^^^^^^^^^^^^^^^
  File "jax/_src/pjit.py", line 917, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "file.py", line 9, in f
    return g(x)
           ^^^^
  File "jax/_src/traceback_util.py", line 174, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "jax/_src/api.py", line 665, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "jax/_src/traceback_util.py", line 174, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "jax/_src/api.py", line 741, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "jax/_src/api.py", line 2247, in _vjp
    out_primal, out_vjp = ad.vjp(
                          ^^^^^^^
  File "jax/_src/interpreters/ad.py", line 140, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "jax/_src/interpreters/ad.py", line 129, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "jax/_src/interpreters/partial_eval.py", line 777, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
                                      ^^^^^^^^^^^^^^^^^^^^^^^
  File "jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "file.py", line 5, in g
    bool(x)  # not valid code
    ^^^^^^^
  File "jax/_src/core.py", line 673, in __bool__
    def __bool__(self): return self.aval._bool(self)
                               ^^^^^^^^^^^^^^^^^^^^^
  File "jax/_src/core.py", line 1383, in error
    raise TracerBoolConversionError(arg)
jax._src.traceback_util.UnfilteredStackTrace: jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape float32[]..
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "file.py", line 11, in <module>
    f(1.)
  File "file.py", line 9, in f
    return g(x)
           ^^^^
  File "file.py", line 5, in g
    bool(x)  # not valid code
    ^^^^^^^
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape float32[]..
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

Copy link
Collaborator

@hawkinsp hawkinsp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems promising.

Would you mind including some sample before and after output to the PR description?

I am also wondering: would it make sense to use a "note" under Python 3.11+ instead? https://peps.python.org/pep-0678/

@patrick-kidger
Copy link
Collaborator Author

Done!
Agreed, I like the use of notes. I've used that for Python >=3.11.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Aug 3, 2023
@copybara-service copybara-service bot merged commit a8388e2 into jax-ml:main Aug 3, 2023
7 checks passed
@patrick-kidger patrick-kidger deleted the simplified-traceback branch August 3, 2023 21:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants