Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF-1035] Track ComputedVar dependency per class #2067

Merged
merged 4 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 100 additions & 59 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,15 @@ def __init__(self, router_data: Optional[dict] = None):
self.page = PageData(router_data)


RESERVED_BACKEND_VAR_NAMES = {
"_backend_vars",
"_computed_var_dependencies",
"_substate_var_dependencies",
"_always_dirty_computed_vars",
"_always_dirty_substates",
}


class State(Base, ABC, extra=pydantic.Extra.allow):
"""The state of the app."""

Expand All @@ -167,6 +176,18 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# The event handlers.
event_handlers: ClassVar[Dict[str, EventHandler]] = {}

# Mapping of var name to set of computed variables that depend on it
_computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}

# Mapping of var name to set of substates that depend on it
_substate_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}

# Set of vars which always need to be recomputed
_always_dirty_computed_vars: ClassVar[Set[str]] = set()

# Set of substates which always need to be recomputed
_always_dirty_substates: ClassVar[Set[str]] = set()

# The parent state.
parent_state: Optional[State] = None

Expand All @@ -182,12 +203,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# The routing path that triggered the state
router_data: Dict[str, Any] = {}

# Mapping of var name to set of computed variables that depend on it
computed_var_dependencies: Dict[str, Set[str]] = {}

# Mapping of var name to set of substates that depend on it
substate_var_dependencies: Dict[str, Set[str]] = {}

# Per-instance copy of backend variable values
_backend_vars: Dict[str, Any] = {}

Expand All @@ -211,10 +226,6 @@ def __init__(self, *args, parent_state: State | None = None, **kwargs):
kwargs["parent_state"] = parent_state
super().__init__(*args, **kwargs)

# initialize per-instance var dependency tracking
self.computed_var_dependencies = defaultdict(set)
self.substate_var_dependencies = defaultdict(set)

# Setup the substates.
for substate in self.get_substates():
substate_name = substate.get_name()
Expand All @@ -227,25 +238,6 @@ def __init__(self, *args, parent_state: State | None = None, **kwargs):
# Convert the event handlers to functions.
self._init_event_handlers()

# Initialize computed vars dependencies.
inherited_vars = set(self.inherited_vars).union(
set(self.inherited_backend_vars),
)
for cvar_name, cvar in self.computed_vars.items():
# Add the dependencies.
for var in cvar._deps(objclass=type(self)):
self.computed_var_dependencies[var].add(cvar_name)
if var in inherited_vars:
# track that this substate depends on its parent for this var
state_name = self.get_name()
parent_state = self.parent_state
while parent_state is not None and var in parent_state.vars:
parent_state.substate_var_dependencies[var].add(state_name)
state_name, parent_state = (
parent_state.get_name(),
parent_state.parent_state,
)

# Create a fresh copy of the backend variables for this instance
self._backend_vars = copy.deepcopy(self.backend_vars)

Expand Down Expand Up @@ -347,6 +339,60 @@ def __init_subclass__(cls, **kwargs):
cls.event_handlers[name] = handler
setattr(cls, name, handler)

cls._init_var_dependency_dicts()

@classmethod
def _init_var_dependency_dicts(cls):
"""Initialize the var dependency tracking dicts.

Allows the state to know which vars each ComputedVar depends on and
whether a ComputedVar depends on a var in its parent state.

Additional updates tracking dicts for vars and substates that always
need to be recomputed.
"""
# Initialize per-class var dependency tracking.
cls._computed_var_dependencies = defaultdict(set)
cls._substate_var_dependencies = defaultdict(set)

inherited_vars = set(cls.inherited_vars).union(
set(cls.inherited_backend_vars),
)
for cvar_name, cvar in cls.computed_vars.items():
# Add the dependencies.
for var in cvar._deps(objclass=cls):
cls._computed_var_dependencies[var].add(cvar_name)
if var in inherited_vars:
# track that this substate depends on its parent for this var
state_name = cls.get_name()
parent_state = cls.get_parent_state()
while parent_state is not None and var in parent_state.vars:
parent_state._substate_var_dependencies[var].add(state_name)
state_name, parent_state = (
parent_state.get_name(),
parent_state.get_parent_state(),
)

