diff --git a/reflex/state.py b/reflex/state.py index 5798564fa4..76586b47d5 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2874,11 +2874,14 @@ class StateManagerRedis(StateManager): # Only warn about each state class size once. _warned_about_state_size: ClassVar[Set[str]] = set() - async def _get_parent_state(self, token: str) -> BaseState | None: + async def _get_parent_state( + self, token: str, state: BaseState | None = None + ) -> BaseState | None: """Get the parent state for the state requested in the token. Args: token: The token to get the state for (_substate_key). + state: The state instance to get parent state for. Returns: The parent state for the state requested by the token or None if there is no such parent. @@ -2887,11 +2890,15 @@ async def _get_parent_state(self, token: str) -> BaseState | None: client_token, state_path = _split_substate_key(token) parent_state_name = state_path.rpartition(".")[0] if parent_state_name: + cached_substates = None + if state is not None: + cached_substates = [state] # Retrieve the parent state to populate event handlers onto this substate. parent_state = await self.get_state( token=_substate_key(client_token, parent_state_name), top_level=False, get_substates=False, + cached_substates=cached_substates, ) return parent_state @@ -2923,6 +2930,8 @@ async def _populate_substates( tasks = {} # Retrieve the necessary substates from redis. for substate_cls in fetch_substates: + if substate_cls.get_name() in state.substates: + continue substate_name = substate_cls.get_name() tasks[substate_name] = asyncio.create_task( self.get_state( @@ -2943,6 +2952,7 @@ async def get_state( top_level: bool = True, get_substates: bool = True, parent_state: BaseState | None = None, + cached_substates: list[BaseState] | None = None, ) -> BaseState: """Get the state for a token. @@ -2951,6 +2961,7 @@ async def get_state( top_level: If true, return an instance of the top-level state (self.state). get_substates: If true, also retrieve substates. parent_state: If provided, use this parent_state instead of getting it from redis. + cached_substates: If provided, attach these substates to the state. Returns: The state for the token. @@ -2968,45 +2979,38 @@ async def get_state( "StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}" ) + # The deserialized or newly created (sub)state instance. + state = None + # Fetch the serialized substate from redis. redis_state = await self.redis.get(token) if redis_state is not None: # Deserialize the substate. - state = BaseState._deserialize(data=redis_state) - - # Populate parent state if missing and requested. - if parent_state is None: - parent_state = await self._get_parent_state(token) - # Set up Bidirectional linkage between this state and its parent. - if parent_state is not None: - parent_state.substates[state.get_name()] = state - state.parent_state = parent_state - # Populate substates if requested. - await self._populate_substates(token, state, all_substates=get_substates) - - # To retain compatibility with previous implementation, by default, we return - # the top-level state by chasing `parent_state` pointers up the tree. - if top_level: - return state._get_root_state() - return state - - # TODO: dedupe the following logic with the above block - # Key didn't exist so we have to create a new instance for this token. + with contextlib.suppress(StateSchemaMismatchError): + state = BaseState._deserialize(data=redis_state) + if state is None: + # Key didn't exist or schema mismatch so create a new instance for this token. + state = state_cls( + init_substates=False, + _reflex_internal_init=True, + ) + # Populate parent state if missing and requested. if parent_state is None: - parent_state = await self._get_parent_state(token) - # Instantiate the new state class (but don't persist it yet). - state = state_cls( - parent_state=parent_state, - init_substates=False, - _reflex_internal_init=True, - ) + parent_state = await self._get_parent_state(token, state) # Set up Bidirectional linkage between this state and its parent. if parent_state is not None: parent_state.substates[state.get_name()] = state state.parent_state = parent_state - # Populate substates for the newly created state. + # Avoid fetching substates multiple times. + if cached_substates: + for substate in cached_substates: + state.substates[substate.get_name()] = substate + if substate.parent_state is None: + substate.parent_state = state + # Populate substates if requested. await self._populate_substates(token, state, all_substates=get_substates) + # To retain compatibility with previous implementation, by default, we return # the top-level state by chasing `parent_state` pointers up the tree. if top_level: