Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typing for _core/_multierror.py #2742

Merged
merged 16 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ module = [
"trio._abc",
"trio._core._entry_queue",
"trio._core._local",
"trio._core._unbounded_queue",
"trio._core._multierror",
"trio._core._thread_cache",
"trio._core._unbounded_queue",
"trio._deprecate",
"trio._dtls",
"trio._file_io",
Expand Down
168 changes: 113 additions & 55 deletions trio/_core/_multierror.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import sys
import warnings
from typing import TYPE_CHECKING
from collections.abc import Callable, Iterable, Sequence
from types import TracebackType
from typing import TYPE_CHECKING, Any, cast, overload

import attr

Expand All @@ -14,13 +16,15 @@
from traceback import print_exception

if TYPE_CHECKING:
from types import TracebackType
from typing_extensions import Self
################################################################
# MultiError
################################################################


def _filter_impl(handler, root_exc):
def _filter_impl(
handler: Callable[[BaseException], BaseException | None], root_exc: BaseException
) -> BaseException | None:
# We have a tree of MultiError's, like:
#
# MultiError([
Expand Down Expand Up @@ -79,7 +83,9 @@ def _filter_impl(handler, root_exc):

# Filters a subtree, ignoring tracebacks, while keeping a record of
# which MultiErrors were preserved unchanged
def filter_tree(exc, preserved):
def filter_tree(
exc: MultiError | BaseException, preserved: set[int]
) -> MultiError | BaseException | None:
if isinstance(exc, MultiError):
new_exceptions = []
changed = False
Expand All @@ -103,7 +109,9 @@ def filter_tree(exc, preserved):
new_exc.__context__ = exc
return new_exc

def push_tb_down(tb, exc, preserved):
def push_tb_down(
tb: TracebackType | None, exc: BaseException, preserved: set[int]
) -> None:
if id(exc) in preserved:
return
new_tb = concat_tb(tb, exc.__traceback__)
Expand All @@ -114,7 +122,7 @@ def push_tb_down(tb, exc, preserved):
else:
exc.__traceback__ = new_tb

preserved = set()
preserved: set[int] = set()
new_root_exc = filter_tree(root_exc, preserved)
push_tb_down(None, root_exc, preserved)
# Delete the local functions to avoid a reference cycle (see
Expand All @@ -130,9 +138,9 @@ def push_tb_down(tb, exc, preserved):
# frame show up in the traceback; otherwise, we leave no trace.)
@attr.s(frozen=True)
class MultiErrorCatcher:
_handler = attr.ib()
_handler: Callable[[BaseException], BaseException | None] = attr.ib()

def __enter__(self):
def __enter__(self) -> None:
pass

def __exit__(
Expand Down Expand Up @@ -167,7 +175,13 @@ def __exit__(
return False


class MultiError(BaseExceptionGroup):
if TYPE_CHECKING:
_BaseExceptionGroup = BaseExceptionGroup[BaseException]
else:
_BaseExceptionGroup = BaseExceptionGroup


class MultiError(_BaseExceptionGroup):
"""An exception that contains other exceptions; also known as an
"inception".

Expand All @@ -190,7 +204,9 @@ class MultiError(BaseExceptionGroup):

"""

def __init__(self, exceptions, *, _collapse=True):
def __init__(
self, exceptions: list[BaseException], *, _collapse: bool = True
) -> None:
self.collapse = _collapse

# Avoid double initialization when _collapse is True and exceptions[0] returned
Expand All @@ -201,7 +217,9 @@ def __init__(self, exceptions, *, _collapse=True):

super().__init__("multiple tasks failed", exceptions)

def __new__(cls, exceptions, *, _collapse=True):
def __new__( # type: ignore[misc] # mypy says __new__ must return a class instance
cls, exceptions: Iterable[BaseException], *, _collapse: bool = True
) -> NonBaseMultiError | Self | BaseException:
exceptions = list(exceptions)
for exc in exceptions:
if not isinstance(exc, BaseException):
Expand All @@ -218,33 +236,54 @@ def __new__(cls, exceptions, *, _collapse=True):
# In an earlier version of the code, we didn't define __init__ and
# simply set the `exceptions` attribute directly on the new object.
# However, linters expect attributes to be initialized in __init__.
from_class: type[Self] | type[NonBaseMultiError] = cls
if all(isinstance(exc, Exception) for exc in exceptions):
cls = NonBaseMultiError
from_class = NonBaseMultiError

return super().__new__(cls, "multiple tasks failed", exceptions)
# Ignoring arg-type: 'Argument 3 to "__new__" of "BaseExceptionGroup" has incompatible type "list[BaseException]"; expected "Sequence[_BaseExceptionT_co]"'
# We have checked that exceptions is indeed a list of BaseException objects, this is fine.
new_obj = super().__new__(from_class, "multiple tasks failed", exceptions) # type: ignore[arg-type]
assert isinstance(new_obj, (cls, NonBaseMultiError))
return new_obj

def __reduce__(self):
def __reduce__(
self,
) -> tuple[object, tuple[type[Self], list[BaseException]], dict[str, bool],]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the trailing comma?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used black formatting, not sure, I can try removing it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, still there after removing and re-running black format. Maybe black has a bug?

return (
self.__new__,
(self.__class__, list(self.exceptions)),
{"collapse": self.collapse},
)

def __str__(self):
def __str__(self) -> str:
return ", ".join(repr(exc) for exc in self.exceptions)

def __repr__(self):
def __repr__(self) -> str:
return f"<MultiError: {self}>"

def derive(self, __excs):
@overload
def derive(self, __excs: Sequence[Exception]) -> NonBaseMultiError:
...

@overload
def derive(self, __excs: Sequence[BaseException]) -> MultiError:
...

def derive(
self, __excs: Sequence[Exception | BaseException]
) -> NonBaseMultiError | MultiError:
# We use _collapse=False here to get ExceptionGroup semantics, since derive()
# is part of the PEP 654 API
exc = MultiError(__excs, _collapse=False)
exc = MultiError(list(__excs), _collapse=False)
exc.collapse = self.collapse
return exc

@classmethod
def filter(cls, handler, root_exc):
def filter(
cls,
handler: Callable[[BaseException], BaseException | None],
root_exc: BaseException,
) -> BaseException | None:
"""Apply the given ``handler`` to all the exceptions in ``root_exc``.

Args:
Expand All @@ -268,7 +307,9 @@ def filter(cls, handler, root_exc):
return _filter_impl(handler, root_exc)

@classmethod
def catch(cls, handler):
def catch(
cls, handler: Callable[[BaseException], BaseException | None]
) -> MultiErrorCatcher:
"""Return a context manager that catches and re-throws exceptions
after running :meth:`filter` on them.

Expand All @@ -286,8 +327,14 @@ def catch(cls, handler):
return MultiErrorCatcher(handler)


class NonBaseMultiError(MultiError, ExceptionGroup):
pass
if TYPE_CHECKING:
_ExceptionGroup = ExceptionGroup[Exception]
else:
_ExceptionGroup = ExceptionGroup


class NonBaseMultiError(MultiError, _ExceptionGroup):
__slots__ = ()


# Clean up exception printing:
Expand Down Expand Up @@ -316,30 +363,6 @@ class NonBaseMultiError(MultiError, ExceptionGroup):
try:
import tputil
except ImportError:
have_tproxy = False
else:
have_tproxy = True

if have_tproxy:
# http://doc.pypy.org/en/latest/objspace-proxies.html
def copy_tb(base_tb, tb_next):
def controller(operation):
# Rationale for pragma: I looked fairly carefully and tried a few
# things, and AFAICT it's not actually possible to get any
# 'opname' that isn't __getattr__ or __getattribute__. So there's
# no missing test we could add, and no value in coverage nagging
# us about adding one.
if operation.opname in [
"__getattribute__",
"__getattr__",
]: # pragma: no cover
if operation.args[0] == "tb_next":
return tb_next
return operation.delegate()

return tputil.make_proxy(controller, type(base_tb), base_tb)

else:
# ctypes it is
import ctypes

Expand All @@ -359,12 +382,13 @@ class CTraceback(ctypes.Structure):
("tb_lineno", ctypes.c_int),
]

def copy_tb(base_tb, tb_next):
def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType:
# TracebackType has no public constructor, so allocate one the hard way
try:
raise ValueError
except ValueError as exc:
new_tb = exc.__traceback__
assert new_tb is not None
c_new_tb = CTraceback.from_address(id(new_tb))

# At the C level, tb_next either pointer to the next traceback or is
Expand All @@ -377,14 +401,14 @@ def copy_tb(base_tb, tb_next):
# which it already is, so we're done. Otherwise, we have to actually
# do some work:
if tb_next is not None:
_ctypes.Py_INCREF(tb_next)
_ctypes.Py_INCREF(tb_next) # type: ignore[attr-defined]
c_new_tb.tb_next = id(tb_next)

assert c_new_tb.tb_frame is not None
_ctypes.Py_INCREF(base_tb.tb_frame)
_ctypes.Py_INCREF(base_tb.tb_frame) # type: ignore[attr-defined]
old_tb_frame = new_tb.tb_frame
c_new_tb.tb_frame = id(base_tb.tb_frame)
_ctypes.Py_DECREF(old_tb_frame)
_ctypes.Py_DECREF(old_tb_frame) # type: ignore[attr-defined]

c_new_tb.tb_lasti = base_tb.tb_lasti
c_new_tb.tb_lineno = base_tb.tb_lineno
Expand All @@ -396,8 +420,32 @@ def copy_tb(base_tb, tb_next):
# see test_MultiError_catch_doesnt_create_cyclic_garbage
del new_tb, old_tb_frame

else:
# http://doc.pypy.org/en/latest/objspace-proxies.html
def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType:
# Mypy refuses to believe that ProxyOperation can be imported properly
def controller(operation: tputil.ProxyOperation) -> Any | None: # type: ignore[no-any-unimported]
# Rationale for pragma: I looked fairly carefully and tried a few
# things, and AFAICT it's not actually possible to get any
# 'opname' that isn't __getattr__ or __getattribute__. So there's
# no missing test we could add, and no value in coverage nagging
# us about adding one.
if operation.opname in [
"__getattribute__",
"__getattr__",
]: # pragma: no cover
if operation.args[0] == "tb_next":
return tb_next
return operation.delegate() # Deligate is reverting to original behaviour

return cast(
TracebackType, tputil.make_proxy(controller, type(base_tb), base_tb)
) # Returns proxy to traceback


def concat_tb(head, tail):
def concat_tb(
head: TracebackType | None, tail: TracebackType | None
) -> TracebackType | None:
# We have to use an iterative algorithm here, because in the worst case
# this might be a RecursionError stack that is by definition too deep to
# process by recursion!
Expand Down Expand Up @@ -429,7 +477,13 @@ def concat_tb(head, tail):
)
else:

def trio_show_traceback(self, etype, value, tb, tb_offset=None):
def trio_show_traceback(
self: IPython.core.interactiveshell.InteractiveShell,
etype: type[BaseException],
value: BaseException,
tb: TracebackType,
tb_offset: int | None = None,
) -> None:
# XX it would be better to integrate with IPython's fancy
# exception formatting stuff (and not ignore tb_offset)
print_exception(value)
Expand Down Expand Up @@ -460,10 +514,14 @@ def trio_show_traceback(self, etype, value, tb, tb_offset=None):

assert sys.excepthook is apport_python_hook.apport_excepthook

def replacement_excepthook(etype, value, tb):
sys.stderr.write("".join(format_exception(etype, value, tb)))
def replacement_excepthook(
etype: type[BaseException], value: BaseException, tb: TracebackType | None
) -> None:
# This does work, it's an overloaded function
sys.stderr.write("".join(format_exception(etype, value, tb))) # type: ignore[arg-type]

fake_sys = ModuleType("trio_fake_sys")
fake_sys.__dict__.update(sys.__dict__)
fake_sys.__excepthook__ = replacement_excepthook # type: ignore
# Fake does not have __excepthook__ attribute, but we are about to replace real sys
fake_sys.__excepthook__ = replacement_excepthook # type: ignore[attr-defined]
apport_python_hook.sys = fake_sys