Skip to content

Commit

Permalink
state: _init_event_handlers recursively (#1640)
Browse files Browse the repository at this point in the history
  • Loading branch information
masenf authored Aug 25, 2023
1 parent dbaa6a1 commit 12e516d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 5 deletions.
29 changes: 24 additions & 5 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,7 @@ def __init__(self, *args, parent_state: Optional[State] = None, **kwargs):
for substate in self.get_substates():
self.substates[substate.get_name()] = substate(parent_state=self)
# Convert the event handlers to functions.
for name, event_handler in self.event_handlers.items():
fn = functools.partial(event_handler.fn, self)
fn.__module__ = event_handler.fn.__module__ # type: ignore
fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
setattr(self, name, fn)
self._init_event_handlers()

# Initialize computed vars dependencies.
inherited_vars = set(self.inherited_vars).union(
Expand Down Expand Up @@ -155,6 +151,29 @@ def _init_mutable_fields(self):

self._clean()

def _init_event_handlers(self, state: State | None = None):
"""Initialize event handlers.
Allow event handlers to be called directly on the instance. This is
called recursively for all parent states.
Args:
state: The state to initialize the event handlers on.
"""
if state is None:
state = self

# Convert the event handlers to functions.
for name, event_handler in state.event_handlers.items():
fn = functools.partial(event_handler.fn, self)
fn.__module__ = event_handler.fn.__module__ # type: ignore
fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
setattr(self, name, fn)

# Also allow direct calling of parent state event handlers
if state.parent_state is not None:
self._init_event_handlers(state.parent_state)

def _reassign_field(self, field_name: str):
"""Reassign the given field.
Expand Down
8 changes: 8 additions & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,10 +992,18 @@ def set_v(self, v: int):
def set_v2(self, v: int):
self.set_v(v)

class SubState(MainState):
def set_v3(self, v: int):
self.set_v2(v)

ms = MainState()
ms.set_v2(1)
assert ms.v == 1

# ensure handler can be called from substate
ms.substates[SubState.get_name()].set_v3(2)
assert ms.v == 2


def test_computed_var_cached():
"""Test that a ComputedVar doesn't recalculate when accessed."""
Expand Down

0 comments on commit 12e516d

Please sign in to comment.