diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d4648a4..b0e3c02 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -27,7 +27,7 @@ jobs: - python: "3.11" tox: py311 - python: "3.12" - tox: py312 + tox: py312,py312-trio - python: "3.12" tox: pep8 - python: "3.11" diff --git a/doc/source/index.rst b/doc/source/index.rst index 3f0764a..65dd208 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -568,28 +568,34 @@ in retry strategies like ``retry_if_result``. This can be done accessing the Async and retry ~~~~~~~~~~~~~~~ -Finally, ``retry`` works also on asyncio and Tornado (>= 4.5) coroutines. +Finally, ``retry`` works also on asyncio, Trio, and Tornado (>= 4.5) coroutines. Sleeps are done asynchronously too. .. code-block:: python @retry - async def my_async_function(loop): + async def my_asyncio_function(loop): await loop.getaddrinfo('8.8.8.8', 53) +.. code-block:: python + + @retry + async def my_async_trio_function(): + await trio.socket.getaddrinfo('8.8.8.8', 53) + .. code-block:: python @retry @tornado.gen.coroutine - def my_async_function(http_client, url): + def my_async_tornado_function(http_client, url): yield http_client.fetch(url) -You can even use alternative event loops such as `curio` or `Trio` by passing the correct sleep function: +You can even use alternative event loops such as `curio` by passing the correct sleep function: .. code-block:: python - @retry(sleep=trio.sleep) - async def my_async_function(loop): + @retry(sleep=curio.sleep) + async def my_async_curio_function(): await asks.get('https://example.org') Contribute diff --git a/releasenotes/notes/trio-support-retry-22bd544800cd1f36.yaml b/releasenotes/notes/trio-support-retry-22bd544800cd1f36.yaml new file mode 100644 index 0000000..b8e0c14 --- /dev/null +++ b/releasenotes/notes/trio-support-retry-22bd544800cd1f36.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + If you're using `Trio `__, then + ``@retry`` now works automatically. It's no longer necessary to + pass ``sleep=trio.sleep``. diff --git a/tenacity/asyncio/__init__.py b/tenacity/asyncio/__init__.py index 3ec0088..6d63ebc 100644 --- a/tenacity/asyncio/__init__.py +++ b/tenacity/asyncio/__init__.py @@ -46,11 +46,22 @@ WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]]) -def asyncio_sleep(duration: float) -> t.Awaitable[None]: +def _portable_async_sleep(seconds: float) -> t.Awaitable[None]: + # If trio is already imported, then importing it is cheap. + # If trio isn't already imported, then it's definitely not running, so we + # can skip further checks. + if "trio" in sys.modules: + # If trio is available, then sniffio is too + import trio + import sniffio + + if sniffio.current_async_library() == "trio": + return trio.sleep(seconds) + # Otherwise, assume asyncio # Lazy import asyncio as it's expensive (responsible for 25-50% of total import overhead). import asyncio - return asyncio.sleep(duration) + return asyncio.sleep(seconds) class AsyncRetrying(BaseRetrying): @@ -58,7 +69,7 @@ def __init__( self, sleep: t.Callable[ [t.Union[int, float]], t.Union[None, t.Awaitable[None]] - ] = asyncio_sleep, + ] = _portable_async_sleep, stop: "StopBaseT" = tenacity.stop.stop_never, wait: "WaitBaseT" = tenacity.wait.wait_none(), retry: "t.Union[SyncRetryBaseT, RetryBaseT]" = tenacity.retry_if_exception_type(), diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 48f6286..8716529 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -18,6 +18,13 @@ import unittest from functools import wraps +try: + import trio +except ImportError: + have_trio = False +else: + have_trio = True + import pytest import tenacity @@ -55,7 +62,7 @@ async def _retryable_coroutine_with_2_attempts(thing): thing.go() -class TestAsync(unittest.TestCase): +class TestAsyncio(unittest.TestCase): @asynctest async def test_retry(self): thing = NoIOErrorAfterCount(5) @@ -138,6 +145,21 @@ def after(retry_state): assert list(attempt_nos2) == [1, 2, 3] +@unittest.skipIf(not have_trio, "trio not installed") +class TestTrio(unittest.TestCase): + def test_trio_basic(self): + thing = NoIOErrorAfterCount(5) + + @retry + async def trio_function(): + await trio.sleep(0.00001) + return thing.go() + + trio.run(trio_function) + + assert thing.counter == thing.count + + class TestContextManager(unittest.TestCase): @asynctest async def test_do_max_attempts(self): diff --git a/tox.ini b/tox.ini index 13e5a1d..14f8ae0 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py3{8,9,10,11,12}, pep8, pypy3 +envlist = py3{8,9,10,11,12,12-trio}, pep8, pypy3 skip_missing_interpreters = True [testenv] @@ -8,6 +8,7 @@ sitepackages = False deps = .[test] .[doc] + trio: trio commands = py3{8,9,10,11,12},pypy3: pytest {posargs} py3{8,9,10,11,12},pypy3: sphinx-build -a -E -W -b doctest doc/source doc/build @@ -24,10 +25,11 @@ commands = deps = mypy>=1.0.0 pytest # for stubs + trio commands = mypy {posargs} [testenv:reno] basepython = python3 deps = reno -commands = reno {posargs} \ No newline at end of file +commands = reno {posargs}