Skip to content

Commit

Permalink
Annotate decorators that wrap Document methods (BeanieODM#679)
Browse files Browse the repository at this point in the history
Type checkers were complaining about missing `self`
argument in decorated `Document` methods. This was
caused by incomplete annotations of used decorators.
  • Loading branch information
Maxim Borisov committed Feb 29, 2024
1 parent e2d95be commit 7f8f92a
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 31 deletions.
55 changes: 40 additions & 15 deletions beanie/odm/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@
Optional,
Tuple,
Type,
TypeVar,
Union,
)

from typing_extensions import ParamSpec

if TYPE_CHECKING:
from beanie.odm.documents import Document
from beanie.odm.documents import AsyncDocMethod, DocType, Document

P = ParamSpec("P")
R = TypeVar("R")


class EventTypes(str, Enum):
Expand Down Expand Up @@ -136,10 +142,14 @@ async def run_actions(
await asyncio.gather(*coros)


# `Any` because there is arbitrary attribute assignment on this type
F = TypeVar("F", bound=Any)


def register_action(
event_types: Tuple[Union[List[EventTypes], EventTypes]],
event_types: Tuple[Union[List[EventTypes], EventTypes], ...],
action_direction: ActionDirections,
):
) -> Callable[[F], F]:
"""
Decorator. Base registration method.
Used inside `before_event` and `after_event`
Expand All @@ -154,7 +164,7 @@ def register_action(
else:
final_event_types.append(event_type)

def decorator(f):
def decorator(f: F) -> F:
f.has_action = True
f.event_types = final_event_types
f.action_direction = action_direction
Expand All @@ -163,7 +173,9 @@ def decorator(f):
return decorator


def before_event(*args: Union[List[EventTypes], EventTypes]):
def before_event(
*args: Union[List[EventTypes], EventTypes]
) -> Callable[[F], F]:
"""
Decorator. It adds action, which should run before mentioned one
or many events happen
Expand All @@ -172,11 +184,13 @@ def before_event(*args: Union[List[EventTypes], EventTypes]):
:return: None
"""
return register_action(
action_direction=ActionDirections.BEFORE, event_types=args # type: ignore
action_direction=ActionDirections.BEFORE, event_types=args
)


def after_event(*args: Union[List[EventTypes], EventTypes]):
def after_event(
*args: Union[List[EventTypes], EventTypes]
) -> Callable[[F], F]:
"""
Decorator. It adds action, which should run after mentioned one
or many events happen
Expand All @@ -186,26 +200,32 @@ def after_event(*args: Union[List[EventTypes], EventTypes]):
"""

return register_action(
action_direction=ActionDirections.AFTER, event_types=args # type: ignore
action_direction=ActionDirections.AFTER, event_types=args
)


def wrap_with_actions(event_type: EventTypes):
def wrap_with_actions(
event_type: EventTypes,
) -> Callable[
["AsyncDocMethod[DocType, P, R]"], "AsyncDocMethod[DocType, P, R]"
]:
"""
Helper function to wrap Document methods with
before and after event listeners
:param event_type: EventTypes - event types
:return: None
"""

def decorator(f: Callable):
def decorator(
f: "AsyncDocMethod[DocType, P, R]",
) -> "AsyncDocMethod[DocType, P, R]":
@wraps(f)
async def wrapper(
self,
*args,
self: "Document",
*args: P.args,
skip_actions: Optional[List[Union[ActionDirections, str]]] = None,
**kwargs,
):
**kwargs: P.kwargs,
) -> R:
if skip_actions is None:
skip_actions = []

Expand All @@ -216,7 +236,12 @@ async def wrapper(
exclude=skip_actions,
)

result = await f(self, *args, skip_actions=skip_actions, **kwargs)
result = await f(
self,
*args,
skip_actions=skip_actions, # type: ignore[arg-type]
**kwargs,
)

await ActionRegistry.run_actions(
self,
Expand Down
19 changes: 13 additions & 6 deletions beanie/odm/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from enum import Enum
from typing import (
Any,
Awaitable,
Callable,
ClassVar,
Dict,
Iterable,
Expand Down Expand Up @@ -32,6 +34,7 @@
DeleteResult,
InsertManyResult,
)
from typing_extensions import Concatenate, ParamSpec, TypeAlias

from beanie.exceptions import (
CollectionWasNotInitialized,
Expand Down Expand Up @@ -104,6 +107,10 @@
from pydantic import model_validator

DocType = TypeVar("DocType", bound="Document")
P = ParamSpec("P")
R = TypeVar("R")
SyncDocMethod: TypeAlias = Callable[Concatenate[DocType, P], R]
AsyncDocMethod: TypeAlias = Callable[Concatenate[DocType, P], Awaitable[R]]
DocumentProjectionType = TypeVar("DocumentProjectionType", bound=BaseModel)


Expand Down Expand Up @@ -529,7 +536,7 @@ async def save(
link_rule: WriteRules = WriteRules.DO_NOTHING,
ignore_revision: bool = False,
**kwargs,
) -> None:
) -> DocType:
"""
Update an existing model in the database or
insert it if it does not yet exist.
Expand Down Expand Up @@ -605,12 +612,12 @@ async def save(
@wrap_with_actions(EventTypes.SAVE_CHANGES)
@validate_self_before
async def save_changes(
self,
self: DocType,
ignore_revision: bool = False,
session: Optional[ClientSession] = None,
bulk_writer: Optional[BulkWriter] = None,
skip_actions: Optional[List[Union[ActionDirections, str]]] = None,
) -> None:
) -> Optional[DocType]:
"""
Save changes.
State management usage must be turned on
Expand All @@ -632,7 +639,7 @@ async def save_changes(
)
else:
return await self.set(
changes, # type: ignore #TODO fix typing
changes,
ignore_revision=ignore_revision,
session=session,
bulk_writer=bulk_writer,
Expand Down Expand Up @@ -741,13 +748,13 @@ def update_all(
)

def set(
self,
self: DocType,
expression: Dict[Union[ExpressionField, str], Any],
session: Optional[ClientSession] = None,
bulk_writer: Optional[BulkWriter] = None,
skip_sync: Optional[bool] = None,
**kwargs,
):
) -> Awaitable[DocType]:
"""
Set values
Expand Down
15 changes: 11 additions & 4 deletions beanie/odm/utils/self_validation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from functools import wraps
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, TypeVar

