Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve event handler state references #2818

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions reflex/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ class EventHandler(EventActionsMixin):
# The function to call in response to the event.
fn: Any

# The full name of the state class this event handler is attached to.
# Emtpy string means this event handler is a server side event.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Empty

state_full_name: str = ""

class Config:
"""The Pydantic config."""

Expand Down
25 changes: 20 additions & 5 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def __init_subclass__(cls, **kwargs):
events[name] = value

for name, fn in events.items():
handler = EventHandler(fn=fn)
handler = cls._create_event_handler(fn)
cls.event_handlers[name] = handler
setattr(cls, name, handler)

Expand Down Expand Up @@ -677,7 +677,7 @@ def get_full_name(cls) -> str:

@classmethod
@functools.lru_cache()
def get_class_substate(cls, path: Sequence[str]) -> Type[BaseState]:
def get_class_substate(cls, path: Sequence[str] | str) -> Type[BaseState]:
"""Get the class substate.

Args:
Expand All @@ -689,6 +689,9 @@ def get_class_substate(cls, path: Sequence[str]) -> Type[BaseState]:
Raises:
ValueError: If the substate is not found.
"""
if isinstance(path, str):
path = tuple(path.split("."))

if len(path) == 0:
return cls
if path[0] == cls.get_name():
Expand Down Expand Up @@ -789,6 +792,18 @@ def _set_var(cls, prop: BaseVar):
"""
setattr(cls, prop._var_name, prop)

@classmethod
def _create_event_handler(cls, fn):
"""Create an event handler for the given function.

Args:
fn: The function to create an event handler for.

Returns:
The event handler.
"""
return EventHandler(fn=fn, state_full_name=cls.get_full_name())

@classmethod
def _create_setter(cls, prop: BaseVar):
"""Create a setter for the var.
Expand All @@ -798,7 +813,7 @@ def _create_setter(cls, prop: BaseVar):
"""
setter_name = prop.get_setter_name(include_state=False)
if setter_name not in cls.__dict__:
event_handler = EventHandler(fn=prop.get_setter())
event_handler = cls._create_event_handler(prop.get_setter())
cls.event_handlers[setter_name] = event_handler
setattr(cls, setter_name, event_handler)

Expand Down Expand Up @@ -1752,7 +1767,7 @@ async def update_vars_internal(self, vars: dict[str, Any]) -> None:
"""
for var, value in vars.items():
state_name, _, var_name = var.rpartition(".")
var_state_cls = State.get_class_substate(tuple(state_name.split(".")))
var_state_cls = State.get_class_substate(state_name)
var_state = await self.get_state(var_state_cls)
setattr(var_state, var_name, value)

Expand Down Expand Up @@ -2268,7 +2283,7 @@ async def get_state(
_, state_path = _split_substate_key(token)
if state_path:
# Get the State class associated with the given path.
state_cls = self.state.get_class_substate(tuple(state_path.split(".")))
state_cls = self.state.get_class_substate(state_path)
else:
raise RuntimeError(
"StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}"
Expand Down
19 changes: 9 additions & 10 deletions reflex/utils/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import json
import os
import re
import sys
from typing import TYPE_CHECKING, Any, List, Union

from reflex import constants
Expand Down Expand Up @@ -470,18 +469,18 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]:
if len(parts) == 1:
return ("", parts[-1])

# Get the state and the function name.
state_name, name = parts[-2:]
# Get the state full name
state_full_name = handler.state_full_name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nicer, good call


# Construct the full event handler name.
try:
# Try to get the state from the module.
state = vars(sys.modules[handler.fn.__module__])[state_name]
except Exception:
# If the state isn't in the module, just return the function name.
# Get the function name
name = parts[-1]

from reflex.state import State

if state_full_name == "state" and name not in State.__dict__:
return ("", to_snake_case(handler.fn.__qualname__))

return (state.get_full_name(), name)
return (state_full_name, name)


def format_event_handler(handler: EventHandler) -> str:
Expand Down
Loading