Skip to content

Commit

Permalink
state: Initialize _always_dirty_computed_vars and _always_dirty_subst…
Browse files Browse the repository at this point in the history
…ates

Instead of calculating these every time there is a delta, determine them
at subclass definition time (or when adding a var dynamically).

This should improve performance for every delta calculation.

It also fixes #2066 by tracking substates which have ComputedVar with
_cache=False

Fix REF-1035
  • Loading branch information
masenf committed Nov 22, 2023
1 parent bd74a23 commit 76a51c1
Showing 1 changed file with 40 additions and 17 deletions.
57 changes: 40 additions & 17 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def __init__(self, router_data: Optional[dict] = None):
"_backend_vars",
"_computed_var_dependencies",
"_substate_var_dependencies",
"_always_dirty_computed_vars",
"_always_dirty_substates",
}


Expand Down Expand Up @@ -180,6 +182,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Mapping of var name to set of substates that depend on it
_substate_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}

# Set of vars which always need to be recomputed
_always_dirty_computed_vars: ClassVar[Set[str]] = set()

# Set of substates which always need to be recomputed
_always_dirty_substates: ClassVar[Set[str]] = set()

# The parent state.
parent_state: Optional[State] = None

Expand Down Expand Up @@ -335,11 +343,18 @@ def __init_subclass__(cls, **kwargs):

@classmethod
def _init_var_dependency_dicts(cls):
"""Initialize the var dependency tracking dicts.
Allows the state to know which vars each ComputedVar depends on and
whether a ComputedVar depends on a var in its parent state.
Additional updates tracking dicts for vars and substates that always
need to be recomputed.
"""
# Initialize per-class var dependency tracking.
cls._computed_var_dependencies = defaultdict(set)
cls._substate_var_dependencies = defaultdict(set)

# Initialize computed vars dependencies.
inherited_vars = set(cls.inherited_vars).union(
set(cls.inherited_backend_vars),
)
Expand All @@ -358,6 +373,26 @@ def _init_var_dependency_dicts(cls):
parent_state.get_parent_state(),
)

# ComputedVar with cache=False always need to be recomputed
cls._always_dirty_computed_vars = set(
cvar_name
for cvar_name, cvar in cls.computed_vars.items()
if not cvar._cache
)

# Any substate containing a ComputedVar with cache=False always needs to be recomputed
cls._always_dirty_substates = set()
if cls._always_dirty_computed_vars:
# Tell parent classes that this substate has always dirty computed vars
state_name = cls.get_name()
parent_state = cls.get_parent_state()
while parent_state is not None:
parent_state._always_dirty_substates.add(state_name)
state_name, parent_state = (
parent_state.get_name(),
parent_state.get_parent_state(),
)

@classmethod
def _check_overridden_methods(cls):
"""Check for shadow methods and raise error if any.
Expand Down Expand Up @@ -1071,18 +1106,6 @@ async def _process_event(
final=True,
)

def _always_dirty_computed_vars(self) -> set[str]:
"""The set of ComputedVars that always need to be recalculated.
Returns:
Set of all ComputedVar in this state where cache=False
"""
return set(
cvar_name
for cvar_name, cvar in self.computed_vars.items()
if not cvar._cache
)

def _mark_dirty_computed_vars(self) -> None:
"""Mark ComputedVars that need to be recalculated based on dirty_vars."""
dirty_vars = self.dirty_vars
Expand Down Expand Up @@ -1119,15 +1142,15 @@ def get_delta(self) -> Delta:
delta = {}

# Apply dirty variables down into substates
self.dirty_vars.update(self._always_dirty_computed_vars())
self.dirty_vars.update(self._always_dirty_computed_vars)
self._mark_dirty()

# Return the dirty vars for this instance, any cached/dependent computed vars,
# and always dirty computed vars (cache=False)
delta_vars = (
self.dirty_vars.intersection(self.base_vars)
.union(self._dirty_computed_vars())
.union(self._always_dirty_computed_vars())
.union(self._always_dirty_computed_vars)
)

subdelta = {
Expand All @@ -1140,7 +1163,7 @@ def get_delta(self) -> Delta:

# Recursively find the substate deltas.
substates = self.substates
for substate in self.dirty_substates:
for substate in self.dirty_substates.union(self._always_dirty_substates):
delta.update(substates[substate].get_delta())

# Format the delta.
Expand Down Expand Up @@ -1210,7 +1233,7 @@ def dict(self, include_computed: bool = True, **kwargs) -> dict[str, Any]:
if include_computed:
# Apply dirty variables down into substates to allow never-cached ComputedVar to
# trigger recalculation of dependent vars
self.dirty_vars.update(self._always_dirty_computed_vars())
self.dirty_vars.update(self._always_dirty_computed_vars)
self._mark_dirty()

base_vars = {
Expand Down

0 comments on commit 76a51c1

Please sign in to comment.