Skip to content

Commit

Permalink
Merge pull request #16949 from patrick-kidger:simplified-traceback
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 553580442
  • Loading branch information
jax authors committed Aug 3, 2023
2 parents 7f068dc + 5e276d0 commit a8388e2
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 29 deletions.
16 changes: 10 additions & 6 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,17 +1025,21 @@ def _update_disable_jit_thread_local(val):

traceback_filtering = config.define_enum_state(
name = 'jax_traceback_filtering',
enum_values=["off", "tracebackhide", "remove_frames", "auto"],
enum_values=["off", "tracebackhide", "remove_frames", "quiet_remove_frames",
"auto"],
default="auto",
help="Controls how JAX filters internal frames out of tracebacks.\n\n"
"Valid values are:\n"
" * \"off\": disables traceback filtering.\n"
" * \"auto\": use \"tracebackhide\" if running under a sufficiently "
"new IPython, or \"remove_frames\" otherwise.\n"
" * \"tracebackhide\": adds \"__tracebackhide__\" annotations to "
" * \"auto\": use \"tracebackhide\" if running under a sufficiently"
" new IPython, or \"remove_frames\" otherwise.\n"
" * \"tracebackhide\": adds \"__tracebackhide__\" annotations to"
" hidden stack frames, which some traceback printers support.\n"
" * \"remove_frames\": removes hidden frames from tracebacks, and adds "
" the unfiltered traceback as a __cause__ of the exception.\n")
" * \"remove_frames\": removes hidden frames from tracebacks, and adds"
" the unfiltered traceback as a __cause__ of the exception.\n"
" * \"quiet_remove_frames\": removes hidden frames from tracebacks, and adds"
" a brief message (to the __cause__ of the exception) describing that this has"
" happened.\n")

# This flag is for internal use.
# TODO(tianjianlu): Removes once we always enable cusparse lowering.
Expand Down
52 changes: 37 additions & 15 deletions jax/_src/traceback_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import functools
import os
import sys
import traceback
import types
from typing import Any, Callable, Optional, TypeVar, cast
Expand Down Expand Up @@ -114,6 +115,16 @@ def format_exception_only(e: BaseException) -> str:

class UnfilteredStackTrace(Exception): pass

_simplified_tb_msg = ("For simplicity, JAX has removed its internal frames from the "
"traceback of the following exception. Set "
"JAX_TRACEBACK_FILTERING=off to include these.")

class SimplifiedTraceback(Exception):
def __str__(self):
return _simplified_tb_msg

SimplifiedTraceback.__module__ = "jax.errors"

def _running_under_ipython() -> bool:
"""Returns true if we appear to be in an IPython session."""
try:
Expand All @@ -133,7 +144,7 @@ def _filtering_mode() -> str:
if (_running_under_ipython() and _ipython_supports_tracebackhide()):
mode = "tracebackhide"
else:
mode = "remove_frames"
mode = "quiet_remove_frames"
return mode

def api_boundary(fun: C) -> C:
Expand Down Expand Up @@ -171,22 +182,12 @@ def reraise_with_filtered_traceback(*args, **kwargs):
if mode == "tracebackhide":
_add_tracebackhide_to_hidden_frames(e.__traceback__)
raise
assert mode == "remove_frames", mode

filtered_tb, unfiltered, mode = None, None, None
filtered_tb, unfiltered = None, None
try:
filtered_tb = filter_traceback(e.__traceback__)
msg = format_exception_only(e)
msg = f'{msg}\n\n{_jax_message_append}'
unfiltered = UnfilteredStackTrace(msg)
unfiltered.with_traceback(_add_call_stack_frames(e.__traceback__))
unfiltered.__context__ = e.__context__
unfiltered.__cause__ = e.__cause__
unfiltered.__suppress_context__ = e.__suppress_context__
e.__context__ = None
e.__cause__ = unfiltered

