Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENG-2287] Avoid fetching same state from redis multiple times #4055

Merged
merged 3 commits into from
Oct 7, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 29 additions & 28 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2874,11 +2874,14 @@ class StateManagerRedis(StateManager):
# Only warn about each state class size once.
_warned_about_state_size: ClassVar[Set[str]] = set()

async def _get_parent_state(self, token: str) -> BaseState | None:
async def _get_parent_state(
self, token: str, state: BaseState | None = None
) -> BaseState | None:
"""Get the parent state for the state requested in the token.

Args:
token: The token to get the state for (_substate_key).
state: The state instance to get parent state for.

Returns:
The parent state for the state requested by the token or None if there is no such parent.
Expand All @@ -2887,11 +2890,15 @@ async def _get_parent_state(self, token: str) -> BaseState | None:
client_token, state_path = _split_substate_key(token)
parent_state_name = state_path.rpartition(".")[0]
if parent_state_name:
cached_substates = None
if state is not None:
cached_substates = [state]
# Retrieve the parent state to populate event handlers onto this substate.
parent_state = await self.get_state(
token=_substate_key(client_token, parent_state_name),
top_level=False,
get_substates=False,
cached_substates=cached_substates,
)
return parent_state

Expand Down Expand Up @@ -2923,6 +2930,8 @@ async def _populate_substates(
tasks = {}
# Retrieve the necessary substates from redis.
for substate_cls in fetch_substates:
if substate_cls.get_name() in state.substates:
continue
substate_name = substate_cls.get_name()
tasks[substate_name] = asyncio.create_task(
self.get_state(
Expand All @@ -2943,6 +2952,7 @@ async def get_state(
top_level: bool = True,
get_substates: bool = True,
parent_state: BaseState | None = None,
cached_substates: list[BaseState] | None = None,
) -> BaseState:
"""Get the state for a token.

Expand All @@ -2951,6 +2961,7 @@ async def get_state(
top_level: If true, return an instance of the top-level state (self.state).
get_substates: If true, also retrieve substates.
parent_state: If provided, use this parent_state instead of getting it from redis.
cached_substates: If provided, attach these substates to the state.

Returns:
The state for the token.
Expand All @@ -2974,39 +2985,29 @@ async def get_state(
if redis_state is not None:
# Deserialize the substate.
state = BaseState._deserialize(data=redis_state)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to catch schema mismatch error


# Populate parent state if missing and requested.
if parent_state is None:
parent_state = await self._get_parent_state(token)
# Set up Bidirectional linkage between this state and its parent.
if parent_state is not None:
parent_state.substates[state.get_name()] = state
state.parent_state = parent_state
# Populate substates if requested.
await self._populate_substates(token, state, all_substates=get_substates)

# To retain compatibility with previous implementation, by default, we return
# the top-level state by chasing `parent_state` pointers up the tree.
if top_level:
return state._get_root_state()
return state

# TODO: dedupe the following logic with the above block
# Key didn't exist so we have to create a new instance for this token.
else:
# Key didn't exist so we have to create a new instance for this token.
# Instantiate the new state class (but don't persist it yet).
state = state_cls(
init_substates=False,
_reflex_internal_init=True,
)
# Populate parent state if missing and requested.
if parent_state is None:
parent_state = await self._get_parent_state(token)
# Instantiate the new state class (but don't persist it yet).
state = state_cls(
parent_state=parent_state,
init_substates=False,
_reflex_internal_init=True,
)
parent_state = await self._get_parent_state(token, state)
# Set up Bidirectional linkage between this state and its parent.
if parent_state is not None:
parent_state.substates[state.get_name()] = state
state.parent_state = parent_state
# Populate substates for the newly created state.
# Avoid fetching substates multiple times.
if cached_substates:
for substate in cached_substates:
state.substates[substate.get_name()] = substate
if substate.parent_state is None:
substate.parent_state = state
# Populate substates if requested.
await self._populate_substates(token, state, all_substates=get_substates)

# To retain compatibility with previous implementation, by default, we return
# the top-level state by chasing `parent_state` pointers up the tree.
if top_level:
Expand Down
Loading