diff --git a/reflex/state.py b/reflex/state.py index 1b62d5c566..7cd789fa05 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -106,11 +106,7 @@ def __init__(self, *args, parent_state: Optional[State] = None, **kwargs): for substate in self.get_substates(): self.substates[substate.get_name()] = substate(parent_state=self) # Convert the event handlers to functions. - for name, event_handler in self.event_handlers.items(): - fn = functools.partial(event_handler.fn, self) - fn.__module__ = event_handler.fn.__module__ # type: ignore - fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore - setattr(self, name, fn) + self._init_event_handlers() # Initialize computed vars dependencies. inherited_vars = set(self.inherited_vars).union( @@ -155,6 +151,29 @@ def _init_mutable_fields(self): self._clean() + def _init_event_handlers(self, state: State | None = None): + """Initialize event handlers. + + Allow event handlers to be called directly on the instance. This is + called recursively for all parent states. + + Args: + state: The state to initialize the event handlers on. + """ + if state is None: + state = self + + # Convert the event handlers to functions. + for name, event_handler in state.event_handlers.items(): + fn = functools.partial(event_handler.fn, self) + fn.__module__ = event_handler.fn.__module__ # type: ignore + fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore + setattr(self, name, fn) + + # Also allow direct calling of parent state event handlers + if state.parent_state is not None: + self._init_event_handlers(state.parent_state) + def _reassign_field(self, field_name: str): """Reassign the given field. diff --git a/tests/test_state.py b/tests/test_state.py index e13da9dac7..f5650fc0f0 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -992,10 +992,18 @@ def set_v(self, v: int): def set_v2(self, v: int): self.set_v(v) + class SubState(MainState): + def set_v3(self, v: int): + self.set_v2(v) + ms = MainState() ms.set_v2(1) assert ms.v == 1 + # ensure handler can be called from substate + ms.substates[SubState.get_name()].set_v3(2) + assert ms.v == 2 + def test_computed_var_cached(): """Test that a ComputedVar doesn't recalculate when accessed."""