e.__traceback__ = filtered_tb
tb = e.__traceback__
filtered_tb = filter_traceback(tb)
e.with_traceback(filtered_tb)
# In Python < 3.11, there seems to be no way to alter the currently
# raised exception traceback, except via the C API. The interpreter
# keeps a copy of the traceback (exc_traceback) that is separate to the
Expand All @@ -195,7 +196,28 @@ def reraise_with_filtered_traceback(*args, **kwargs):
# the XLA extension no longer defines a traceback-replacing method at
# Python 3.11 and onward.
if hasattr(xla_extension, "replace_thread_exc_traceback"):
# TODO(kidger): remove this line once Python 3.11 is the minimum supported
# version.
xla_extension.replace_thread_exc_traceback(filtered_tb)
if sys.version_info >= (3, 11) and mode == "quiet_remove_frames":
e.add_note("--------------------\n" + _simplified_tb_msg)
else:
if mode == "quiet_remove_frames":
# TODO(kidger): remove `SimplifiedTraceback` once Python 3.11 is the minimum
# supported version.
jax_error = SimplifiedTraceback()
elif mode == "remove_frames":
msg = format_exception_only(e)
msg = f'{msg}\n\n{_jax_message_append}'
jax_error = UnfilteredStackTrace(msg)
jax_error.with_traceback(_add_call_stack_frames(tb))
else:
raise ValueError(f"JAX_TRACEBACK_FILTERING={mode} is not a valid value.")
jax_error.__cause__ = e.__cause__
jax_error.__context__ = e.__context__
jax_error.__suppress_context__ = e.__suppress_context__
e.__cause__ = jax_error
e.__context__ = None
raise
finally:
del filtered_tb
Expand Down
1 change: 1 addition & 0 deletions jax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
TracerIntegerConversionError as TracerIntegerConversionError,
UnexpectedTracerError as UnexpectedTracerError,
)
from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback
31 changes: 23 additions & 8 deletions tests/errors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import re
import sys
import traceback

from absl.testing import absltest
Expand Down Expand Up @@ -46,7 +47,12 @@ def check_filtered_stack_trace(test, etype, f, frame_patterns=(),
test.assertRaises(etype, f)
e = get_exception(etype, f)
c = e.__cause__
if filter_mode == "remove_frames":
if filter_mode == "quiet_remove_frames":
if sys.version_info >= (3, 11):
assert any("For simplicity" in x for x in e.__notes__)
else:
test.assertIsInstance(c, jax.errors.SimplifiedTraceback)
elif filter_mode == "remove_frames":
test.assertIsInstance(c, traceback_util.UnfilteredStackTrace)
else:
test.assertFalse(isinstance(c, traceback_util.UnfilteredStackTrace))
Expand Down Expand Up @@ -74,7 +80,7 @@ def check_filtered_stack_trace(test, etype, f, frame_patterns=(),
@jtu.with_config(jax_traceback_filtering='auto') # JaxTestCase defaults to off.
@parameterized.named_parameters(
{"testcase_name": f"_{f}", "filter_mode": f}
for f in ("tracebackhide", "remove_frames"))
for f in ("tracebackhide", "remove_frames", "quiet_remove_frames"))
class FilteredTracebackTest(jtu.JaxTestCase):

def test_nested_jit(self, filter_mode):
Expand Down Expand Up @@ -347,9 +353,13 @@ def outer(x):
check_filtered_stack_trace(self, TypeError, f, [
('<lambda>', 'f = lambda: outer'),
('outer', 'raise TypeError')], filter_mode=filter_mode)
e = get_exception(TypeError, f)
self.assertIsInstance(e.__cause__, traceback_util.UnfilteredStackTrace)
self.assertIsInstance(e.__cause__.__cause__, ValueError)
e = get_exception(TypeError, f) # Uses the default JAX_TRACEBACK_FILTERING=auto
if sys.version_info >= (3, 11):
assert any("For simplicity" in x for x in e.__notes__)
self.assertIsInstance(e.__cause__, ValueError)
else:
self.assertIsInstance(e.__cause__, jax.errors.SimplifiedTraceback)
self.assertIsInstance(e.__cause__.__cause__, ValueError)

def test_null_traceback(self, filter_mode):
class TestA: pass
Expand All @@ -375,9 +385,14 @@ def test_grad_norm(self):
e = exc
self.assertIsNot(e, None)
self.assertIn("invalid value", str(e))
self.assertIsInstance(
e.__cause__.__cause__,
source_info_util.JaxStackTraceBeforeTransformation)
if sys.version_info >= (3, 11):
self.assertIsInstance(
e.__cause__,
source_info_util.JaxStackTraceBeforeTransformation)
else:
self.assertIsInstance(
e.__cause__.__cause__,
source_info_util.JaxStackTraceBeforeTransformation)


class CustomErrorsTest(jtu.JaxTestCase):
Expand Down

0 comments on commit a8388e2

Please sign in to comment.