From e693e17103d444f3718cbb0ca6d7bf575d66e3ba Mon Sep 17 00:00:00 2001 From: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com> Date: Mon, 9 Sep 2024 03:35:47 +0200 Subject: [PATCH] fix var dependency dicts (#3842) --- reflex/state.py | 23 ++++++++++---------- tests/test_app.py | 11 +--------- tests/test_state.py | 52 ++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 64 insertions(+), 22 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 98531cae9d..d74920622e 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -582,6 +582,9 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): cls.event_handlers[name] = handler setattr(cls, name, handler) + # Initialize per-class var dependency tracking. + cls._computed_var_dependencies = defaultdict(set) + cls._substate_var_dependencies = defaultdict(set) cls._init_var_dependency_dicts() @staticmethod @@ -651,10 +654,6 @@ def _init_var_dependency_dicts(cls): Additional updates tracking dicts for vars and substates that always need to be recomputed. """ - # Initialize per-class var dependency tracking. - cls._computed_var_dependencies = defaultdict(set) - cls._substate_var_dependencies = defaultdict(set) - inherited_vars = set(cls.inherited_vars).union( set(cls.inherited_backend_vars), ) @@ -1004,20 +1003,20 @@ def setup_dynamic_args(cls, args: dict[str, str]): Args: args: a dict of args """ + if not args: + return def argsingle_factory(param): - @ImmutableComputedVar def inner_func(self) -> str: return self.router.page.params.get(param, "") - return inner_func + return ImmutableComputedVar(fget=inner_func, cache=True) def arglist_factory(param): - @ImmutableComputedVar def inner_func(self) -> List: return self.router.page.params.get(param, []) - return inner_func + return ImmutableComputedVar(fget=inner_func, cache=True) for param, value in args.items(): if value == constants.RouteArgType.SINGLE: @@ -1031,8 +1030,8 @@ def inner_func(self) -> List: cls.vars[param] = cls.computed_vars[param] = func._var_set_state(cls) # type: ignore setattr(cls, param, func) - # Reinitialize dependency tracking dicts. - cls._init_var_dependency_dicts() + # Reinitialize dependency tracking dicts. + cls._init_var_dependency_dicts() def __getattribute__(self, name: str) -> Any: """Get the state var. @@ -3598,5 +3597,7 @@ def reload_state_module( if subclass.__module__ == module and module is not None: state.class_subclasses.remove(subclass) state._always_dirty_substates.discard(subclass.get_name()) - state._init_var_dependency_dicts() + state._computed_var_dependencies = defaultdict(set) + state._substate_var_dependencies = defaultdict(set) + state._init_var_dependency_dicts() state.get_class_substate.cache_clear() diff --git a/tests/test_app.py b/tests/test_app.py index e3d98a86bc..167cbf0d48 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -906,7 +906,7 @@ def on_counter(self): """Increment the counter var.""" self.counter = self.counter + 1 - @immutable_computed_var + @immutable_computed_var(cache=True) def comp_dynamic(self) -> str: """A computed var that depends on the dynamic var. @@ -1049,9 +1049,6 @@ def _dynamic_state_event(name, val, **kwargs): assert on_load_update == StateUpdate( delta={ state.get_name(): { - # These computed vars _shouldn't_ be here, because they didn't change - arg_name: exp_val, - f"comp_{arg_name}": exp_val, "loaded": exp_index + 1, }, }, @@ -1073,9 +1070,6 @@ def _dynamic_state_event(name, val, **kwargs): assert on_set_is_hydrated_update == StateUpdate( delta={ state.get_name(): { - # These computed vars _shouldn't_ be here, because they didn't change - arg_name: exp_val, - f"comp_{arg_name}": exp_val, "is_hydrated": True, }, }, @@ -1097,9 +1091,6 @@ def _dynamic_state_event(name, val, **kwargs): assert update == StateUpdate( delta={ state.get_name(): { - # These computed vars _shouldn't_ be here, because they didn't change - f"comp_{arg_name}": exp_val, - arg_name: exp_val, "counter": exp_index + 1, } }, diff --git a/tests/test_state.py b/tests/test_state.py index 480c66d50b..29944840eb 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -649,7 +649,12 @@ def test_set_dirty_var(test_state): assert test_state.dirty_vars == set() -def test_set_dirty_substate(test_state, child_state, child_state2, grandchild_state): +def test_set_dirty_substate( + test_state: TestState, + child_state: ChildState, + child_state2: ChildState2, + grandchild_state: GrandchildState, +): """Test changing substate vars marks the value as dirty. Args: @@ -3077,6 +3082,51 @@ def bar(self) -> str: assert C1._potentially_dirty_substates() == set() +def test_router_var_dep() -> None: + """Test that router var dependencies are correctly tracked.""" + + class RouterVarParentState(State): + """A parent state for testing router var dependency.""" + + pass + + class RouterVarDepState(RouterVarParentState): + """A state with a router var dependency.""" + + @rx.var(cache=True) + def foo(self) -> str: + return self.router.page.params.get("foo", "") + + foo = RouterVarDepState.computed_vars["foo"] + State._init_var_dependency_dicts() + + assert foo._deps(objclass=RouterVarDepState) == {"router"} + assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState} + assert RouterVarParentState._substate_var_dependencies == { + "router": {RouterVarDepState.get_name()} + } + assert RouterVarDepState._computed_var_dependencies == { + "router": {"foo"}, + } + + rx_state = State() + parent_state = RouterVarParentState() + state = RouterVarDepState() + + # link states + rx_state.substates = {RouterVarParentState.get_name(): parent_state} + parent_state.parent_state = rx_state + state.parent_state = parent_state + parent_state.substates = {RouterVarDepState.get_name(): state} + + assert state.dirty_vars == set() + + # Reassign router var + state.router = state.router + assert state.dirty_vars == {"foo", "router"} + assert parent_state.dirty_substates == {RouterVarDepState.get_name()} + + @pytest.mark.asyncio async def test_setvar(mock_app: rx.App, token: str): """Test that setvar works correctly.