Skip to content

Commit

Permalink
Forward wrapped args with generic callables
Browse files Browse the repository at this point in the history
Signed-off-by: kadambishreyas <shreyaskadambi@gmail.com>
  • Loading branch information
kadambishreyas committed Jan 8, 2025
1 parent 83d17a5 commit c952247
Showing 1 changed file with 37 additions and 25 deletions.
62 changes: 37 additions & 25 deletions once/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ class _WrappedFunctionType(enum.Enum):
_ASYNC_FN_TYPES = (_WrappedFunctionType.ASYNC_FUNCTION, _WrappedFunctionType.ASYNC_GENERATOR)


_P = typing.ParamSpec("_P")


_R = typing.TypeVar("_R")


def _wrapped_function_type(func: collections.abc.Callable) -> _WrappedFunctionType:
# The function inspect.isawaitable is a bit of a misnomer - it refers
# to the awaitable result of an async function, not the async function
Expand Down Expand Up @@ -103,11 +109,11 @@ def __init__(self, exception: Exception):


def _wrap(
func: collections.abc.Callable,
func: collections.abc.Callable[_P, _R],
once_factory: _ONCE_FACTORY_TYPE,
fn_type: _WrappedFunctionType,
retry_exceptions: bool,
) -> collections.abc.Callable:
) -> collections.abc.Callable[_P, _R]:
"""Generate a wrapped function appropriate to the function type.
The once_factory lets us reuse logic for both per-thread and singleton.
Expand Down Expand Up @@ -150,7 +156,7 @@ async def wrapped(*args, **kwargs) -> typing.Any:
async with once_base.async_lock:
if not once_base.called:
try:
once_base.return_value = await func(*args, **kwargs)
once_base.return_value = await func(*args, **kwargs) # type: ignore
except Exception as exception:
if retry_exceptions:
raise exception
Expand Down Expand Up @@ -266,8 +272,8 @@ def _get_once_per_thread():


def once(
*args, per_thread=False, retry_exceptions=False, allow_reset=False
) -> collections.abc.Callable:
*args: collections.abc.Callable[_P, _R], per_thread=False, retry_exceptions=False, allow_reset=False
) -> collections.abc.Callable[_P, _R]:
"""Decorator to ensure a function is only called once.
The restriction of only one call also holds across threads. However, this
Expand Down Expand Up @@ -312,11 +318,14 @@ def once(
# This trick lets this function be a decorator directly, or be called
# to create a decorator.
# Both @once and @once() will function correctly.
return functools.partial(
once,
per_thread=per_thread,
retry_exceptions=retry_exceptions,
allow_reset=allow_reset,
return typing.cast(
collections.abc.Callable[_P, _R],
functools.partial(
once,
per_thread=per_thread,
retry_exceptions=retry_exceptions,
allow_reset=allow_reset,
)
)
if _is_method(func):
raise SyntaxError(
Expand All @@ -332,11 +341,12 @@ def once(
return _wrap(func, once_factory, fn_type, retry_exceptions)


class once_per_class: # pylint: disable=invalid-name
class once_per_class(typing.Generic[_P, _R]): # pylint: disable=invalid-name
"""A version of once for class methods which runs once across all instances."""

is_classmethod: bool
is_staticmethod: bool
func: collections.abc.Callable[_P, _R]

@classmethod
def with_options(cls, per_thread: bool = False, retry_exceptions=False, allow_reset=False):
Expand All @@ -349,7 +359,7 @@ def with_options(cls, per_thread: bool = False, retry_exceptions=False, allow_re

def __init__(
self,
func: collections.abc.Callable,
func: collections.abc.Callable[_P, _R],
per_thread: bool = False,
retry_exceptions: bool = False,
allow_reset: bool = False,
Expand All @@ -363,7 +373,7 @@ def __init__(
)
self.retry_exceptions = retry_exceptions

def _inspect_function(self, func: collections.abc.Callable):
def _inspect_function(self, func: collections.abc.Callable[_P, _R]) -> collections.abc.Callable[_P, _R]:
if not _is_method(func):
raise SyntaxError(
"Attempting to use @once.once_per_class method-only decorator "
Expand All @@ -383,20 +393,22 @@ def _inspect_function(self, func: collections.abc.Callable):

# This is needed for a decorator on a class method to return a
# bound version of the function to the object or class.
def __get__(self, obj, cls) -> collections.abc.Callable:
def __get__(self, obj, cls) -> collections.abc.Callable[_P, _R]:
func = self.func
if self.is_classmethod:
func = functools.partial(self.func, cls)
elif self.is_staticmethod:
func = self.func
else:
elif not self.is_staticmethod:
func = functools.partial(self.func, obj)

# Properly annotate the return type of _wrap to match Callable[P, R].
return _wrap(func, self.once_factory, self.fn_type, self.retry_exceptions)


class once_per_instance: # pylint: disable=invalid-name
class once_per_instance(typing.Generic[_P, _R]): # pylint: disable=invalid-name
"""A version of once for class methods which runs once per instance."""

is_property: bool
func: collections.abc.Callable[_P, _R]

@classmethod
def with_options(cls, per_thread: bool = False, retry_exceptions=False, allow_reset=False):
Expand All @@ -406,7 +418,7 @@ def with_options(cls, per_thread: bool = False, retry_exceptions=False, allow_re

def __init__(
self,
func: collections.abc.Callable,
func: collections.abc.Callable[_P, _R],
per_thread: bool = False,
retry_exceptions: bool = False,
allow_reset: bool = False,
Expand All @@ -415,9 +427,9 @@ def __init__(
self.fn_type = _wrapped_function_type(self.func)
self.is_async_fn = self.fn_type in _ASYNC_FN_TYPES
self.callables_lock = threading.Lock()
self.callables: weakref.WeakKeyDictionary[typing.Any, collections.abc.Callable] = (
weakref.WeakKeyDictionary()
)
self.callables: weakref.WeakKeyDictionary[
typing.Any, collections.abc.Callable[_P, _R]
] = weakref.WeakKeyDictionary()
self.per_thread = per_thread
self.retry_exceptions = retry_exceptions
self.allow_reset = allow_reset
Expand All @@ -431,7 +443,7 @@ def once_factory(self) -> _ONCE_FACTORY_TYPE:
self.is_async_fn, per_thread=self.per_thread, allow_reset=self.allow_reset
)

def _inspect_function(self, func: collections.abc.Callable):
def _inspect_function(self, func: collections.abc.Callable[_P, _R]) -> collections.abc.Callable[_P, _R]:
if isinstance(func, (classmethod, staticmethod)):
raise SyntaxError("Must use @once.once_per_class on classmethod and staticmethod")
if isinstance(func, property):
Expand All @@ -448,7 +460,7 @@ def _inspect_function(self, func: collections.abc.Callable):

# This is needed for a decorator on a class method to return a
# bound version of the function to the object.
def __get__(self, obj, cls) -> collections.abc.Callable:
def __get__(self, obj, cls) -> collections.abc.Callable[_P, _R]:
del cls
with self.callables_lock:
if (callable := self.callables.get(obj)) is None:
Expand All @@ -458,5 +470,5 @@ def __get__(self, obj, cls) -> collections.abc.Callable:
)
self.callables[obj] = callable
if self.is_property:
return callable()
return callable() # type: ignore
return callable

0 comments on commit c952247

Please sign in to comment.