Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix var dependency dicts #3842

Merged
merged 8 commits into from
Sep 9, 2024
Merged
23 changes: 12 additions & 11 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,9 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs):
cls.event_handlers[name] = handler
setattr(cls, name, handler)

# Initialize per-class var dependency tracking.
cls._computed_var_dependencies = defaultdict(set)
cls._substate_var_dependencies = defaultdict(set)
cls._init_var_dependency_dicts()

@staticmethod
Expand Down Expand Up @@ -660,10 +663,6 @@ def _init_var_dependency_dicts(cls):
Additional updates tracking dicts for vars and substates that always
need to be recomputed.
"""
# Initialize per-class var dependency tracking.
cls._computed_var_dependencies = defaultdict(set)
cls._substate_var_dependencies = defaultdict(set)

inherited_vars = set(cls.inherited_vars).union(
set(cls.inherited_backend_vars),
)
Expand Down Expand Up @@ -1013,20 +1012,20 @@ def setup_dynamic_args(cls, args: dict[str, str]):
Args:
args: a dict of args
"""
if not args:
return

def argsingle_factory(param):
@ComputedVar
def inner_func(self) -> str:
return self.router.page.params.get(param, "")

return inner_func
return ComputedVar(fget=inner_func, cache=True)

def arglist_factory(param):
@ComputedVar
def inner_func(self) -> List:
return self.router.page.params.get(param, [])

return inner_func
return ComputedVar(fget=inner_func, cache=True)

for param, value in args.items():
if value == constants.RouteArgType.SINGLE:
Expand All @@ -1040,8 +1039,8 @@ def inner_func(self) -> List:
cls.vars[param] = cls.computed_vars[param] = func._var_set_state(cls) # type: ignore
setattr(cls, param, func)

# Reinitialize dependency tracking dicts.
cls._init_var_dependency_dicts()
# Reinitialize dependency tracking dicts.
cls._init_var_dependency_dicts()

def __getattribute__(self, name: str) -> Any:
"""Get the state var.
Expand Down Expand Up @@ -3607,5 +3606,7 @@ def reload_state_module(
if subclass.__module__ == module and module is not None:
state.class_subclasses.remove(subclass)
state._always_dirty_substates.discard(subclass.get_name())
state._init_var_dependency_dicts()
state._computed_var_dependencies = defaultdict(set)
state._substate_var_dependencies = defaultdict(set)
state._init_var_dependency_dicts()
state.get_class_substate.cache_clear()
11 changes: 1 addition & 10 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ def on_counter(self):
"""Increment the counter var."""
self.counter = self.counter + 1

@computed_var
@computed_var(cache=True)
def comp_dynamic(self) -> str:
"""A computed var that depends on the dynamic var.

Expand Down Expand Up @@ -1049,9 +1049,6 @@ def _dynamic_state_event(name, val, **kwargs):
assert on_load_update == StateUpdate(
delta={
state.get_name(): {
# These computed vars _shouldn't_ be here, because they didn't change
arg_name: exp_val,
f"comp_{arg_name}": exp_val,
"loaded": exp_index + 1,
},
},
Expand All @@ -1073,9 +1070,6 @@ def _dynamic_state_event(name, val, **kwargs):
assert on_set_is_hydrated_update == StateUpdate(
delta={
state.get_name(): {
# These computed vars _shouldn't_ be here, because they didn't change
arg_name: exp_val,
f"comp_{arg_name}": exp_val,
"is_hydrated": True,
},
},
Expand All @@ -1097,9 +1091,6 @@ def _dynamic_state_event(name, val, **kwargs):
assert update == StateUpdate(
delta={
state.get_name(): {
# These computed vars _shouldn't_ be here, because they didn't change
f"comp_{arg_name}": exp_val,
arg_name: exp_val,
"counter": exp_index + 1,
}
},
Expand Down
52 changes: 51 additions & 1 deletion tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,12 @@ def test_set_dirty_var(test_state):
assert test_state.dirty_vars == set()


def test_set_dirty_substate(test_state, child_state, child_state2, grandchild_state):
def test_set_dirty_substate(
test_state: TestState,
child_state: ChildState,
child_state2: ChildState2,
grandchild_state: GrandchildState,
):
"""Test changing substate vars marks the value as dirty.

Args:
Expand Down Expand Up @@ -3077,6 +3082,51 @@ def bar(self) -> str:
assert C1._potentially_dirty_substates() == set()


def test_router_var_dep() -> None:
"""Test that router var dependencies are correctly tracked."""

class RouterVarParentState(State):
"""A parent state for testing router var dependency."""

pass

class RouterVarDepState(RouterVarParentState):
"""A state with a router var dependency."""

@rx.var(cache=True)
def foo(self) -> str:
return self.router.page.params.get("foo", "")

foo = RouterVarDepState.computed_vars["foo"]
State._init_var_dependency_dicts()

assert foo._deps(objclass=RouterVarDepState) == {"router"}
assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState}
assert RouterVarParentState._substate_var_dependencies == {
"router": {RouterVarDepState.get_name()}
}
assert RouterVarDepState._computed_var_dependencies == {
"router": {"foo"},
}

rx_state = State()
parent_state = RouterVarParentState()
state = RouterVarDepState()

# link states
rx_state.substates = {RouterVarParentState.get_name(): parent_state}
parent_state.parent_state = rx_state
state.parent_state = parent_state
parent_state.substates = {RouterVarDepState.get_name(): state}

assert state.dirty_vars == set()

# Reassign router var
state.router = state.router
assert state.dirty_vars == {"foo", "router"}
assert parent_state.dirty_substates == {RouterVarDepState.get_name()}


@pytest.mark.asyncio
async def test_setvar(mock_app: rx.App, token: str):
"""Test that setvar works correctly.
Expand Down
Loading