from typing_extensions import ParamSpec

if TYPE_CHECKING:
from beanie.odm.documents import DocType
from beanie.odm.documents import AsyncDocMethod, DocType

P = ParamSpec("P")
R = TypeVar("R")


def validate_self_before(f: Callable):
def validate_self_before(
f: "AsyncDocMethod[DocType, P, R]",
) -> "AsyncDocMethod[DocType, P, R]":
@wraps(f)
async def wrapper(self: "DocType", *args, **kwargs):
async def wrapper(self: "DocType", *args: P.args, **kwargs: P.kwargs) -> R:
await self.validate_self(*args, **kwargs)
return await f(self, *args, **kwargs)

Expand Down
47 changes: 41 additions & 6 deletions beanie/odm/utils/state.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import inspect
from functools import wraps
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Callable, TypeVar, overload

from typing_extensions import ParamSpec

from beanie.exceptions import StateManagementIsTurnedOff, StateNotSaved

if TYPE_CHECKING:
from beanie.odm.documents import DocType
from beanie.odm.documents import AsyncDocMethod, DocType, SyncDocMethod

P = ParamSpec("P")
R = TypeVar("R")


def check_if_state_saved(self: "DocType"):
Expand All @@ -17,7 +22,21 @@ def check_if_state_saved(self: "DocType"):
raise StateNotSaved("No state was saved")


def saved_state_needed(f: Callable):
@overload
def saved_state_needed(
f: "AsyncDocMethod[DocType, P, R]",
) -> "AsyncDocMethod[DocType, P, R]":
...


@overload
def saved_state_needed(
f: "SyncDocMethod[DocType, P, R]",
) -> "SyncDocMethod[DocType, P, R]":
...


def saved_state_needed(f: Callable) -> Callable:
@wraps(f)
def sync_wrapper(self: "DocType", *args, **kwargs):
check_if_state_saved(self)
Expand All @@ -44,7 +63,21 @@ def check_if_previous_state_saved(self: "DocType"):
)


def previous_saved_state_needed(f: Callable):
@overload
def previous_saved_state_needed(
f: "AsyncDocMethod[DocType, P, R]",
) -> "AsyncDocMethod[DocType, P, R]":
...


@overload
def previous_saved_state_needed(
f: "SyncDocMethod[DocType, P, R]",
) -> "SyncDocMethod[DocType, P, R]":
...


def previous_saved_state_needed(f: Callable) -> Callable:
@wraps(f)
def sync_wrapper(self: "DocType", *args, **kwargs):
check_if_previous_state_saved(self)
Expand All @@ -60,9 +93,11 @@ async def async_wrapper(self: "DocType", *args, **kwargs):
return sync_wrapper


def save_state_after(f: Callable):
def save_state_after(
f: "AsyncDocMethod[DocType, P, R]",
) -> "AsyncDocMethod[DocType, P, R]":
@wraps(f)
async def wrapper(self: "DocType", *args, **kwargs):
async def wrapper(self: "DocType", *args: P.args, **kwargs: P.kwargs) -> R:
result = await f(self, *args, **kwargs)
self._save_state()
return result
Expand Down

0 comments on commit 7f8f92a

Please sign in to comment.