Skip to content

Commit

Permalink
defer dependency dicts to compile, except for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
benedikt-bartscher committed Aug 27, 2024
1 parent 3ce38cd commit f3a98f3
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 16 deletions.
8 changes: 8 additions & 0 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,13 @@ def _apply_decorated_pages(self):
for render, kwargs in DECORATED_PAGES[get_config().app_name]:
self.add_page(render, **kwargs)

def _init_var_dependency_dicts(self):
"""Recursively initialize the var dependency dictionaries for all state subclasses of the app."""
if not self.state:
return

self.state._reset_var_dependency_dicts()

def _validate_var_dependencies(
self, state: Optional[Type[BaseState]] = None
) -> None:
Expand Down Expand Up @@ -853,6 +860,7 @@ def get_compilation_time() -> str:
if not self._should_compile():
return

self._init_var_dependency_dicts()
self._validate_var_dependencies()
self._setup_overlay_component()
self._setup_error_boundary()
Expand Down
31 changes: 17 additions & 14 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,10 +589,8 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs):
cls.event_handlers[name] = handler
setattr(cls, name, handler)

# Initialize per-class var dependency tracking.
cls._computed_var_dependencies = defaultdict(set)
cls._substate_var_dependencies = defaultdict(set)
cls._init_var_dependency_dicts()
if is_testing_env():
BaseState._reset_var_dependency_dicts()

@staticmethod
def _copy_fn(fn: Callable) -> Callable:
Expand Down Expand Up @@ -651,6 +649,21 @@ def _mixins(cls) -> List[Type]:
)
]

@classmethod
def _reset_var_dependency_dicts(cls, state: Type[BaseState] | None = None):
if not state:
state = cls
cls._computed_var_dependencies = defaultdict(set)
cls._substate_var_dependencies = defaultdict(set)
if is_testing_env():
# clear all subclasses, because substates are cleared in BaseState.__init_subclass__
substates = state.__subclasses__()
else:
substates = state.get_substates()
for substate in substates:
substate._reset_var_dependency_dicts(state=substate)
cls._init_var_dependency_dicts()

@classmethod
def _init_var_dependency_dicts(cls):
"""Initialize the var dependency tracking dicts.
Expand Down Expand Up @@ -918,9 +931,6 @@ def add_var(cls, name: str, type_: Any, default_value: Any = None):
for substate_class in cls.class_subclasses:
substate_class.vars.setdefault(name, var)

# Reinitialize dependency tracking dicts.
cls._init_var_dependency_dicts()

@classmethod
def _set_var(cls, prop: ImmutableVar):
"""Set the var as a class member.
Expand Down Expand Up @@ -1037,9 +1047,6 @@ def inner_func(self) -> List:
cls.vars[param] = cls.computed_vars[param] = func._var_set_state(cls) # type: ignore
setattr(cls, param, func)

# Reinitialize dependency tracking dicts.
cls._init_var_dependency_dicts()

def __getattribute__(self, name: str) -> Any:
"""Get the state var.
Expand Down Expand Up @@ -1709,9 +1716,6 @@ def _mark_dirty_substates(self):
substates = self.substates
for var in self.dirty_vars:
for substate_name in self._substate_var_dependencies[var]:
if substate_name not in substates:
# TODO: why is this happening?
continue
self.dirty_substates.add(substate_name)
substate = substates[substate_name]
substate.dirty_vars.add(var)
Expand Down Expand Up @@ -3365,5 +3369,4 @@ def reload_state_module(
state._always_dirty_substates.discard(subclass.get_name())
state._computed_var_dependencies = defaultdict(set)
state._substate_var_dependencies = defaultdict(set)
state._init_var_dependency_dicts()
state.get_class_substate.cache_clear()
2 changes: 2 additions & 0 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
route.lstrip("/").replace("/", "\\")
assert app.pages == {}
app.add_page(index_page, route=route)
app._init_var_dependency_dicts()
assert app.pages.keys() == {"test/[dynamic]"}
assert "dynamic" in app.state.computed_vars
assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
Expand Down Expand Up @@ -946,6 +947,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
assert app.state is not None
assert arg_name not in app.state.vars
app.add_page(index_page, route=route, on_load=DynamicState.on_load) # type: ignore
app._init_var_dependency_dicts()
assert arg_name in app.state.vars
assert arg_name in app.state.computed_vars
assert app.state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
Expand Down
3 changes: 1 addition & 2 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def test_get_substate(test_state, child_state, child_state2, grandchild_state):
)


def test_set_dirty_var(test_state):
def test_set_dirty_var(test_state: TestState):
"""Test changing state vars marks the value as dirty.
Args:
Expand Down Expand Up @@ -3085,7 +3085,6 @@ def foo(self) -> str:
return self.router.page.params.get("foo", "")

foo = RouterVarDepState.computed_vars["foo"]
State._init_var_dependency_dicts()

assert foo._deps(objclass=RouterVarDepState) == {"router"}
assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState}
Expand Down

0 comments on commit f3a98f3

Please sign in to comment.