Skip to content

Commit

Permalink
fixup! Annotate decorators that wrap Document methods (BeanieODM#679)
Browse files Browse the repository at this point in the history
Removed sync/async overload in favour of ignoring errors in wrappers
because mypy confused them and always expected async function.
  • Loading branch information
Maxim Borisov committed Feb 29, 2024
1 parent 7f8f92a commit 0ac26c9
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 37 deletions.
2 changes: 1 addition & 1 deletion beanie/odm/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def decorator(
) -> "AsyncDocMethod[DocType, P, R]":
@wraps(f)
async def wrapper(
self: "Document",
self: "DocType",
*args: P.args,
skip_actions: Optional[List[Union[ActionDirections, str]]] = None,
**kwargs: P.kwargs,
Expand Down
9 changes: 7 additions & 2 deletions beanie/odm/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Awaitable,
Callable,
ClassVar,
Coroutine,
Dict,
Iterable,
List,
Expand Down Expand Up @@ -109,8 +110,12 @@
DocType = TypeVar("DocType", bound="Document")
P = ParamSpec("P")
R = TypeVar("R")
SyncDocMethod: TypeAlias = Callable[Concatenate[DocType, P], R]
AsyncDocMethod: TypeAlias = Callable[Concatenate[DocType, P], Awaitable[R]]
# can describe both sync and async, where R itself is a coroutine
AnyDocMethod: TypeAlias = Callable[Concatenate[DocType, P], R]
# describes only async
AsyncDocMethod: TypeAlias = Callable[
Concatenate[DocType, P], Coroutine[Any, Any, R]
]
DocumentProjectionType = TypeVar("DocumentProjectionType", bound=BaseModel)


Expand Down
52 changes: 18 additions & 34 deletions beanie/odm/utils/state.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import inspect
from functools import wraps
from typing import TYPE_CHECKING, Callable, TypeVar, overload
from typing import TYPE_CHECKING, TypeVar

from typing_extensions import ParamSpec

from beanie.exceptions import StateManagementIsTurnedOff, StateNotSaved

if TYPE_CHECKING:
from beanie.odm.documents import AsyncDocMethod, DocType, SyncDocMethod
from beanie.odm.documents import AnyDocMethod, AsyncDocMethod, DocType

P = ParamSpec("P")
R = TypeVar("R")
Expand All @@ -22,21 +22,9 @@ def check_if_state_saved(self: "DocType"):
raise StateNotSaved("No state was saved")


@overload
def saved_state_needed(
f: "AsyncDocMethod[DocType, P, R]",
) -> "AsyncDocMethod[DocType, P, R]":
...


@overload
def saved_state_needed(
f: "SyncDocMethod[DocType, P, R]",
) -> "SyncDocMethod[DocType, P, R]":
...


def saved_state_needed(f: Callable) -> Callable:
f: "AnyDocMethod[DocType, P, R]",
) -> "AnyDocMethod[DocType, P, R]":
@wraps(f)
def sync_wrapper(self: "DocType", *args, **kwargs):
check_if_state_saved(self)
Expand All @@ -45,10 +33,14 @@ def sync_wrapper(self: "DocType", *args, **kwargs):
@wraps(f)
async def async_wrapper(self: "DocType", *args, **kwargs):
check_if_state_saved(self)
return await f(self, *args, **kwargs)
# type ignore because there is no nice/proper way to annotate both sync
# and async case without parametrized TypeVar, which is not supported
return await f(self, *args, **kwargs) # type: ignore[misc]

if inspect.iscoroutinefunction(f):
return async_wrapper
# type ignore because there is no nice/proper way to annotate both sync
# and async case without parametrized TypeVar, which is not supported
return async_wrapper # type: ignore[return-value]
return sync_wrapper


Expand All @@ -63,21 +55,9 @@ def check_if_previous_state_saved(self: "DocType"):
)


