Skip to content

Commit

Permalink
fix: correct type annotations for decorators
Browse files Browse the repository at this point in the history
  • Loading branch information
MultifokalHirn committed Dec 21, 2023
1 parent 0009412 commit 6942bbb
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 14 deletions.
4 changes: 2 additions & 2 deletions docs/img/coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions src/ornaments/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from typing import ParamSpec, TypeVar

# P = TypeVar("P") # Parameter type
P = ParamSpec("P")
R = TypeVar("R", covariant=True) # Return type
3 changes: 2 additions & 1 deletion src/ornaments/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# pragma: no cover
import os


def print_logo(terminal_width: int = 80) -> None:
def print_logo(terminal_width: int = 80) -> None: # pragma: no cover
"""
Prints the logo of the project.
Expand Down
11 changes: 6 additions & 5 deletions src/ornaments/invariants/only_called_once.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import warnings
from collections.abc import Callable
from functools import wraps
from typing import Any
from collections.abc import Callable

from .._types import P, R
from ..exceptions import CalledTooOftenError, CalledTooOftenWarning
from ..scopes import CLASS_SCOPE, OBJECT_SCOPE, SESSION_SCOPE


def only_called_once(scope: str = "object", enforce: bool = False) -> Callable[..., Any]:
def only_called_once(scope: str = "object", enforce: bool = False) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
Decorator that ensures a function is only called once in a given scope.
Expand All @@ -23,17 +24,17 @@ def only_callable_once() -> str:
```
"""

def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
def decorator(func: Callable[P, R]) -> Callable[P, R]:
@wraps(wrapped=func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
if scope in OBJECT_SCOPE:
# Use the id of the instance for object scope
call_scope = (id(args[0]), func)
elif scope in CLASS_SCOPE:
# Use the class as identifier for session scope
# Note: only really useful if the class itself is instantiated more than once
# - otherwise, this will behave just as session scope
call_scope = (args[0].__class__, func)
call_scope = (id(args[0].__class__), func)
elif scope in SESSION_SCOPE: # session scope
# Use the function itself as identifier for session scope
call_scope = (id(func), func)
Expand Down
2 changes: 1 addition & 1 deletion src/ornaments/runtime_checks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .return_type_validation import checked_return_type
from .checked_return_type import checked_return_type

__all__ = [
"checked_return_type",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import warnings
from collections.abc import Callable
from functools import wraps
from typing import Any
from collections.abc import Callable

from .._types import P, R
from ..exceptions import InvalidReturnTypeError, InvalidReturnTypeWarning


def checked_return_type(enforce: bool = False) -> Callable[..., Any]:
def checked_return_type(enforce: bool = False) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
Checks that the return value of the function is of a specified type. If not, it raises an exception or raises a warning.
Expand All @@ -18,13 +18,13 @@ def my_function() -> int:
"""

def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
def decorator(func: Callable[P, R]) -> Callable[P, R]:
expected_type = func.__annotations__.get("return")
if expected_type is None:
raise ValueError("Expected type must be specified in the function annotations.")

@wraps(wrapped=func)
def wrapper(*args, **kwargs) -> Any:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
result = func(*args, **kwargs)

if not isinstance(result, expected_type):
Expand Down

0 comments on commit 6942bbb

Please sign in to comment.