Skip to content

Commit

Permalink
fix var dependency dicts (#3842)
Browse files Browse the repository at this point in the history
  • Loading branch information
benedikt-bartscher authored and adhami3310 committed Sep 9, 2024
1 parent f54eefb commit e693e17
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 22 deletions.
23 changes: 12 additions & 11 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,9 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs):
cls.event_handlers[name] = handler
setattr(cls, name, handler)

# Initialize per-class var dependency tracking.
cls._computed_var_dependencies = defaultdict(set)
cls._substate_var_dependencies = defaultdict(set)
cls._init_var_dependency_dicts()

@staticmethod
Expand Down Expand Up @@ -651,10 +654,6 @@ def _init_var_dependency_dicts(cls):
Additional updates tracking dicts for vars and substates that always
need to be recomputed.
"""
# Initialize per-class var dependency tracking.
cls._computed_var_dependencies = defaultdict(set)
cls._substate_var_dependencies = defaultdict(set)

inherited_vars = set(cls.inherited_vars).union(
set(cls.inherited_backend_vars),
)
Expand Down Expand Up @@ -1004,20 +1003,20 @@ def setup_dynamic_args(cls, args: dict[str, str]):
Args:
args: a dict of args
"""
if not args:
return

def argsingle_factory(param):
@ImmutableComputedVar
def inner_func(self) -> str:
return self.router.page.params.get(param, "")

return inner_func
return ImmutableComputedVar(fget=inner_func, cache=True)

def arglist_factory(param):
@ImmutableComputedVar
def inner_func(self) -> List:
return self.router.page.params.get(param, [])

return inner_func
return ImmutableComputedVar(fget=inner_func, cache=True)

for param, value in args.items():
if value == constants.RouteArgType.SINGLE:
Expand All @@ -1031,8 +1030,8 @@ def inner_func(self) -> List:
cls.vars[param] = cls.computed_vars[param] = func._var_set_state(cls) # type: ignore
setattr(cls, param, func)

# Reinitialize dependency tracking dicts.
cls._init_var_dependency_dicts()
# Reinitialize dependency tracking dicts.
cls._init_var_dependency_dicts()

def __getattribute__(self, name: str) -> Any:
"""Get the state var.
Expand Down Expand Up @@ -3598,5 +3597,7 @@ def reload_state_module(
if subclass.__module__ == module and module is not None:
state.class_subclasses.remove(subclass)
state._always_dirty_substates.discard(subclass.get_name())
state._init_var_dependency_dicts()
state._computed_var_dependencies = defaultdict(set)
state._substate_var_dependencies = defaultdict(set)
state._init_var_dependency_dicts()
state.get_class_substate.cache_clear()
11 changes: 1 addition & 10 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ def on_counter(self):
"""Increment the counter var."""
self.counter = self.counter + 1

@immutable_computed_var
@immutable_computed_var(cache=True)
def comp_dynamic(self) -> str:
"""A computed var that depends on the dynamic var.
Expand Down Expand Up @@ -1049,9 +1049,6 @@ def _dynamic_state_event(name, val, **kwargs):
assert on_load_update == StateUpdate(
delta={
state.get_name(): {
# These computed vars _shouldn't_ be here, because they didn't change
arg_name: exp_val,
f"comp_{arg_name}": exp_val,
"loaded": exp_index + 1,
},
},
Expand All @@ -1073,9 +1070,6 @@ def _dynamic_state_event(name, val, **kwargs):
assert on_set_is_hydrated_update == StateUpdate(
delta={
state.get_name(): {
# These computed vars _shouldn't_ be here, because they didn't change
arg_name: exp_val,
f"comp_{arg_name}": exp_val,
"is_hydrated": True,
},
},
Expand All @@ -1097,9 +1091,6 @@ def _dynamic_state_event(name, val, **kwargs):
assert update == StateUpdate(
delta={
state.get_name(): {
# These computed vars _shouldn't_ be here, because they didn't change
f"comp_{arg_name}": exp_val,
arg_name: exp_val,
"counter": exp_index + 1,
}
},
Expand Down
52 changes: 51 additions & 1 deletion tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,12 @@ def test_set_dirty_var(test_state):
assert test_state.dirty_vars == set()


def test_set_dirty_substate(test_state, child_state, child_state2, grandchild_state):
def test_set_dirty_substate(
test_state: TestState,
child_state: ChildState,
child_state2: ChildState2,
grandchild_state: GrandchildState,
):
"""Test changing substate vars marks the value as dirty.
Args:
Expand Down Expand Up @@ -3077,6 +3082,51 @@ def bar(self) -> str:
assert C1._potentially_dirty_substates() == set()


def test_router_var_dep() -> None:
"""Test that router var dependencies are correctly tracked."""

class RouterVarParentState(State):
"""A parent state for testing router var dependency."""

pass

class RouterVarDepState(RouterVarParentState):
"""A state with a router var dependency."""

@rx.var(cache=True)
def foo(self) -> str:
return self.router.page.params.get("foo", "")

foo = RouterVarDepState.computed_vars["foo"]
State._init_var_dependency_dicts()

assert foo._deps(objclass=RouterVarDepState) == {"router"}
assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState}
assert RouterVarParentState._substate_var_dependencies == {
"router": {RouterVarDepState.get_name()}
}
assert RouterVarDepState._computed_var_dependencies == {
"router": {"foo"},
}

rx_state = State()
parent_state = RouterVarParentState()
state = RouterVarDepState()

# link states
rx_state.substates = {RouterVarParentState.get_name(): parent_state}
parent_state.parent_state = rx_state
state.parent_state = parent_state
parent_state.substates = {RouterVarDepState.get_name(): state}

assert state.dirty_vars == set()

# Reassign router var
state.router = state.router
assert state.dirty_vars == {"foo", "router"}
assert parent_state.dirty_substates == {RouterVarDepState.get_name()}


@pytest.mark.asyncio
async def test_setvar(mock_app: rx.App, token: str):
"""Test that setvar works correctly.
Expand Down

0 comments on commit e693e17

Please sign in to comment.