Skip to content

Commit b808c67

Browse files
committed
Small refactor
* Import functions from `traceback` directly, to allow free use of `traceback` as a variable name. * Extract `_filtered_traceback` into a function. * Inline `_repr_exception_group_traceback` given it is used only in one place. * Make a type alias for the type of `tbfilter`.
1 parent 22ef71b commit b808c67

File tree

1 file changed

+42
-33
lines changed

1 file changed

+42
-33
lines changed

src/_pytest/_code/code.py

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
from pathlib import Path
1616
import re
1717
import sys
18-
import traceback
18+
from traceback import extract_tb
19+
from traceback import format_exception
1920
from traceback import format_exception_only
21+
from traceback import FrameSummary
2022
from types import CodeType
2123
from types import FrameType
2224
from types import TracebackType
@@ -28,6 +30,7 @@
2830
from typing import Literal
2931
from typing import overload
3032
from typing import SupportsIndex
33+
from typing import TYPE_CHECKING
3134
from typing import TypeVar
3235
from typing import Union
3336

@@ -208,10 +211,10 @@ def with_repr_style(
208211
def lineno(self) -> int:
209212
return self._rawentry.tb_lineno - 1
210213

211-
def get_python_framesummary(self) -> traceback.FrameSummary:
214+
def get_python_framesummary(self) -> FrameSummary:
212215
# Python's built-in traceback module implements all the nitty gritty
213216
# details to get column numbers of out frames.
214-
stack_summary = traceback.extract_tb(self._rawentry, limit=1)
217+
stack_summary = extract_tb(self._rawentry, limit=1)
215218
return stack_summary[0]
216219

217220
# Column and end line numbers introduced in python 3.11
@@ -694,8 +697,7 @@ def getrepr(
694697
showlocals: bool = False,
695698
style: TracebackStyle = "long",
696699
abspath: bool = False,
697-
tbfilter: bool
698-
| Callable[[ExceptionInfo[BaseException]], _pytest._code.code.Traceback] = True,
700+
tbfilter: TracebackFilter = True,
699701
funcargs: bool = False,
700702
truncate_locals: bool = True,
701703
truncate_args: bool = True,
@@ -742,7 +744,7 @@ def getrepr(
742744
if style == "native":
743745
return ReprExceptionInfo(
744746
reprtraceback=ReprTracebackNative(
745-
traceback.format_exception(
747+
format_exception(
746748
self.type,
747749
self.value,
748750
self.traceback[0]._rawentry if self.traceback else None,
@@ -851,6 +853,17 @@ def group_contains(
851853
return self._group_contains(self.value, expected_exception, match, depth)
852854

853855

856+
if TYPE_CHECKING:
857+
from typing_extensions import TypeAlias
858+
859+
# Type alias for the `tbfilter` setting:
860+
# bool: If True, it should be filtered using Traceback.filter()
861+
# callable: A callable that takes an ExceptionInfo and returns the filtered traceback.
862+
TracebackFilter: TypeAlias = Union[
863+
bool, Callable[[ExceptionInfo[BaseException]], Traceback]
864+
]
865+
866+
854867
@dataclasses.dataclass
855868
class FormattedExcinfo:
856869
"""Presenting information about failing Functions and Generators."""
@@ -862,7 +875,7 @@ class FormattedExcinfo:
862875
showlocals: bool = False
863876
style: TracebackStyle = "long"
864877
abspath: bool = True
865-
tbfilter: bool | Callable[[ExceptionInfo[BaseException]], Traceback] = True
878+
tbfilter: TracebackFilter = True
866879
funcargs: bool = False
867880
truncate_locals: bool = True
868881
truncate_args: bool = True
@@ -1099,16 +1112,8 @@ def _makepath(self, path: Path | str) -> str:
10991112
return np
11001113
return str(path)
11011114

1102-
def _filtered_traceback(self, excinfo: ExceptionInfo[BaseException]) -> Traceback:
1103-
"""Filter the exception traceback in ``excinfo`` according to ``tbfilter``, if any."""
1104-
if callable(self.tbfilter):
1105-
return self.tbfilter(excinfo)
1106-
elif self.tbfilter:
1107-
return excinfo.traceback.filter(excinfo)
1108-
return excinfo.traceback
1109-
11101115
def repr_traceback(self, excinfo: ExceptionInfo[BaseException]) -> ReprTraceback:
1111-
traceback = self._filtered_traceback(excinfo)
1116+
traceback = filter_excinfo_traceback(self.tbfilter, excinfo)
11121117

11131118
if isinstance(excinfo.value, RecursionError):
11141119
traceback, extraline = self._truncate_recursive_traceback(traceback)
@@ -1132,18 +1137,6 @@ def repr_traceback(self, excinfo: ExceptionInfo[BaseException]) -> ReprTraceback
11321137
]
11331138
return ReprTraceback(entries, extraline, style=self.style)
11341139

1135-
def _repr_exception_group_traceback(
1136-
self, excinfo: ExceptionInfo[BaseException]
1137-
) -> ReprTracebackNative:
1138-
traceback_ = self._filtered_traceback(excinfo)
1139-
return ReprTracebackNative(
1140-
traceback.format_exception(
1141-
type(excinfo.value),
1142-
excinfo.value,
1143-
traceback_[0]._rawentry,
1144-
)
1145-
)
1146-
11471140
def _truncate_recursive_traceback(
11481141
self, traceback: Traceback
11491142
) -> tuple[Traceback, str | None]:
@@ -1194,19 +1187,23 @@ def repr_excinfo(self, excinfo: ExceptionInfo[BaseException]) -> ExceptionChainR
11941187
# Fall back to native traceback as a temporary workaround until
11951188
# full support for exception groups added to ExceptionInfo.
11961189
# See https://github.com/pytest-dev/pytest/issues/9159
1190+
reprtraceback: ReprTraceback | ReprTracebackNative
11971191
if isinstance(e, BaseExceptionGroup):
1198-
reprtraceback: ReprTracebackNative | ReprTraceback = (
1199-
self._repr_exception_group_traceback(excinfo_)
1192+
traceback = filter_excinfo_traceback(self.tbfilter, excinfo)
1193+
reprtraceback = ReprTracebackNative(
1194+
format_exception(
1195+
type(excinfo.value),
1196+
excinfo.value,
1197+
traceback[0]._rawentry,
1198+
)
12001199
)
12011200
else:
12021201
reprtraceback = self.repr_traceback(excinfo_)
12031202
reprcrash = excinfo_._getreprcrash()
12041203
else:
12051204
# Fallback to native repr if the exception doesn't have a traceback:
12061205
# ExceptionInfo objects require a full traceback to work.
1207-
reprtraceback = ReprTracebackNative(
1208-
traceback.format_exception(type(e), e, None)
1209-
)
1206+
reprtraceback = ReprTracebackNative(format_exception(type(e), e, None))
12101207
reprcrash = None
12111208
repr_chain += [(reprtraceback, reprcrash, descr)]
12121209

@@ -1555,3 +1552,15 @@ def filter_traceback(entry: TracebackEntry) -> bool:
15551552
return False
15561553

15571554
return True
1555+
1556+
1557+
def filter_excinfo_traceback(
1558+
tbfilter: TracebackFilter, excinfo: ExceptionInfo[BaseException]
1559+
) -> Traceback:
1560+
"""Filter the exception traceback in ``excinfo`` according to ``tbfilter``."""
1561+
if callable(tbfilter):
1562+
return tbfilter(excinfo)
1563+
elif tbfilter:
1564+
return excinfo.traceback.filter(excinfo)
1565+
else:
1566+
return excinfo.traceback

0 commit comments

Comments
 (0)