Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENG-2287] Avoid fetching same state from redis multiple times #4055

Merged
merged 3 commits into from
Oct 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 33 additions & 29 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2874,11 +2874,14 @@ class StateManagerRedis(StateManager):
# Only warn about each state class size once.
_warned_about_state_size: ClassVar[Set[str]] = set()

async def _get_parent_state(self, token: str) -> BaseState | None:
async def _get_parent_state(
self, token: str, state: BaseState | None = None
) -> BaseState | None:
"""Get the parent state for the state requested in the token.

Args:
token: The token to get the state for (_substate_key).
state: The state instance to get parent state for.

Returns:
The parent state for the state requested by the token or None if there is no such parent.
Expand All @@ -2887,11 +2890,15 @@ async def _get_parent_state(self, token: str) -> BaseState | None:
client_token, state_path = _split_substate_key(token)
parent_state_name = state_path.rpartition(".")[0]
if parent_state_name:
cached_substates = None
if state is not None:
cached_substates = [state]
# Retrieve the parent state to populate event handlers onto this substate.
parent_state = await self.get_state(
token=_substate_key(client_token, parent_state_name),
top_level=False,
get_substates=False,
cached_substates=cached_substates,
)
return parent_state

Expand Down Expand Up @@ -2923,6 +2930,8 @@ async def _populate_substates(
tasks = {}
# Retrieve the necessary substates from redis.
for substate_cls in fetch_substates:
if substate_cls.get_name() in state.substates:
continue
substate_name = substate_cls.get_name()
tasks[substate_name] = asyncio.create_task(
self.get_state(
Expand All @@ -2943,6 +2952,7 @@ async def get_state(
top_level: bool = True,
get_substates: bool = True,
parent_state: BaseState | None = None,
cached_substates: list[BaseState] | None = None,
) -> BaseState:
"""Get the state for a token.

Expand All @@ -2951,6 +2961,7 @@ async def get_state(
top_level: If true, return an instance of the top-level state (self.state).
get_substates: If true, also retrieve substates.
parent_state: If provided, use this parent_state instead of getting it from redis.
cached_substates: If provided, attach these substates to the state.

Returns:
The state for the token.
Expand All @@ -2968,45 +2979,38 @@ async def get_state(
"StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}"
)

# The deserialized or newly created (sub)state instance.
state = None

# Fetch the serialized substate from redis.
redis_state = await self.redis.get(token)

if redis_state is not None:
# Deserialize the substate.
state = BaseState._deserialize(data=redis_state)

# Populate parent state if missing and requested.
if parent_state is None:
parent_state = await self._get_parent_state(token)
# Set up Bidirectional linkage between this state and its parent.
if parent_state is not None:
parent_state.substates[state.get_name()] = state
state.parent_state = parent_state
# Populate substates if requested.
await self._populate_substates(token, state, all_substates=get_substates)

# To retain compatibility with previous implementation, by default, we return
# the top-level state by chasing `parent_state` pointers up the tree.
if top_level:
return state._get_root_state()
return state

# TODO: dedupe the following logic with the above block
# Key didn't exist so we have to create a new instance for this token.
with contextlib.suppress(StateSchemaMismatchError):
state = BaseState._deserialize(data=redis_state)
if state is None:
# Key didn't exist or schema mismatch so create a new instance for this token.
state = state_cls(
init_substates=False,
_reflex_internal_init=True,
)
# Populate parent state if missing and requested.
if parent_state is None:
parent_state = await self._get_parent_state(token)
# Instantiate the new state class (but don't persist it yet).
state = state_cls(
parent_state=parent_state,
init_substates=False,
_reflex_internal_init=True,
)
parent_state = await self._get_parent_state(token, state)
# Set up Bidirectional linkage between this state and its parent.
if parent_state is not None:
parent_state.substates[state.get_name()] = state
state.parent_state = parent_state
# Populate substates for the newly created state.
# Avoid fetching substates multiple times.
if cached_substates:
for substate in cached_substates:
state.substates[substate.get_name()] = substate
if substate.parent_state is None:
substate.parent_state = state
# Populate substates if requested.
await self._populate_substates(token, state, all_substates=get_substates)

# To retain compatibility with previous implementation, by default, we return
# the top-level state by chasing `parent_state` pointers up the tree.
if top_level:
Expand Down
Loading