Skip to content

Commit

Permalink
Add async strategies (#451)
Browse files Browse the repository at this point in the history
* Add async strategies

* Fix init typing

* Reuse is_coroutine_callable

* Keep only async predicate overrides and DRY implementations

* Ensure async and/or versions called when necessary

* Run ruff format

* Copy over strategies as async

* Add release note
  • Loading branch information
hasier authored Jun 12, 2024
1 parent cb15300 commit 21137e7
Show file tree
Hide file tree
Showing 7 changed files with 396 additions and 35 deletions.
5 changes: 5 additions & 0 deletions releasenotes/notes/add-async-actions-b249c527d99723bb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
features:
- |
Added the ability to use async functions for retries. This way, you can now use
asyncio coroutines for retry strategy predicates.
28 changes: 19 additions & 9 deletions tenacity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
import warnings
from abc import ABC, abstractmethod
from concurrent import futures
from inspect import iscoroutinefunction

from . import _utils

# Import all built-in retry strategies for easier usage.
from .retry import retry_base # noqa
Expand Down Expand Up @@ -87,6 +88,7 @@
if t.TYPE_CHECKING:
import types

from . import asyncio as tasyncio
from .retry import RetryBaseT
from .stop import StopBaseT
from .wait import WaitBaseT
Expand Down Expand Up @@ -593,16 +595,24 @@ def retry(func: WrappedFn) -> WrappedFn: ...

@t.overload
def retry(
sleep: t.Callable[[t.Union[int, float]], t.Optional[t.Awaitable[None]]] = sleep,
sleep: t.Callable[[t.Union[int, float]], t.Union[None, t.Awaitable[None]]] = sleep,
stop: "StopBaseT" = stop_never,
wait: "WaitBaseT" = wait_none(),
retry: "RetryBaseT" = retry_if_exception_type(),
before: t.Callable[["RetryCallState"], None] = before_nothing,
after: t.Callable[["RetryCallState"], None] = after_nothing,
before_sleep: t.Optional[t.Callable[["RetryCallState"], None]] = None,
retry: "t.Union[RetryBaseT, tasyncio.retry.RetryBaseT]" = retry_if_exception_type(),
before: t.Callable[
["RetryCallState"], t.Union[None, t.Awaitable[None]]
] = before_nothing,
after: t.Callable[
["RetryCallState"], t.Union[None, t.Awaitable[None]]
] = after_nothing,
before_sleep: t.Optional[
t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]]
] = None,
reraise: bool = False,
retry_error_cls: t.Type["RetryError"] = RetryError,
retry_error_callback: t.Optional[t.Callable[["RetryCallState"], t.Any]] = None,
retry_error_callback: t.Optional[
t.Callable[["RetryCallState"], t.Union[t.Any, t.Awaitable[t.Any]]]
] = None,
) -> t.Callable[[WrappedFn], WrappedFn]: ...


Expand All @@ -624,7 +634,7 @@ def wrap(f: WrappedFn) -> WrappedFn:
f"this will probably hang indefinitely (did you mean retry={f.__class__.__name__}(...)?)"
)
r: "BaseRetrying"
if iscoroutinefunction(f):
if _utils.is_coroutine_callable(f):
r = AsyncRetrying(*dargs, **dkw)
elif (
tornado
Expand All @@ -640,7 +650,7 @@ def wrap(f: WrappedFn) -> WrappedFn:
return wrap


from tenacity._asyncio import AsyncRetrying # noqa:E402,I100
from tenacity.asyncio import AsyncRetrying # noqa:E402,I100

if tornado:
from tenacity.tornadoweb import TornadoRetrying
Expand Down
12 changes: 12 additions & 0 deletions tenacity/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,15 @@ def is_coroutine_callable(call: typing.Callable[..., typing.Any]) -> bool:
partial_call = isinstance(call, functools.partial) and call.func
dunder_call = partial_call or getattr(call, "__call__", None)
return inspect.iscoroutinefunction(dunder_call)


def wrap_to_async_func(
call: typing.Callable[..., typing.Any],
) -> typing.Callable[..., typing.Awaitable[typing.Any]]:
if is_coroutine_callable(call):
return call

async def inner(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
return call(*args, **kwargs)

return inner
86 changes: 64 additions & 22 deletions tenacity/_asyncio.py → tenacity/asyncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,29 @@
import sys
import typing as t

import tenacity
from tenacity import AttemptManager
from tenacity import BaseRetrying
from tenacity import DoAttempt
from tenacity import DoSleep
from tenacity import RetryCallState
from tenacity import RetryError
from tenacity import after_nothing
from tenacity import before_nothing
from tenacity import _utils

# Import all built-in retry strategies for easier usage.
from .retry import RetryBaseT
from .retry import retry_all # noqa
from .retry import retry_any # noqa
from .retry import retry_if_exception # noqa
from .retry import retry_if_result # noqa
from ..retry import RetryBaseT as SyncRetryBaseT

if t.TYPE_CHECKING:
from tenacity.stop import StopBaseT
from tenacity.wait import WaitBaseT

WrappedFnReturnT = t.TypeVar("WrappedFnReturnT")
WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]])

