diff --git a/reflex/state.py b/reflex/state.py index 424ead87c1..d080e57c90 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -143,6 +143,15 @@ def __init__(self, router_data: Optional[dict] = None): self.page = PageData(router_data) +RESERVED_BACKEND_VAR_NAMES = { + "_backend_vars", + "_computed_var_dependencies", + "_substate_var_dependencies", + "_always_dirty_computed_vars", + "_always_dirty_substates", +} + + class State(Base, ABC, extra=pydantic.Extra.allow): """The state of the app.""" @@ -167,6 +176,18 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # The event handlers. event_handlers: ClassVar[Dict[str, EventHandler]] = {} + # Mapping of var name to set of computed variables that depend on it + _computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {} + + # Mapping of var name to set of substates that depend on it + _substate_var_dependencies: ClassVar[Dict[str, Set[str]]] = {} + + # Set of vars which always need to be recomputed + _always_dirty_computed_vars: ClassVar[Set[str]] = set() + + # Set of substates which always need to be recomputed + _always_dirty_substates: ClassVar[Set[str]] = set() + # The parent state. parent_state: Optional[State] = None @@ -182,12 +203,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # The routing path that triggered the state router_data: Dict[str, Any] = {} - # Mapping of var name to set of computed variables that depend on it - computed_var_dependencies: Dict[str, Set[str]] = {} - - # Mapping of var name to set of substates that depend on it - substate_var_dependencies: Dict[str, Set[str]] = {} - # Per-instance copy of backend variable values _backend_vars: Dict[str, Any] = {} @@ -211,10 +226,6 @@ def __init__(self, *args, parent_state: State | None = None, **kwargs): kwargs["parent_state"] = parent_state super().__init__(*args, **kwargs) - # initialize per-instance var dependency tracking - self.computed_var_dependencies = defaultdict(set) - self.substate_var_dependencies = defaultdict(set) - # Setup the substates. for substate in self.get_substates(): substate_name = substate.get_name() @@ -227,25 +238,6 @@ def __init__(self, *args, parent_state: State | None = None, **kwargs): # Convert the event handlers to functions. self._init_event_handlers() - # Initialize computed vars dependencies. - inherited_vars = set(self.inherited_vars).union( - set(self.inherited_backend_vars), - ) - for cvar_name, cvar in self.computed_vars.items(): - # Add the dependencies. - for var in cvar._deps(objclass=type(self)): - self.computed_var_dependencies[var].add(cvar_name) - if var in inherited_vars: - # track that this substate depends on its parent for this var - state_name = self.get_name() - parent_state = self.parent_state - while parent_state is not None and var in parent_state.vars: - parent_state.substate_var_dependencies[var].add(state_name) - state_name, parent_state = ( - parent_state.get_name(), - parent_state.parent_state, - ) - # Create a fresh copy of the backend variables for this instance self._backend_vars = copy.deepcopy(self.backend_vars) @@ -347,6 +339,60 @@ def __init_subclass__(cls, **kwargs): cls.event_handlers[name] = handler setattr(cls, name, handler) + cls._init_var_dependency_dicts() + + @classmethod + def _init_var_dependency_dicts(cls): + """Initialize the var dependency tracking dicts. + + Allows the state to know which vars each ComputedVar depends on and + whether a ComputedVar depends on a var in its parent state. + + Additional updates tracking dicts for vars and substates that always + need to be recomputed. + """ + # Initialize per-class var dependency tracking. + cls._computed_var_dependencies = defaultdict(set) + cls._substate_var_dependencies = defaultdict(set) + + inherited_vars = set(cls.inherited_vars).union( + set(cls.inherited_backend_vars), + ) + for cvar_name, cvar in cls.computed_vars.items(): + # Add the dependencies. + for var in cvar._deps(objclass=cls): + cls._computed_var_dependencies[var].add(cvar_name) + if var in inherited_vars: + # track that this substate depends on its parent for this var + state_name = cls.get_name() + parent_state = cls.get_parent_state() + while parent_state is not None and var in parent_state.vars: + parent_state._substate_var_dependencies[var].add(state_name) + state_name, parent_state = ( + parent_state.get_name(), + parent_state.get_parent_state(), + ) + + # ComputedVar with cache=False always need to be recomputed + cls._always_dirty_computed_vars = set( + cvar_name + for cvar_name, cvar in cls.computed_vars.items() + if not cvar._cache + ) + + # Any substate containing a ComputedVar with cache=False always needs to be recomputed + cls._always_dirty_substates = set() + if cls._always_dirty_computed_vars: + # Tell parent classes that this substate has always dirty computed vars + state_name = cls.get_name() + parent_state = cls.get_parent_state() + while parent_state is not None: + parent_state._always_dirty_substates.add(state_name) + state_name, parent_state = ( + parent_state.get_name(), + parent_state.get_parent_state(), + ) + @classmethod def _check_overridden_methods(cls): """Check for shadow methods and raise error if any. @@ -377,16 +423,17 @@ def get_skip_vars(cls) -> set[str]: Returns: The vars to skip when serializing. """ - return set(cls.inherited_vars) | { - "parent_state", - "substates", - "dirty_vars", - "dirty_substates", - "router_data", - "computed_var_dependencies", - "substate_var_dependencies", - "_backend_vars", - } + return ( + set(cls.inherited_vars) + | { + "parent_state", + "substates", + "dirty_vars", + "dirty_substates", + "router_data", + } + | RESERVED_BACKEND_VAR_NAMES + ) @classmethod @functools.lru_cache() @@ -540,6 +587,9 @@ def add_var(cls, name: str, type_: Any, default_value: Any = None): for substate_class in cls.__subclasses__(): substate_class.vars.setdefault(name, var) + # Reinitialize dependency tracking dicts. + cls._init_var_dependency_dicts() + @classmethod def _set_var(cls, prop: BaseVar): """Set the var as a class member. @@ -749,6 +799,9 @@ def inner_func(self) -> List: cls.vars[param] = cls.computed_vars[param] = func._var_set_state(cls) # type: ignore setattr(cls, param, func) + # Reinitialize dependency tracking dicts. + cls._init_var_dependency_dicts() + def __getattribute__(self, name: str) -> Any: """Get the state var. @@ -804,7 +857,7 @@ def __setattr__(self, name: str, value: Any): setattr(self.parent_state, name, value) return - if types.is_backend_variable(name) and name != "_backend_vars": + if types.is_backend_variable(name) and name not in RESERVED_BACKEND_VAR_NAMES: self._backend_vars.__setitem__(name, value) self.dirty_vars.add(name) self._mark_dirty() @@ -814,7 +867,7 @@ def __setattr__(self, name: str, value: Any): super().__setattr__(name, value) # Add the var to the dirty list. - if name in self.vars or name in self.computed_var_dependencies: + if name in self.vars or name in self._computed_var_dependencies: self.dirty_vars.add(name) self._mark_dirty() @@ -1056,18 +1109,6 @@ async def _process_event( final=True, ) - def _always_dirty_computed_vars(self) -> set[str]: - """The set of ComputedVars that always need to be recalculated. - - Returns: - Set of all ComputedVar in this state where cache=False - """ - return set( - cvar_name - for cvar_name, cvar in self.computed_vars.items() - if not cvar._cache - ) - def _mark_dirty_computed_vars(self) -> None: """Mark ComputedVars that need to be recalculated based on dirty_vars.""" dirty_vars = self.dirty_vars @@ -1092,7 +1133,7 @@ def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]: return set( cvar for dirty_var in from_vars or self.dirty_vars - for cvar in self.computed_var_dependencies[dirty_var] + for cvar in self._computed_var_dependencies[dirty_var] ) def get_delta(self) -> Delta: @@ -1104,7 +1145,7 @@ def get_delta(self) -> Delta: delta = {} # Apply dirty variables down into substates - self.dirty_vars.update(self._always_dirty_computed_vars()) + self.dirty_vars.update(self._always_dirty_computed_vars) self._mark_dirty() # Return the dirty vars for this instance, any cached/dependent computed vars, @@ -1112,7 +1153,7 @@ def get_delta(self) -> Delta: delta_vars = ( self.dirty_vars.intersection(self.base_vars) .union(self._dirty_computed_vars()) - .union(self._always_dirty_computed_vars()) + .union(self._always_dirty_computed_vars) ) subdelta = { @@ -1125,7 +1166,7 @@ def get_delta(self) -> Delta: # Recursively find the substate deltas. substates = self.substates - for substate in self.dirty_substates: + for substate in self.dirty_substates.union(self._always_dirty_substates): delta.update(substates[substate].get_delta()) # Format the delta. @@ -1151,7 +1192,7 @@ def _mark_dirty(self): # Propagate dirty var / computed var status into substates substates = self.substates for var in self.dirty_vars: - for substate_name in self.substate_var_dependencies[var]: + for substate_name in self._substate_var_dependencies[var]: self.dirty_substates.add(substate_name) substate = substates[substate_name] substate.dirty_vars.add(var) @@ -1195,7 +1236,7 @@ def dict(self, include_computed: bool = True, **kwargs) -> dict[str, Any]: if include_computed: # Apply dirty variables down into substates to allow never-cached ComputedVar to # trigger recalculation of dependent vars - self.dirty_vars.update(self._always_dirty_computed_vars()) + self.dirty_vars.update(self._always_dirty_computed_vars) self._mark_dirty() base_vars = { diff --git a/tests/test_app.py b/tests/test_app.py index e7ac03661a..b1baa8d2ff 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -257,7 +257,7 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool): assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == { constants.ROUTER } - assert constants.ROUTER in app.state().computed_var_dependencies + assert constants.ROUTER in app.state()._computed_var_dependencies def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool): @@ -917,7 +917,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( assert app.state.computed_vars[arg_name]._deps(objclass=DynamicState) == { constants.ROUTER } - assert constants.ROUTER in app.state().computed_var_dependencies + assert constants.ROUTER in app.state()._computed_var_dependencies sid = "mock_sid" client_ip = "127.0.0.1" diff --git a/tests/test_state.py b/tests/test_state.py index d85f0e59f9..591026d2cc 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1215,7 +1215,7 @@ def cached_x_side_effect(self) -> int: assert isinstance(HandlerState.handler, EventHandler) s = HandlerState() - assert "cached_x_side_effect" in s.computed_var_dependencies["x"] + assert "cached_x_side_effect" in s._computed_var_dependencies["x"] assert s.cached_x_side_effect == 1 assert s.x == 43 s.handler() @@ -1283,11 +1283,11 @@ def comp_z(self) -> List[bool]: return [z in self._z for z in range(5)] cs = ComputedState() - assert cs.computed_var_dependencies["v"] == {"comp_v"} - assert cs.computed_var_dependencies["w"] == {"comp_w"} - assert cs.computed_var_dependencies["x"] == {"comp_x"} - assert cs.computed_var_dependencies["y"] == {"comp_y"} - assert cs.computed_var_dependencies["_z"] == {"comp_z"} + assert cs._computed_var_dependencies["v"] == {"comp_v"} + assert cs._computed_var_dependencies["w"] == {"comp_w"} + assert cs._computed_var_dependencies["x"] == {"comp_x"} + assert cs._computed_var_dependencies["y"] == {"comp_y"} + assert cs._computed_var_dependencies["_z"] == {"comp_z"} def test_backend_method():