diff --git a/reflex/state.py b/reflex/state.py index 2da1104503..9333893a79 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1515,6 +1515,39 @@ def _get_loaded_substates( loaded_substates[substate.get_full_name()] = substate substate._get_loaded_substates(loaded_substates) + def _serialize_touched_states(self) -> dict[str, bytes]: + """Serialize all touched states in the state tree. + + Returns: + The serialized states. + """ + root_state = self._get_root_state() + d = {} + if root_state._get_was_touched(): + serialized = root_state._serialize() + if serialized: + d[root_state.get_full_name()] = serialized + root_state._serialize_touched_substates(d) + return d + + def _serialize_touched_substates( + self, + touched_substates: dict[str, bytes], + ) -> None: + """Serialize all touched substates of this state. + + Args: + touched_substates: A dictionary of touched substates which will be updated with the substates of this state. + """ + for substate in self.substates.values(): + substate._serialize_touched_substates(touched_substates) + if not substate._get_was_touched(): + continue + serialized = substate._serialize() + if not serialized: + continue + touched_substates[substate.get_full_name()] = serialized + def _get_root_state(self) -> BaseState: """Get the root state of the state tree. @@ -3464,15 +3497,7 @@ async def set_state( f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}." ) - redis_hashset = {} - - for state_name, substate in state._get_loaded_states().items(): - if not substate._get_was_touched(): - continue - pickle_state = substate._serialize() - if not pickle_state: - continue - redis_hashset[state_name] = pickle_state + redis_hashset = state._serialize_touched_states() if not redis_hashset: return