# ComputedVar with cache=False always need to be recomputed
cls._always_dirty_computed_vars = set(
cvar_name
for cvar_name, cvar in cls.computed_vars.items()
if not cvar._cache
)

# Any substate containing a ComputedVar with cache=False always needs to be recomputed
cls._always_dirty_substates = set()
if cls._always_dirty_computed_vars:
# Tell parent classes that this substate has always dirty computed vars
state_name = cls.get_name()
parent_state = cls.get_parent_state()
while parent_state is not None:
parent_state._always_dirty_substates.add(state_name)
state_name, parent_state = (
parent_state.get_name(),
parent_state.get_parent_state(),
)

@classmethod
def _check_overridden_methods(cls):
"""Check for shadow methods and raise error if any.
Expand Down Expand Up @@ -377,16 +423,17 @@ def get_skip_vars(cls) -> set[str]:
Returns:
The vars to skip when serializing.
"""
return set(cls.inherited_vars) | {
"parent_state",
"substates",
"dirty_vars",
"dirty_substates",
"router_data",
"computed_var_dependencies",
"substate_var_dependencies",
"_backend_vars",
}
return (
set(cls.inherited_vars)
| {
"parent_state",
"substates",
"dirty_vars",
"dirty_substates",
"router_data",
}
| RESERVED_BACKEND_VAR_NAMES
)

@classmethod
@functools.lru_cache()
Expand Down Expand Up @@ -540,6 +587,9 @@ def add_var(cls, name: str, type_: Any, default_value: Any = None):
for substate_class in cls.__subclasses__():
substate_class.vars.setdefault(name, var)

# Reinitialize dependency tracking dicts.
cls._init_var_dependency_dicts()

@classmethod
def _set_var(cls, prop: BaseVar):
"""Set the var as a class member.
Expand Down Expand Up @@ -749,6 +799,9 @@ def inner_func(self) -> List:
cls.vars[param] = cls.computed_vars[param] = func._var_set_state(cls) # type: ignore
setattr(cls, param, func)

# Reinitialize dependency tracking dicts.
cls._init_var_dependency_dicts()

def __getattribute__(self, name: str) -> Any:
"""Get the state var.

Expand Down Expand Up @@ -804,7 +857,7 @@ def __setattr__(self, name: str, value: Any):
setattr(self.parent_state, name, value)
return

if types.is_backend_variable(name) and name != "_backend_vars":
if types.is_backend_variable(name) and name not in RESERVED_BACKEND_VAR_NAMES:
self._backend_vars.__setitem__(name, value)
self.dirty_vars.add(name)
self._mark_dirty()
Expand All @@ -814,7 +867,7 @@ def __setattr__(self, name: str, value: Any):
super().__setattr__(name, value)

# Add the var to the dirty list.
if name in self.vars or name in self.computed_var_dependencies:
if name in self.vars or name in self._computed_var_dependencies:
self.dirty_vars.add(name)
self._mark_dirty()

Expand Down Expand Up @@ -1056,18 +1109,6 @@ async def _process_event(
final=True,
)

def _always_dirty_computed_vars(self) -> set[str]:
"""The set of ComputedVars that always need to be recalculated.

Returns:
Set of all ComputedVar in this state where cache=False
"""
return set(
cvar_name
for cvar_name, cvar in self.computed_vars.items()
if not cvar._cache
)

def _mark_dirty_computed_vars(self) -> None:
"""Mark ComputedVars that need to be recalculated based on dirty_vars."""
dirty_vars = self.dirty_vars
Expand All @@ -1092,7 +1133,7 @@ def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]:
return set(
cvar
for dirty_var in from_vars or self.dirty_vars
for cvar in self.computed_var_dependencies[dirty_var]
for cvar in self._computed_var_dependencies[dirty_var]
)