@overload
def previous_saved_state_needed(
f: "AsyncDocMethod[DocType, P, R]",
) -> "AsyncDocMethod[DocType, P, R]":
...


@overload
def previous_saved_state_needed(
f: "SyncDocMethod[DocType, P, R]",
) -> "SyncDocMethod[DocType, P, R]":
...


def previous_saved_state_needed(f: Callable) -> Callable:
f: "AnyDocMethod[DocType, P, R]",
) -> "AnyDocMethod[DocType, P, R]":
@wraps(f)
def sync_wrapper(self: "DocType", *args, **kwargs):
check_if_previous_state_saved(self)
Expand All @@ -86,10 +66,14 @@ def sync_wrapper(self: "DocType", *args, **kwargs):
@wraps(f)
async def async_wrapper(self: "DocType", *args, **kwargs):
check_if_previous_state_saved(self)
return await f(self, *args, **kwargs)
# type ignore because there is no nice/proper way to annotate both sync
# and async case without parametrized TypeVar, which is not supported
return await f(self, *args, **kwargs) # type: ignore[misc]

if inspect.iscoroutinefunction(f):
return async_wrapper
# type ignore because there is no nice/proper way to annotate both sync
# and async case without parametrized TypeVar, which is not supported
return async_wrapper # type: ignore[return-value]
return sync_wrapper


Expand Down
87 changes: 87 additions & 0 deletions tests/typing/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import (
Any,
Callable,
Coroutine,
Protocol,
TypeAlias,
assert_type,
)

from beanie import Document
from beanie.odm.actions import EventTypes, wrap_with_actions
from beanie.odm.utils.self_validation import validate_self_before
from beanie.odm.utils.state import (
previous_saved_state_needed,
save_state_after,
saved_state_needed,
)


def sync_func(doc_self: Document, arg1: str, arg2: int, /) -> Document:
"""
Models `Document` sync method that expects self
"""
raise NotImplementedError


SyncFunc: TypeAlias = Callable[[Document, str, int], Document]


async def async_func(doc_self: Document, arg1: str, arg2: int, /) -> Document:
"""
Models `Document` async method that expects self
"""
raise NotImplementedError


AsyncFunc: TypeAlias = Callable[
[Document, str, int], Coroutine[Any, Any, Document]
]


def test_wrap_with_actions_preserves_signature() -> None:
assert_type(async_func, AsyncFunc)
assert_type(wrap_with_actions(EventTypes.SAVE)(async_func), AsyncFunc)


def test_save_state_after_preserves_signature() -> None:
assert_type(async_func, AsyncFunc)
assert_type(save_state_after(async_func), AsyncFunc)


def test_validate_self_before_preserves_signature() -> None:
assert_type(async_func, AsyncFunc)
assert_type(validate_self_before(async_func), AsyncFunc)


def test_saved_state_needed_preserves_signature() -> None:
assert_type(async_func, AsyncFunc)
assert_type(saved_state_needed(async_func), AsyncFunc)

assert_type(sync_func, SyncFunc)
assert_type(saved_state_needed(sync_func), SyncFunc)


def test_previous_saved_state_needed_preserves_signature() -> None:
assert_type(async_func, AsyncFunc)
assert_type(previous_saved_state_needed(async_func), AsyncFunc)

assert_type(sync_func, SyncFunc)
assert_type(previous_saved_state_needed(sync_func), SyncFunc)


class ExpectsDocumentSelf(Protocol):
def __call__(self, doc_self: Document, /) -> Any:
...


def test_document_insert_expects_self() -> None:
test_insert: ExpectsDocumentSelf = Document.insert # noqa: F841


def test_document_save_expects_self() -> None:
test_insert: ExpectsDocumentSelf = Document.save # noqa: F841


def test_document_replace_expects_self() -> None:
test_insert: ExpectsDocumentSelf = Document.replace # noqa: F841

0 comments on commit 0ac26c9

Please sign in to comment.