Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typing for _core/_multierror.py #2742

Merged
merged 16 commits into from
Aug 14, 2023
Merged
119 changes: 89 additions & 30 deletions trio/_core/_multierror.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import sys
import warnings
from typing import TYPE_CHECKING
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any

import attr

Expand All @@ -15,12 +16,17 @@

if TYPE_CHECKING:
from types import TracebackType

from mypy_extensions import DefaultNamedArg
from typing_extensions import Self
################################################################
# MultiError
################################################################


def _filter_impl(handler, root_exc):
def _filter_impl(
handler: Callable[[BaseException], BaseException | None], root_exc: BaseException
) -> BaseException | None:
# We have a tree of MultiError's, like:
#
# MultiError([
Expand Down Expand Up @@ -79,7 +85,9 @@ def _filter_impl(handler, root_exc):

# Filters a subtree, ignoring tracebacks, while keeping a record of
# which MultiErrors were preserved unchanged
def filter_tree(exc, preserved):
def filter_tree(
exc: MultiError | BaseException, preserved: set[int]
) -> MultiError | BaseException | None:
if isinstance(exc, MultiError):
new_exceptions = []
changed = False
Expand All @@ -103,7 +111,9 @@ def filter_tree(exc, preserved):
new_exc.__context__ = exc
return new_exc

def push_tb_down(tb, exc, preserved):
def push_tb_down(
tb: TracebackType | None, exc: BaseException, preserved: set[int]
) -> None:
if id(exc) in preserved:
return
new_tb = concat_tb(tb, exc.__traceback__)
Expand All @@ -114,7 +124,7 @@ def push_tb_down(tb, exc, preserved):
else:
exc.__traceback__ = new_tb

preserved = set()
preserved: set[int] = set()
new_root_exc = filter_tree(root_exc, preserved)
push_tb_down(None, root_exc, preserved)
# Delete the local functions to avoid a reference cycle (see
Expand All @@ -130,9 +140,9 @@ def push_tb_down(tb, exc, preserved):
# frame show up in the traceback; otherwise, we leave no trace.)
@attr.s(frozen=True)
class MultiErrorCatcher:
_handler = attr.ib()
handler: Callable[[BaseException], BaseException | None] = attr.ib()

def __enter__(self):
def __enter__(self) -> None:
pass

