diff --git a/reflex/app.py b/reflex/app.py index 5be0ef04013..82d731fb2aa 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -799,6 +799,13 @@ def _apply_decorated_pages(self): for render, kwargs in DECORATED_PAGES[get_config().app_name]: self.add_page(render, **kwargs) + def _init_var_dependency_dicts(self): + """Recursively initialize the var dependency dictionaries for all state subclasses of the app.""" + if not self.state: + return + + self.state._reset_var_dependency_dicts() + def _validate_var_dependencies( self, state: Optional[Type[BaseState]] = None ) -> None: @@ -853,6 +860,7 @@ def get_compilation_time() -> str: if not self._should_compile(): return + self._init_var_dependency_dicts() self._validate_var_dependencies() self._setup_overlay_component() self._setup_error_boundary() diff --git a/reflex/state.py b/reflex/state.py index 8fcaa7e193f..e2fc9d81c33 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -589,10 +589,8 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): cls.event_handlers[name] = handler setattr(cls, name, handler) - # Initialize per-class var dependency tracking. - cls._computed_var_dependencies = defaultdict(set) - cls._substate_var_dependencies = defaultdict(set) - cls._init_var_dependency_dicts() + if is_testing_env(): + BaseState._reset_var_dependency_dicts() @staticmethod def _copy_fn(fn: Callable) -> Callable: @@ -651,6 +649,21 @@ def _mixins(cls) -> List[Type]: ) ] + @classmethod + def _reset_var_dependency_dicts(cls, state: Type[BaseState] | None = None): + if not state: + state = cls + cls._computed_var_dependencies = defaultdict(set) + cls._substate_var_dependencies = defaultdict(set) + if is_testing_env(): + # clear all subclasses, because substates are cleared in BaseState.__init_subclass__ + substates = state.__subclasses__() + else: + substates = state.get_substates() + for substate in substates: + substate._reset_var_dependency_dicts(state=substate) + cls._init_var_dependency_dicts() + @classmethod def _init_var_dependency_dicts(cls): """Initialize the var dependency tracking dicts. @@ -918,9 +931,6 @@ def add_var(cls, name: str, type_: Any, default_value: Any = None): for substate_class in cls.class_subclasses: substate_class.vars.setdefault(name, var) - # Reinitialize dependency tracking dicts. - cls._init_var_dependency_dicts() - @classmethod def _set_var(cls, prop: ImmutableVar): """Set the var as a class member. @@ -1037,9 +1047,6 @@ def inner_func(self) -> List: cls.vars[param] = cls.computed_vars[param] = func._var_set_state(cls) # type: ignore setattr(cls, param, func) - # Reinitialize dependency tracking dicts. - cls._init_var_dependency_dicts() - def __getattribute__(self, name: str) -> Any: """Get the state var. @@ -1709,9 +1716,6 @@ def _mark_dirty_substates(self): substates = self.substates for var in self.dirty_vars: for substate_name in self._substate_var_dependencies[var]: - if substate_name not in substates: - # TODO: why is this happening? - continue self.dirty_substates.add(substate_name) substate = substates[substate_name] substate.dirty_vars.add(var) @@ -3365,5 +3369,4 @@ def reload_state_module( state._always_dirty_substates.discard(subclass.get_name()) state._computed_var_dependencies = defaultdict(set) state._substate_var_dependencies = defaultdict(set) - state._init_var_dependency_dicts() state.get_class_substate.cache_clear() diff --git a/tests/test_app.py b/tests/test_app.py index cb778f0db79..c428f1832e2 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -271,6 +271,7 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool): route.lstrip("/").replace("/", "\\") assert app.pages == {} app.add_page(index_page, route=route) + app._init_var_dependency_dicts() assert app.pages.keys() == {"test/[dynamic]"} assert "dynamic" in app.state.computed_vars assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == { @@ -946,6 +947,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( assert app.state is not None assert arg_name not in app.state.vars app.add_page(index_page, route=route, on_load=DynamicState.on_load) # type: ignore + app._init_var_dependency_dicts() assert arg_name in app.state.vars assert arg_name in app.state.computed_vars assert app.state.computed_vars[arg_name]._deps(objclass=DynamicState) == { diff --git a/tests/test_state.py b/tests/test_state.py index 83e3f897938..2965bdfbe0a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -626,7 +626,7 @@ def test_get_substate(test_state, child_state, child_state2, grandchild_state): ) -def test_set_dirty_var(test_state): +def test_set_dirty_var(test_state: TestState): """Test changing state vars marks the value as dirty. Args: @@ -3085,7 +3085,6 @@ def foo(self) -> str: return self.router.page.params.get("foo", "") foo = RouterVarDepState.computed_vars["foo"] - State._init_var_dependency_dicts() assert foo._deps(objclass=RouterVarDepState) == {"router"} assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState}