Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add additional typing for calling events #4218

Merged
merged 5 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from reflex.constants.compiler import SpecialAttributes
from reflex.event import (
EventCallback,
EventChain,
EventChainVar,
EventHandler,
Expand Down Expand Up @@ -1126,6 +1127,8 @@ def _event_trigger_values_use_state(self) -> bool:
for trigger in self.event_triggers.values():
if isinstance(trigger, EventChain):
for event in trigger.events:
if isinstance(event, EventCallback):
continue
if isinstance(event, EventSpec):
if event.handler.state_full_name:
return True
Expand Down
114 changes: 72 additions & 42 deletions reflex/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Generic,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Expand Down Expand Up @@ -389,7 +390,9 @@ def __call__(self, *args, **kwargs) -> EventSpec:
class EventChain(EventActionsMixin):
"""Container for a chain of events that will be executed in order."""

events: List[Union[EventSpec, EventVar]] = dataclasses.field(default_factory=list)
events: Sequence[Union[EventSpec, EventVar, EventCallback]] = dataclasses.field(
default_factory=list
)

args_spec: Optional[Callable] = dataclasses.field(default=None)

Expand Down Expand Up @@ -1445,13 +1448,8 @@ def create(
)


G = ParamSpec("G")

IndividualEventType = Union[EventSpec, EventHandler, Callable[G, Any], Var[Any]]

EventType = Union[IndividualEventType[G], List[IndividualEventType[G]]]

P = ParamSpec("P")
Q = ParamSpec("Q")
T = TypeVar("T")
V = TypeVar("V")
V2 = TypeVar("V2")
Expand All @@ -1473,55 +1471,73 @@ def __init__(self, func: Callable[Concatenate[Any, P], T]):
"""
self.func = func

@property
def prevent_default(self):
"""Prevent default behavior.

Returns:
The event callback with prevent default behavior.
"""
return self

@property
def stop_propagation(self):
"""Stop event propagation.

Returns:
The event callback with stop propagation behavior.
"""
return self

@overload
def __get__(
self: EventCallback[[V], T], instance: None, owner
) -> Callable[[Union[Var[V], V]], EventSpec]: ...
def __call__(
self: EventCallback[Concatenate[V, Q], T], value: V | Var[V]
) -> EventCallback[Q, T]: ...

@overload
def __get__(
self: EventCallback[[V, V2], T], instance: None, owner
) -> Callable[[Union[Var[V], V], Union[Var[V2], V2]], EventSpec]: ...
def __call__(
self: EventCallback[Concatenate[V, V2, Q], T],
value: V | Var[V],
value2: V2 | Var[V2],
) -> EventCallback[Q, T]: ...

@overload
def __get__(
self: EventCallback[[V, V2, V3], T], instance: None, owner
) -> Callable[
[Union[Var[V], V], Union[Var[V2], V2], Union[Var[V3], V3]],
EventSpec,
]: ...
def __call__(
self: EventCallback[Concatenate[V, V2, V3, Q], T],
value: V | Var[V],
value2: V2 | Var[V2],
value3: V3 | Var[V3],
) -> EventCallback[Q, T]: ...

@overload
def __get__(
self: EventCallback[[V, V2, V3, V4], T], instance: None, owner
) -> Callable[
[
Union[Var[V], V],
Union[Var[V2], V2],
Union[Var[V3], V3],
Union[Var[V4], V4],
],
EventSpec,
]: ...
def __call__(
self: EventCallback[Concatenate[V, V2, V3, V4, Q], T],
value: V | Var[V],
value2: V2 | Var[V2],
value3: V3 | Var[V3],
value4: V4 | Var[V4],
) -> EventCallback[Q, T]: ...

def __call__(self, *values) -> EventCallback: # type: ignore
"""Call the function with the values.

Args:
*values: The values to call the function with.

Returns:
The function with the values.
"""
return self.func(*values) # type: ignore

@overload
def __get__(
self: EventCallback[[V, V2, V3, V4, V5], T], instance: None, owner
) -> Callable[
[
Union[Var[V], V],
Union[Var[V2], V2],
Union[Var[V3], V3],
Union[Var[V4], V4],
Union[Var[V5], V5],
],
EventSpec,
]: ...
self: EventCallback[P, T], instance: None, owner
) -> EventCallback[P, T]: ...

@overload
def __get__(self, instance, owner) -> Callable[P, T]: ...

def __get__(self, instance, owner) -> Callable:
def __get__(self, instance, owner) -> Callable: # type: ignore
"""Get the function with the instance bound to it.

Args:
Expand All @@ -1548,6 +1564,9 @@ def event_handler(func: Callable[Concatenate[Any, P], T]) -> EventCallback[P, T]
return func # type: ignore
else:

class EventCallback(Generic[P, T]):
"""A descriptor that wraps a function to be used as an event."""

def event_handler(func: Callable[P, T]) -> Callable[P, T]:
"""Wrap a function to be used as an event.

Expand All @@ -1560,6 +1579,17 @@ def event_handler(func: Callable[P, T]) -> Callable[P, T]:
return func


G = ParamSpec("G")

IndividualEventType = Union[
EventSpec, EventHandler, Callable[G, Any], EventCallback[G, Any], Var[Any]
]

ItemOrList = Union[V, List[V]]

EventType = ItemOrList[IndividualEventType[G]]


class EventNamespace(types.SimpleNamespace):
"""A namespace for event related classes."""

Expand Down
Loading