def __exit__(
Expand All @@ -142,7 +152,7 @@ def __exit__(
traceback: TracebackType | None,
) -> bool | None:
if exc_value is not None:
filtered_exc = _filter_impl(self._handler, exc_value)
filtered_exc = _filter_impl(self.handler, exc_value)

if filtered_exc is exc_value:
# Let the interpreter re-raise it
Expand All @@ -167,7 +177,13 @@ def __exit__(
return False


class MultiError(BaseExceptionGroup):
if TYPE_CHECKING:
_BaseExceptionGroup = BaseExceptionGroup[BaseException]
else:
_BaseExceptionGroup = BaseExceptionGroup


class MultiError(_BaseExceptionGroup):
"""An exception that contains other exceptions; also known as an
"inception".

Expand All @@ -190,7 +206,9 @@ class MultiError(BaseExceptionGroup):

"""

def __init__(self, exceptions, *, _collapse=True):
def __init__(
self, exceptions: list[BaseException], *, _collapse: bool = True
) -> None:
self.collapse = _collapse

# Avoid double initialization when _collapse is True and exceptions[0] returned
Expand All @@ -201,7 +219,9 @@ def __init__(self, exceptions, *, _collapse=True):

super().__init__("multiple tasks failed", exceptions)

def __new__(cls, exceptions, *, _collapse=True):
def __new__( # type: ignore[misc] # mypy says __new__ must return a class instance
cls, exceptions: Iterable[BaseException], *, _collapse: bool = True
) -> NonBaseMultiError | MultiError | Self:
exceptions = list(exceptions)
for exc in exceptions:
if not isinstance(exc, BaseException):
Expand All @@ -210,41 +230,64 @@ def __new__(cls, exceptions, *, _collapse=True):
# If this lone object happens to itself be a MultiError, then
# Python will implicitly call our __init__ on it again. See
# special handling in __init__.
return exceptions[0]
single = exceptions[0]
assert isinstance(single, MultiError)
return single
else:
# The base class __new__() implicitly invokes our __init__, which
# is what we want.
#
# In an earlier version of the code, we didn't define __init__ and
# simply set the `exceptions` attribute directly on the new object.
# However, linters expect attributes to be initialized in __init__.
from_class: type[Self] | type[NonBaseMultiError] = cls
if all(isinstance(exc, Exception) for exc in exceptions):
cls = NonBaseMultiError

return super().__new__(cls, "multiple tasks failed", exceptions)

def __reduce__(self):
from_class = NonBaseMultiError

# Mypy is really mad about the following line:
# Ignoring arg-type: 'Argument 3 to "__new__" of "BaseExceptionGroup" has incompatible type "list[BaseException]"; expected "Sequence[_BaseExceptionT_co]"'
# We have checked that exceptions is indeed a list of BaseException objects, this is fine.
# Ignoring type-var: 'Value of type variable "Self" of "__new__" of "BaseExceptionGroup" cannot be "object"'
# Not sure how mypy is getting 'object', this is also fine.
new_obj = super().__new__(from_class, "multiple tasks failed", exceptions) # type: ignore[arg-type,type-var]
assert isinstance(new_obj, (cls, NonBaseMultiError))
return new_obj

def __reduce__(
self,
) -> tuple[
Callable[
[type[Self], Iterable[BaseException], DefaultNamedArg(bool, "_collapse")],
NonBaseMultiError | MultiError | Self,
],
tuple[type[Self], list[BaseException]],
dict[str, bool],
]:
return (
self.__new__,
(self.__class__, list(self.exceptions)),
{"collapse": self.collapse},
)

def __str__(self):
def __str__(self) -> str:
return ", ".join(repr(exc) for exc in self.exceptions)

def __repr__(self):
def __repr__(self) -> str:
return f"<MultiError: {self}>"

def derive(self, __excs):
def derive(self, __excs: list[BaseException]) -> MultiError: # type: ignore[override]
# We use _collapse=False here to get ExceptionGroup semantics, since derive()
# is part of the PEP 654 API
exc = MultiError(__excs, _collapse=False)
exc.collapse = self.collapse
return exc

@classmethod
def filter(cls, handler, root_exc):
def filter(
cls,
handler: Callable[[BaseException], BaseException | None],
root_exc: BaseException,
) -> BaseException | None:
"""Apply the given ``handler`` to all the exceptions in ``root_exc``.

Args:
Expand All @@ -268,7 +311,9 @@ def filter(cls, handler, root_exc):
return _filter_impl(handler, root_exc)

@classmethod
def catch(cls, handler):
def catch(
cls, handler: Callable[[BaseException], BaseException | None]
) -> MultiErrorCatcher:
"""Return a context manager that catches and re-throws exceptions
after running :meth:`filter` on them.

Expand All @@ -286,8 +331,14 @@ def catch(cls, handler):
return MultiErrorCatcher(handler)


class NonBaseMultiError(MultiError, ExceptionGroup):
pass
if TYPE_CHECKING:
_ExceptionGroup = ExceptionGroup[Exception]
else:
_ExceptionGroup = ExceptionGroup


class NonBaseMultiError(MultiError, _ExceptionGroup):
__slots__ = ()


# Clean up exception printing:
Expand Down Expand Up @@ -322,8 +373,8 @@ class NonBaseMultiError(MultiError, ExceptionGroup):

if have_tproxy:
# http://doc.pypy.org/en/latest/objspace-proxies.html
def copy_tb(base_tb, tb_next):
def controller(operation):
def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> tputil:
def controller(operation: tputil.ProxyOperation) -> TracebackType | None | Any:
# Rationale for pragma: I looked fairly carefully and tried a few
# things, and AFAICT it's not actually possible to get any
# 'opname' that isn't __getattr__ or __getattribute__. So there's
Expand Down Expand Up @@ -359,7 +410,7 @@ class CTraceback(ctypes.Structure):
("tb_lineno", ctypes.c_int),
]

def copy_tb(base_tb, tb_next):
def copy_tb(base_tb: TracebackType, tb_next: TracebackType) -> TracebackType:
# TracebackType has no public constructor, so allocate one the hard way
try:
raise ValueError
Expand Down Expand Up @@ -397,7 +448,9 @@ def copy_tb(base_tb, tb_next):
del new_tb, old_tb_frame


def concat_tb(head, tail):
def concat_tb(
head: TracebackType | None, tail: TracebackType | None
) -> TracebackType | None:
# We have to use an iterative algorithm here, because in the worst case
# this might be a RecursionError stack that is by definition too deep to
# process by recursion!
Expand All @@ -417,7 +470,7 @@ def concat_tb(head, tail):
if "IPython" in sys.modules:
import IPython

ip = IPython.get_ipython()
ip = IPython.get_ipython() # type: ignore[attr-defined] # not explicitly exported
if ip is not None:
if ip.custom_exceptions != ():
warnings.warn(
Expand All @@ -429,7 +482,13 @@ def concat_tb(head, tail):
)
else:

def trio_show_traceback(self, etype, value, tb, tb_offset=None):
def trio_show_traceback(
self: Any,
etype: Any,
value: BaseException,
tb: TracebackType,
tb_offset: Any | None = None,
) -> None:
# XX it would be better to integrate with IPython's fancy
# exception formatting stuff (and not ignore tb_offset)
print_exception(value)
Expand Down