Expand All @@ -38,15 +54,41 @@ def asyncio_sleep(duration: float) -> t.Awaitable[None]:


class AsyncRetrying(BaseRetrying):
sleep: t.Callable[[float], t.Awaitable[t.Any]]

def __init__(
self,
sleep: t.Callable[[float], t.Awaitable[t.Any]] = asyncio_sleep,
**kwargs: t.Any,
sleep: t.Callable[
[t.Union[int, float]], t.Union[None, t.Awaitable[None]]
] = asyncio_sleep,
stop: "StopBaseT" = tenacity.stop.stop_never,
wait: "WaitBaseT" = tenacity.wait.wait_none(),
retry: "t.Union[SyncRetryBaseT, RetryBaseT]" = tenacity.retry_if_exception_type(),
before: t.Callable[
["RetryCallState"], t.Union[None, t.Awaitable[None]]
] = before_nothing,
after: t.Callable[
["RetryCallState"], t.Union[None, t.Awaitable[None]]
] = after_nothing,
before_sleep: t.Optional[
t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]]
] = None,
reraise: bool = False,
retry_error_cls: t.Type["RetryError"] = RetryError,
retry_error_callback: t.Optional[
t.Callable[["RetryCallState"], t.Union[t.Any, t.Awaitable[t.Any]]]
] = None,
) -> None:
super().__init__(**kwargs)
self.sleep = sleep
super().__init__(
sleep=sleep, # type: ignore[arg-type]
stop=stop,
wait=wait,
retry=retry, # type: ignore[arg-type]
before=before, # type: ignore[arg-type]
after=after, # type: ignore[arg-type]
before_sleep=before_sleep, # type: ignore[arg-type]
reraise=reraise,
retry_error_cls=retry_error_cls,
retry_error_callback=retry_error_callback,
)

