From d8b4e3826614b3940a0130f06a32b7b29c135fa3 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Sat, 9 Mar 2024 12:31:57 +0100 Subject: [PATCH] improve event handler state references --- reflex/event.py | 4 ++++ reflex/state.py | 25 ++++++++++++++++++++----- reflex/utils/format.py | 19 +++++++++---------- 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/reflex/event.py b/reflex/event.py index d81e257f94..d6c57af82c 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -147,6 +147,10 @@ class EventHandler(EventActionsMixin): # The function to call in response to the event. fn: Any + # The full name of the state class this event handler is attached to. + # Emtpy string means this event handler is a server side event. + state_full_name: str = "" + class Config: """The Pydantic config.""" diff --git a/reflex/state.py b/reflex/state.py index 980b36c150..025866d9a4 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -472,7 +472,7 @@ def __init_subclass__(cls, **kwargs): events[name] = value for name, fn in events.items(): - handler = EventHandler(fn=fn) + handler = cls._create_event_handler(fn) cls.event_handlers[name] = handler setattr(cls, name, handler) @@ -677,7 +677,7 @@ def get_full_name(cls) -> str: @classmethod @functools.lru_cache() - def get_class_substate(cls, path: Sequence[str]) -> Type[BaseState]: + def get_class_substate(cls, path: Sequence[str] | str) -> Type[BaseState]: """Get the class substate. Args: @@ -689,6 +689,9 @@ def get_class_substate(cls, path: Sequence[str]) -> Type[BaseState]: Raises: ValueError: If the substate is not found. """ + if isinstance(path, str): + path = tuple(path.split(".")) + if len(path) == 0: return cls if path[0] == cls.get_name(): @@ -789,6 +792,18 @@ def _set_var(cls, prop: BaseVar): """ setattr(cls, prop._var_name, prop) + @classmethod + def _create_event_handler(cls, fn): + """Create an event handler for the given function. + + Args: + fn: The function to create an event handler for. + + Returns: + The event handler. + """ + return EventHandler(fn=fn, state_full_name=cls.get_full_name()) + @classmethod def _create_setter(cls, prop: BaseVar): """Create a setter for the var. @@ -798,7 +813,7 @@ def _create_setter(cls, prop: BaseVar): """ setter_name = prop.get_setter_name(include_state=False) if setter_name not in cls.__dict__: - event_handler = EventHandler(fn=prop.get_setter()) + event_handler = cls._create_event_handler(prop.get_setter()) cls.event_handlers[setter_name] = event_handler setattr(cls, setter_name, event_handler) @@ -1752,7 +1767,7 @@ async def update_vars_internal(self, vars: dict[str, Any]) -> None: """ for var, value in vars.items(): state_name, _, var_name = var.rpartition(".") - var_state_cls = State.get_class_substate(tuple(state_name.split("."))) + var_state_cls = State.get_class_substate(state_name) var_state = await self.get_state(var_state_cls) setattr(var_state, var_name, value) @@ -2268,7 +2283,7 @@ async def get_state( _, state_path = _split_substate_key(token) if state_path: # Get the State class associated with the given path. - state_cls = self.state.get_class_substate(tuple(state_path.split("."))) + state_cls = self.state.get_class_substate(state_path) else: raise RuntimeError( "StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}" diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 53b55d0ebd..0e57ca9efc 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -6,7 +6,6 @@ import json import os import re -import sys from typing import TYPE_CHECKING, Any, List, Union from reflex import constants @@ -470,18 +469,18 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]: if len(parts) == 1: return ("", parts[-1]) - # Get the state and the function name. - state_name, name = parts[-2:] + # Get the state full name + state_full_name = handler.state_full_name - # Construct the full event handler name. - try: - # Try to get the state from the module. - state = vars(sys.modules[handler.fn.__module__])[state_name] - except Exception: - # If the state isn't in the module, just return the function name. + # Get the function name + name = parts[-1] + + from reflex.state import State + + if state_full_name == "state" and name not in State.__dict__: return ("", to_snake_case(handler.fn.__qualname__)) - return (state.get_full_name(), name) + return (state_full_name, name) def format_event_handler(handler: EventHandler) -> str: