Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tracebacks no longer have JAX-internal frames prepended by default #16949

Merged
merged 1 commit into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,17 +1025,21 @@ def _update_disable_jit_thread_local(val):

traceback_filtering = config.define_enum_state(
name = 'jax_traceback_filtering',
enum_values=["off", "tracebackhide", "remove_frames", "auto"],
enum_values=["off", "tracebackhide", "remove_frames", "quiet_remove_frames",
"auto"],
default="auto",
help="Controls how JAX filters internal frames out of tracebacks.\n\n"
"Valid values are:\n"
" * \"off\": disables traceback filtering.\n"
" * \"auto\": use \"tracebackhide\" if running under a sufficiently "
"new IPython, or \"remove_frames\" otherwise.\n"
" * \"tracebackhide\": adds \"__tracebackhide__\" annotations to "
" * \"auto\": use \"tracebackhide\" if running under a sufficiently"
" new IPython, or \"remove_frames\" otherwise.\n"
" * \"tracebackhide\": adds \"__tracebackhide__\" annotations to"
" hidden stack frames, which some traceback printers support.\n"
" * \"remove_frames\": removes hidden frames from tracebacks, and adds "
" the unfiltered traceback as a __cause__ of the exception.\n")
" * \"remove_frames\": removes hidden frames from tracebacks, and adds"
" the unfiltered traceback as a __cause__ of the exception.\n"
" * \"quiet_remove_frames\": removes hidden frames from tracebacks, and adds"
" a brief message (to the __cause__ of the exception) describing that this has"
" happened.\n")

# This flag is for internal use.
# TODO(tianjianlu): Removes once we always enable cusparse lowering.
Expand Down
52 changes: 37 additions & 15 deletions jax/_src/traceback_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import functools
import os
import sys
import traceback
import types
from typing import Any, Callable, Optional, TypeVar, cast
Expand Down Expand Up @@ -114,6 +115,16 @@ def format_exception_only(e: BaseException) -> str:

class UnfilteredStackTrace(Exception): pass

_simplified_tb_msg = ("For simplicity, JAX has removed its internal frames from the "
"traceback of the following exception. Set "
"JAX_TRACEBACK_FILTERING=off to include these.")

class SimplifiedTraceback(Exception):
def __str__(self):
return _simplified_tb_msg

SimplifiedTraceback.__module__ = "jax.errors"

def _running_under_ipython() -> bool:
"""Returns true if we appear to be in an IPython session."""
try:
Expand All @@ -133,7 +144,7 @@ def _filtering_mode() -> str:
if (_running_under_ipython() and _ipython_supports_tracebackhide()):
mode = "tracebackhide"
else:
mode = "remove_frames"
mode = "quiet_remove_frames"
return mode

def api_boundary(fun: C) -> C:
Expand Down Expand Up @@ -171,22 +182,12 @@ def reraise_with_filtered_traceback(*args, **kwargs):
if mode == "tracebackhide":
_add_tracebackhide_to_hidden_frames(e.__traceback__)
raise
assert mode == "remove_frames", mode

filtered_tb, unfiltered, mode = None, None, None
filtered_tb, unfiltered = None, None
try:
filtered_tb = filter_traceback(e.__traceback__)
msg = format_exception_only(e)
msg = f'{msg}\n\n{_jax_message_append}'
unfiltered = UnfilteredStackTrace(msg)
unfiltered.with_traceback(_add_call_stack_frames(e.__traceback__))
unfiltered.__context__ = e.__context__
unfiltered.__cause__ = e.__cause__
unfiltered.__suppress_context__ = e.__suppress_context__
e.__context__ = None
e.__cause__ = unfiltered