def get_delta(self) -> Delta:
Expand All @@ -1104,15 +1145,15 @@ def get_delta(self) -> Delta:
delta = {}

# Apply dirty variables down into substates
self.dirty_vars.update(self._always_dirty_computed_vars())
self.dirty_vars.update(self._always_dirty_computed_vars)
self._mark_dirty()

# Return the dirty vars for this instance, any cached/dependent computed vars,
# and always dirty computed vars (cache=False)
delta_vars = (
self.dirty_vars.intersection(self.base_vars)
.union(self._dirty_computed_vars())
.union(self._always_dirty_computed_vars())
.union(self._always_dirty_computed_vars)
)

subdelta = {
Expand All @@ -1125,7 +1166,7 @@ def get_delta(self) -> Delta:

# Recursively find the substate deltas.
substates = self.substates
for substate in self.dirty_substates:
for substate in self.dirty_substates.union(self._always_dirty_substates):
delta.update(substates[substate].get_delta())

# Format the delta.
Expand All @@ -1151,7 +1192,7 @@ def _mark_dirty(self):
# Propagate dirty var / computed var status into substates
substates = self.substates
for var in self.dirty_vars:
for substate_name in self.substate_var_dependencies[var]:
for substate_name in self._substate_var_dependencies[var]:
self.dirty_substates.add(substate_name)
substate = substates[substate_name]
substate.dirty_vars.add(var)
Expand Down Expand Up @@ -1195,7 +1236,7 @@ def dict(self, include_computed: bool = True, **kwargs) -> dict[str, Any]:
if include_computed:
# Apply dirty variables down into substates to allow never-cached ComputedVar to
# trigger recalculation of dependent vars
self.dirty_vars.update(self._always_dirty_computed_vars())
self.dirty_vars.update(self._always_dirty_computed_vars)
self._mark_dirty()

base_vars = {
Expand Down
4 changes: 2 additions & 2 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
constants.ROUTER
}
assert constants.ROUTER in app.state().computed_var_dependencies
assert constants.ROUTER in app.state()._computed_var_dependencies


def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
Expand Down Expand Up @@ -917,7 +917,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
assert app.state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
constants.ROUTER
}
assert constants.ROUTER in app.state().computed_var_dependencies
assert constants.ROUTER in app.state()._computed_var_dependencies

sid = "mock_sid"
client_ip = "127.0.0.1"
Expand Down
12 changes: 6 additions & 6 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,7 +1215,7 @@ def cached_x_side_effect(self) -> int:
assert isinstance(HandlerState.handler, EventHandler)

s = HandlerState()
assert "cached_x_side_effect" in s.computed_var_dependencies["x"]
assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
assert s.cached_x_side_effect == 1
assert s.x == 43
s.handler()
Expand Down Expand Up @@ -1283,11 +1283,11 @@ def comp_z(self) -> List[bool]:
return [z in self._z for z in range(5)]

cs = ComputedState()
assert cs.computed_var_dependencies["v"] == {"comp_v"}
assert cs.computed_var_dependencies["w"] == {"comp_w"}
assert cs.computed_var_dependencies["x"] == {"comp_x"}
assert cs.computed_var_dependencies["y"] == {"comp_y"}
assert cs.computed_var_dependencies["_z"] == {"comp_z"}
assert cs._computed_var_dependencies["v"] == {"comp_v"}
assert cs._computed_var_dependencies["w"] == {"comp_w"}
assert cs._computed_var_dependencies["x"] == {"comp_x"}
assert cs._computed_var_dependencies["y"] == {"comp_y"}
assert cs._computed_var_dependencies["_z"] == {"comp_z"}


def test_backend_method():
Expand Down
Loading