|
1 | 1 | import asyncio |
| 2 | +import sys |
2 | 3 | import threading |
3 | 4 | from types import TracebackType |
4 | 5 | from typing import ( |
|
29 | 30 | anyio = None # type: ignore |
30 | 31 |
|
31 | 32 |
|
| 33 | +if sys.version_info >= (3, 11): # pragma: nocover |
| 34 | + import asyncio as asyncio_timeout |
| 35 | + |
| 36 | + anyio_shield = None |
| 37 | +else: # pragma: nocover |
| 38 | + import async_timeout as asyncio_timeout |
| 39 | + |
| 40 | + if anyio is None: # pragma: nocover |
| 41 | + raise RuntimeError("Running in Python<3.11 requires anyio") |
| 42 | + anyio_shield = anyio.CancelScope |
| 43 | + |
| 44 | + |
32 | 45 | AsyncBackend = Literal["asyncio", "trio"] |
33 | 46 |
|
34 | 47 |
|
@@ -163,9 +176,11 @@ async def wait(self, timeout: Optional[float] = None) -> None: |
163 | 176 | with trio.fail_after(timeout_or_inf): |
164 | 177 | await event.wait() |
165 | 178 | else: |
166 | | - asyncio_exc_map: ExceptionMapping = {TimeoutError: PoolTimeout} |
| 179 | + asyncio_exc_map: ExceptionMapping = { |
| 180 | + asyncio.exceptions.TimeoutError: PoolTimeout |
| 181 | + } |
167 | 182 | with map_exceptions(asyncio_exc_map): |
168 | | - async with asyncio.timeout(timeout): |
| 183 | + async with asyncio_timeout.timeout(timeout): |
169 | 184 | await event.wait() |
170 | 185 |
|
171 | 186 |
|
@@ -217,17 +232,20 @@ async def shield(shielded: Callable[[], Coroutine[Any, Any, None]]) -> None: |
217 | 232 | if current_async_backend() == "trio": |
218 | 233 | with trio.CancelScope(shield=True): |
219 | 234 | await shielded() |
| 235 | + elif sys.version_info < (3, 11): # pragma: nocover |
| 236 | + with anyio_shield(shield=True): |
| 237 | + await shielded() |
220 | 238 | else: |
221 | | - await AsyncShieldCancellation._asyncio_shield(shielded) |
| 239 | + await AsyncShieldCancellation._asyncio_shield(shielded) # pragma: nocover |
222 | 240 |
|
223 | 241 | @staticmethod |
224 | 242 | async def _asyncio_shield( |
225 | 243 | shielded: Callable[[], Coroutine[Any, Any, None]], |
226 | | - ) -> None: |
| 244 | + ) -> None: # pragma: nocover |
227 | 245 | inner_task = asyncio.create_task(shielded()) |
228 | 246 | try: |
229 | 247 | await asyncio.shield(inner_task) |
230 | | - except asyncio.CancelledError: |
| 248 | + except (asyncio.exceptions.CancelledError, asyncio.CancelledError): |
231 | 249 | # Let the inner_task to complete as it was shielded from the cancellation |
232 | 250 | await inner_task |
233 | 251 |
|
|
0 commit comments