async def __call__( # type: ignore[override]
self, fn: WrappedFn, *args: t.Any, **kwargs: t.Any
Expand All @@ -65,39 +107,29 @@ async def __call__( # type: ignore[override]
retry_state.set_result(result)
elif isinstance(do, DoSleep):
retry_state.prepare_for_next_attempt()
await self.sleep(do)
await self.sleep(do) # type: ignore[misc]
else:
return do # type: ignore[no-any-return]

@classmethod
def _wrap_action_func(cls, fn: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]:
if _utils.is_coroutine_callable(fn):
return fn

async def inner(*args: t.Any, **kwargs: t.Any) -> t.Any:
return fn(*args, **kwargs)

return inner

def _add_action_func(self, fn: t.Callable[..., t.Any]) -> None:
self.iter_state.actions.append(self._wrap_action_func(fn))
self.iter_state.actions.append(_utils.wrap_to_async_func(fn))

async def _run_retry(self, retry_state: "RetryCallState") -> None: # type: ignore[override]
self.iter_state.retry_run_result = await self._wrap_action_func(self.retry)(
self.iter_state.retry_run_result = await _utils.wrap_to_async_func(self.retry)(
retry_state
)

async def _run_wait(self, retry_state: "RetryCallState") -> None: # type: ignore[override]
if self.wait:
sleep = await self._wrap_action_func(self.wait)(retry_state)
sleep = await _utils.wrap_to_async_func(self.wait)(retry_state)
else:
sleep = 0.0

retry_state.upcoming_sleep = sleep

async def _run_stop(self, retry_state: "RetryCallState") -> None: # type: ignore[override]
self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start
self.iter_state.stop_run_result = await self._wrap_action_func(self.stop)(
self.iter_state.stop_run_result = await _utils.wrap_to_async_func(self.stop)(
retry_state
)

Expand Down Expand Up @@ -127,7 +159,7 @@ async def __anext__(self) -> AttemptManager:
return AttemptManager(retry_state=self._retry_state)
elif isinstance(do, DoSleep):
self._retry_state.prepare_for_next_attempt()
await self.sleep(do)
await self.sleep(do) # type: ignore[misc]
else:
raise StopAsyncIteration

Expand All @@ -146,3 +178,13 @@ async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any:
async_wrapped.retry_with = fn.retry_with # type: ignore[attr-defined]

return async_wrapped # type: ignore[return-value]


__all__ = [
"retry_all",
"retry_any",
"retry_if_exception",
"retry_if_result",
"WrappedFn",
"AsyncRetrying",
]
125 changes: 125 additions & 0 deletions tenacity/asyncio/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2016–2021 Julien Danjou
# Copyright 2016 Joshua Harlow
# Copyright 2013-2014 Ray Holder
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import typing

from tenacity import _utils
from tenacity import retry_base

if typing.TYPE_CHECKING:
from tenacity import RetryCallState


class async_retry_base(retry_base):
"""Abstract base class for async retry strategies."""

@abc.abstractmethod
async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override]
pass

def __and__( # type: ignore[override]
self, other: "typing.Union[retry_base, async_retry_base]"
) -> "retry_all":
return retry_all(self, other)

def __rand__( # type: ignore[misc,override]
self, other: "typing.Union[retry_base, async_retry_base]"
) -> "retry_all":
return retry_all(other, self)

def __or__( # type: ignore[override]
self, other: "typing.Union[retry_base, async_retry_base]"
) -> "retry_any":
return retry_any(self, other)

def __ror__( # type: ignore[misc,override]
self, other: "typing.Union[retry_base, async_retry_base]"
) -> "retry_any":
return retry_any(other, self)


RetryBaseT = typing.Union[
async_retry_base, typing.Callable[["RetryCallState"], typing.Awaitable[bool]]
]


class retry_if_exception(async_retry_base):
"""Retry strategy that retries if an exception verifies a predicate."""

def __init__(
self, predicate: typing.Callable[[BaseException], typing.Awaitable[bool]]
) -> None:
self.predicate = predicate

async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override]
if retry_state.outcome is None:
raise RuntimeError("__call__() called before outcome was set")

if retry_state.outcome.failed:
exception = retry_state.outcome.exception()
if exception is None:
raise RuntimeError("outcome failed but the exception is None")
return await self.predicate(exception)
else:
return False


class retry_if_result(async_retry_base):
"""Retries if the result verifies a predicate."""

def __init__(
self, predicate: typing.Callable[[typing.Any], typing.Awaitable[bool]]
) -> None:
self.predicate = predicate

async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override]
if retry_state.outcome is None:
raise RuntimeError("__call__() called before outcome was set")

if not retry_state.outcome.failed:
return await self.predicate(retry_state.outcome.result())
else:
return False


class retry_any(async_retry_base):
"""Retries if any of the retries condition is valid."""

def __init__(self, *retries: typing.Union[retry_base, async_retry_base]) -> None:
self.retries = retries

async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override]
result = False
for r in self.retries:
result = result or await _utils.wrap_to_async_func(r)(retry_state)
if result:
break
return result


class retry_all(async_retry_base):
"""Retries if all the retries condition are valid."""

def __init__(self, *retries: typing.Union[retry_base, async_retry_base]) -> None:
self.retries = retries

async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override]
result = True
for r in self.retries:
result = result and await _utils.wrap_to_async_func(r)(retry_state)
if not result:
break
return result
10 changes: 8 additions & 2 deletions tenacity/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,16 @@ def __call__(self, retry_state: "RetryCallState") -> bool:
pass

def __and__(self, other: "retry_base") -> "retry_all":
return retry_all(self, other)
return other.__rand__(self)

def __rand__(self, other: "retry_base") -> "retry_all":
return retry_all(other, self)

def __or__(self, other: "retry_base") -> "retry_any":
return retry_any(self, other)
return other.__ror__(self)

def __ror__(self, other: "retry_base") -> "retry_any":
return retry_any(other, self)


RetryBaseT = typing.Union[retry_base, typing.Callable[["RetryCallState"], bool]]
Expand Down
Loading

1 comment on commit 21137e7

@Knamdev
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is failing at line 653 in tenacity/init.py

from tenacity.asyncio import AsyncRetrying # noqa:E402,I100
ModuleNotFoundError: No module named 'tenacity.asyncio'

Please sign in to comment.