diff --git a/async_timeout/__init__.py b/async_timeout/__init__.py index 4188a98..37c4971 100644 --- a/async_timeout/__init__.py +++ b/async_timeout/__init__.py @@ -113,7 +113,7 @@ def __exit__( exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> Optional[bool]: - self._do_exit(exc_type) + self._do_exit(exc_type, exc_val) return None async def __aenter__(self) -> "Timeout": @@ -126,7 +126,7 @@ async def __aexit__( exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> Optional[bool]: - self._do_exit(exc_type) + self._do_exit(exc_type, exc_val) return None @property @@ -206,17 +206,34 @@ def _do_enter(self) -> None: self._state = _State.ENTER self._reschedule() - def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: + def _do_exit( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + ) -> None: if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT: - self._timeout_handler = None - raise asyncio.TimeoutError - # timeout has not expired - self._state = _State.EXIT + skip = False + if sys.version_info >= (3, 9): + # Analyse msg + assert exc_val is not None + if not exc_val.args or exc_val.args[0] != id(self): + skip = True + if not skip: + if sys.version_info >= (3, 11): + asyncio.current_task().uncancel() + raise asyncio.TimeoutError + # state is EXIT if not timed out previously + if self._state != _State.TIMEOUT: + self._state = _State.EXIT self._reject() return None def _on_timeout(self, task: "asyncio.Task[None]") -> None: - task.cancel() + # Note: the second '.cancel()' call is ignored on Python 3.11 + if sys.version_info >= (3, 9): + task.cancel(id(self)) + else: + task.cancel() self._state = _State.TIMEOUT # drop the reference early self._timeout_handler = None diff --git a/tests/test_timeout.py b/tests/test_timeout.py index d32c5fd..d62b160 100644 --- a/tests/test_timeout.py +++ b/tests/test_timeout.py @@ -361,3 +361,40 @@ async def test_deprecated_with() -> None: with pytest.warns(DeprecationWarning): with timeout(1): await asyncio.sleep(0) + + +@pytest.mark.asyncio +async def test_double_timeouts() -> None: + with pytest.raises(asyncio.TimeoutError): + async with timeout(0.1) as cm1: + async with timeout(0.1) as cm2: + await asyncio.sleep(10) + + assert cm1.expired + assert cm2.expired + + +@pytest.mark.asyncio +async def test_timeout_with_cancelled_task() -> None: + + event = asyncio.Event() + + async def coro() -> None: + event.set() + async with timeout_cm: + await asyncio.sleep(5) + + async def main() -> str: + task = asyncio.create_task(coro()) + await event.wait() + loop = asyncio.get_running_loop() + timeout_cm.update(loop.time()) # reschedule to the next loop iteration + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + return "ok" + + timeout_cm = timeout(3600) # reschedule just before the usage + task2 = asyncio.create_task(main()) + assert "ok" == await task2 + assert timeout_cm.expired