Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix annotated EventHandler #3076

Merged
merged 2 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,11 @@ def get_event_triggers(self) -> Dict[str, Any]:
# e.g. variable declared as EventHandler types.
for field in self.get_fields().values():
if types._issubclass(field.type_, EventHandler):
default_triggers[field.name] = getattr(
field.type_, "args_spec", lambda: []
)
args_spec = None
annotation = field.annotation
if hasattr(annotation, "__metadata__"):
args_spec = annotation.__metadata__[0]
default_triggers[field.name] = args_spec or (lambda: [])
return default_triggers

def __repr__(self) -> str:
Expand Down
13 changes: 7 additions & 6 deletions reflex/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
Optional,
Tuple,
Union,
_GenericAlias, # type: ignore
get_type_hints,
)

Expand All @@ -23,6 +22,11 @@
from reflex.utils.types import ArgsSpec
from reflex.vars import BaseVar, Var

try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated


class Event(Base):
"""An event that describes any state change in the app."""
Expand Down Expand Up @@ -118,7 +122,7 @@ class Config:
frozen = True

@classmethod
def __class_getitem__(cls, args_spec: str) -> _GenericAlias:
def __class_getitem__(cls, args_spec: str) -> Annotated:
"""Get a typed EventHandler.

Args:
Expand All @@ -127,10 +131,7 @@ def __class_getitem__(cls, args_spec: str) -> _GenericAlias:
Returns:
The EventHandler class item.
"""
gen = _GenericAlias(cls, Any)
# Cannot subclass special typing classes, so we need to set the args_spec dynamically as an attribute.
gen.args_spec = args_spec
return gen
return Annotated[cls, args_spec]

@property
def is_background(self) -> bool:
Expand Down
23 changes: 15 additions & 8 deletions tests/components/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
custom_component,
)
from reflex.constants import EventTriggers
from reflex.event import EventChain, EventHandler
from reflex.event import EventChain, EventHandler, parse_args_spec
from reflex.state import BaseState
from reflex.style import Style
from reflex.utils import imports
Expand Down Expand Up @@ -1357,11 +1357,12 @@ def get_event_triggers(self) -> Dict[str, Any]:
"""
return {
**super().get_event_triggers(),
"on_a": lambda e: [e],
"on_b": lambda e: [e.target.value],
"on_c": lambda e: [],
"on_a": lambda e0: [e0],
"on_b": lambda e0: [e0.target.value],
"on_c": lambda e0: [],
"on_d": lambda: [],
"on_e": lambda: [],
"on_f": lambda a, b, c: [c, b, a],
}

class TestComponent(Component):
Expand All @@ -1370,10 +1371,16 @@ class TestComponent(Component):
on_c: EventHandler[lambda e0: []]
on_d: EventHandler[lambda: []]
on_e: EventHandler
on_f: EventHandler[lambda a, b, c: [c, b, a]]

custom_component = ReferenceComponent.create()
test_component = TestComponent.create()
assert (
custom_component.get_event_triggers().keys()
== test_component.get_event_triggers().keys()
)
custom_triggers = custom_component.get_event_triggers()
test_triggers = test_component.get_event_triggers()
assert custom_triggers.keys() == test_triggers.keys()
for trigger_name in custom_component.get_event_triggers():
for v1, v2 in zip(
parse_args_spec(test_triggers[trigger_name]),
parse_args_spec(custom_triggers[trigger_name]),
):
assert v1.equals(v2)
Loading