e.__traceback__ = filtered_tb
tb = e.__traceback__
filtered_tb = filter_traceback(tb)
e.with_traceback(filtered_tb)
# In Python < 3.11, there seems to be no way to alter the currently
# raised exception traceback, except via the C API. The interpreter
# keeps a copy of the traceback (exc_traceback) that is separate to the
Expand All @@ -195,7 +196,28 @@ def reraise_with_filtered_traceback(*args, **kwargs):
# the XLA extension no longer defines a traceback-replacing method at
# Python 3.11 and onward.
if hasattr(xla_extension, "replace_thread_exc_traceback"):
# TODO(kidger): remove this line once Python 3.11 is the minimum supported
# version.
xla_extension.replace_thread_exc_traceback(filtered_tb)
if sys.version_info >= (3, 11) and mode == "quiet_remove_frames":
e.add_note("--------------------\n" + _simplified_tb_msg)
else:
if mode == "quiet_remove_frames":
# TODO(kidger): remove `SimplifiedTraceback` once Python 3.11 is the minimum
# supported version.
jax_error = SimplifiedTraceback()
elif mode == "remove_frames":
msg = format_exception_only(e)
msg = f'{msg}\n\n{_jax_message_append}'
jax_error = UnfilteredStackTrace(msg)
jax_error.with_traceback(_add_call_stack_frames(tb))
else:
raise ValueError(f"JAX_TRACEBACK_FILTERING={mode} is not a valid value.")
jax_error.__cause__ = e.__cause__
jax_error.__context__ = e.__context__
jax_error.__suppress_context__ = e.__suppress_context__
e.__cause__ = jax_error
e.__context__ = None
raise
finally:
del filtered_tb
Expand Down
1 change: 1 addition & 0 deletions jax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
TracerIntegerConversionError as TracerIntegerConversionError,
UnexpectedTracerError as UnexpectedTracerError,
)
from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback
31 changes: 23 additions & 8 deletions tests/errors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import re
import sys
import traceback

from absl.testing import absltest
Expand Down Expand Up @@ -46,7 +47,12 @@ def check_filtered_stack_trace(test, etype, f, frame_patterns=(),
test.assertRaises(etype, f)
e = get_exception(etype, f)
c = e.__cause__
if filter_mode == "remove_frames":
if filter_mode == "quiet_remove_frames":
if sys.version_info >= (3, 11):
assert any("For simplicity" in x for x in e.__notes__)
else:
test.assertIsInstance(c, jax.errors.SimplifiedTraceback)
elif filter_mode == "remove_frames":
test.assertIsInstance(c, traceback_util.UnfilteredStackTrace)
else:
test.assertFalse(isinstance(c, traceback_util.UnfilteredStackTrace))
Expand Down Expand Up @@ -74,7 +80,7 @@ def check_filtered_stack_trace(test, etype, f, frame_patterns=(),
@jtu.with_config(jax_traceback_filtering='auto') # JaxTestCase defaults to off.
@parameterized.named_parameters(
{"testcase_name": f"_{f}", "filter_mode": f}
for f in ("tracebackhide", "remove_frames"))
for f in ("tracebackhide", "remove_frames", "quiet_remove_frames"))
class FilteredTracebackTest(jtu.JaxTestCase):

def test_nested_jit(self, filter_mode):
Expand Down Expand Up @@ -347,9 +353,13 @@ def outer(x):
check_filtered_stack_trace(self, TypeError, f, [
('<lambda>', 'f = lambda: outer'),
('outer', 'raise TypeError')], filter_mode=filter_mode)
e = get_exception(TypeError, f)
self.assertIsInstance(e.__cause__, traceback_util.UnfilteredStackTrace)
self.assertIsInstance(e.__cause__.__cause__, ValueError)
e = get_exception(TypeError, f) # Uses the default JAX_TRACEBACK_FILTERING=auto
if sys.version_info >= (3, 11):
assert any("For simplicity" in x for x in e.__notes__)
self.assertIsInstance(e.__cause__, ValueError)
else:
self.assertIsInstance(e.__cause__, jax.errors.SimplifiedTraceback)
self.assertIsInstance(e.__cause__.__cause__, ValueError)

def test_null_traceback(self, filter_mode):
class TestA: pass
Expand All @@ -375,9 +385,14 @@ def test_grad_norm(self):
e = exc
self.assertIsNot(e, None)
self.assertIn("invalid value", str(e))
self.assertIsInstance(
e.__cause__.__cause__,
source_info_util.JaxStackTraceBeforeTransformation)
if sys.version_info >= (3, 11):
self.assertIsInstance(
e.__cause__,
source_info_util.JaxStackTraceBeforeTransformation)
else:
self.assertIsInstance(
e.__cause__.__cause__,
source_info_util.JaxStackTraceBeforeTransformation)


class CustomErrorsTest(jtu.JaxTestCase):
Expand Down