Skip to content

Commit

Permalink
chore(typing): improve typing of WrappedFn (#390)
Browse files Browse the repository at this point in the history
This change improves the typing of WrappedFn.
It makes explictly the two signatures of tenacity.retry() with overload.

This avoids mypy thinking the return type is `<nothing>`
  • Loading branch information
sileht authored Feb 9, 2023
1 parent 78c8d4b commit b49eb37
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 47 deletions.
97 changes: 65 additions & 32 deletions tenacity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import functools
import sys
import threading
Expand Down Expand Up @@ -91,37 +92,8 @@
from .wait import WaitBaseT


WrappedFnReturnT = t.TypeVar("WrappedFnReturnT")
WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Any])
_RetValT = t.TypeVar("_RetValT")


def retry(*dargs: t.Any, **dkw: t.Any) -> t.Union[WrappedFn, t.Callable[[WrappedFn], WrappedFn]]: # noqa
"""Wrap a function with a new `Retrying` object.
:param dargs: positional arguments passed to Retrying object
:param dkw: keyword arguments passed to the Retrying object
"""
# support both @retry and @retry() as valid syntax
if len(dargs) == 1 and callable(dargs[0]):
return retry()(dargs[0])
else:

def wrap(f: WrappedFn) -> WrappedFn:
if isinstance(f, retry_base):
warnings.warn(
f"Got retry_base instance ({f.__class__.__name__}) as callable argument, "
f"this will probably hang indefinitely (did you mean retry={f.__class__.__name__}(...)?)"
)
if iscoroutinefunction(f):
r: "BaseRetrying" = AsyncRetrying(*dargs, **dkw)
elif tornado and hasattr(tornado.gen, "is_coroutine_function") and tornado.gen.is_coroutine_function(f):
r = TornadoRetrying(*dargs, **dkw)
else:
r = Retrying(*dargs, **dkw)

return r.wraps(f)

return wrap


class TryAgain(Exception):
Expand Down Expand Up @@ -382,14 +354,24 @@ def __iter__(self) -> t.Generator[AttemptManager, None, None]:
break

@abstractmethod
def __call__(self, fn: t.Callable[..., _RetValT], *args: t.Any, **kwargs: t.Any) -> _RetValT:
def __call__(
self,
fn: t.Callable[..., WrappedFnReturnT],
*args: t.Any,
**kwargs: t.Any,
) -> WrappedFnReturnT:
pass


class Retrying(BaseRetrying):
"""Retrying controller."""

def __call__(self, fn: t.Callable[..., _RetValT], *args: t.Any, **kwargs: t.Any) -> _RetValT:
def __call__(
self,
fn: t.Callable[..., WrappedFnReturnT],
*args: t.Any,
**kwargs: t.Any,
) -> WrappedFnReturnT:
self.begin()

retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
Expand Down Expand Up @@ -510,6 +492,57 @@ def __repr__(self) -> str:
return f"<{clsname} {id(self)}: attempt #{self.attempt_number}; slept for {slept}; last result: {result}>"


@t.overload
def retry(func: WrappedFn) -> WrappedFn:
...


@t.overload
def retry(
sleep: t.Callable[[t.Union[int, float]], None] = sleep,
stop: "StopBaseT" = stop_never,
wait: "WaitBaseT" = wait_none(),
retry: "RetryBaseT" = retry_if_exception_type(),
before: t.Callable[["RetryCallState"], None] = before_nothing,
after: t.Callable[["RetryCallState"], None] = after_nothing,
before_sleep: t.Optional[t.Callable[["RetryCallState"], None]] = None,
reraise: bool = False,
retry_error_cls: t.Type["RetryError"] = RetryError,
retry_error_callback: t.Optional[t.Callable[["RetryCallState"], t.Any]] = None,
) -> t.Callable[[WrappedFn], WrappedFn]:
...


def retry(*dargs: t.Any, **dkw: t.Any) -> t.Any:
"""Wrap a function with a new `Retrying` object.
:param dargs: positional arguments passed to Retrying object
:param dkw: keyword arguments passed to the Retrying object
"""
# support both @retry and @retry() as valid syntax
if len(dargs) == 1 and callable(dargs[0]):
return retry()(dargs[0])
else:

def wrap(f: WrappedFn) -> WrappedFn:
if isinstance(f, retry_base):
warnings.warn(
f"Got retry_base instance ({f.__class__.__name__}) as callable argument, "
f"this will probably hang indefinitely (did you mean retry={f.__class__.__name__}(...)?)"
)
r: "BaseRetrying"
if iscoroutinefunction(f):
r = AsyncRetrying(*dargs, **dkw)
elif tornado and hasattr(tornado.gen, "is_coroutine_function") and tornado.gen.is_coroutine_function(f):
r = TornadoRetrying(*dargs, **dkw)
else:
r = Retrying(*dargs, **dkw)

return r.wraps(f)

return wrap


from tenacity._asyncio import AsyncRetrying # noqa:E402,I100

if tornado:
Expand Down
24 changes: 10 additions & 14 deletions tenacity/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import functools
import sys
import typing
import typing as t
from asyncio import sleep

from tenacity import AttemptManager
Expand All @@ -26,24 +26,20 @@
from tenacity import DoSleep
from tenacity import RetryCallState


WrappedFn = typing.TypeVar("WrappedFn", bound=typing.Callable[..., typing.Any])
_RetValT = typing.TypeVar("_RetValT")
WrappedFnReturnT = t.TypeVar("WrappedFnReturnT")
WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]])


class AsyncRetrying(BaseRetrying):
def __init__(
self, sleep: typing.Callable[[float], typing.Awaitable[typing.Any]] = sleep, **kwargs: typing.Any
) -> None:
sleep: t.Callable[[float], t.Awaitable[t.Any]]

def __init__(self, sleep: t.Callable[[float], t.Awaitable[t.Any]] = sleep, **kwargs: t.Any) -> None:
super().__init__(**kwargs)
self.sleep = sleep

async def __call__( # type: ignore[override]
self,
fn: typing.Callable[..., typing.Awaitable[_RetValT]],
*args: typing.Any,
**kwargs: typing.Any,
) -> _RetValT:
self, fn: WrappedFn, *args: t.Any, **kwargs: t.Any
) -> WrappedFnReturnT:
self.begin()

retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
Expand All @@ -62,7 +58,7 @@ async def __call__( # type: ignore[override]
else:
return do # type: ignore[no-any-return]

def __iter__(self) -> typing.Generator[AttemptManager, None, None]:
def __iter__(self) -> t.Generator[AttemptManager, None, None]:
raise TypeError("AsyncRetrying object is not iterable")

def __aiter__(self) -> "AsyncRetrying":
Expand All @@ -88,7 +84,7 @@ def wraps(self, fn: WrappedFn) -> WrappedFn:
# Ensure wrapper is recognized as a coroutine function.

@functools.wraps(fn)
async def async_wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any:
return await fn(*args, **kwargs)

# Preserve attributes
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ commands =

[testenv:mypy]
deps =
mypy
mypy>=1.0.0
commands =
mypy tenacity

Expand Down

0 comments on commit b49eb37

Please sign in to comment.