Skip to content

Commit

Permalink
typed mixins and ComponentState (#3196)
Browse files Browse the repository at this point in the history
* typed mixins

* implicit mixin=True kwarg for ComponentState subclasses

* fix: always init other subclasses

* adjust tests: all mixins support base vars now
  • Loading branch information
benedikt-bartscher authored May 15, 2024
1 parent 87a3dde commit d96baac
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 29 deletions.
1 change: 1 addition & 0 deletions integration/test_component_state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test that per-component state scaffold works and operates independently."""

from typing import Generator

import pytest
Expand Down
66 changes: 44 additions & 22 deletions integration/test_state_inheritance.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,15 @@ def StateInheritance():
"""Test that state inheritance works as expected."""
import reflex as rx

class ChildMixin:
# mixin basevars only work with pydantic/rx.Base models
# child_mixin: str = "child_mixin"
class ChildMixin(rx.State, mixin=True):
child_mixin: str = "child_mixin"

@rx.var
def computed_child_mixin(self) -> str:
return "computed_child_mixin"

class Mixin(ChildMixin):
# mixin basevars only work with pydantic/rx.Base models
# mixin: str = "mixin"
class Mixin(ChildMixin, mixin=True):
mixin: str = "mixin"

@rx.var
def computed_mixin(self) -> str:
Expand All @@ -64,7 +62,7 @@ def computed_mixin(self) -> str:
def on_click_mixin(self):
return rx.call_script("alert('clicked')")

class OtherMixin(rx.Base):
class OtherMixin(rx.State, mixin=True):
other_mixin: str = "other_mixin"
other_mixin_clicks: int = 0

Expand All @@ -78,7 +76,7 @@ def on_click_other_mixin(self):
f"{self.__class__.__name__}.clicked.{self.other_mixin_clicks}"
)

class Base1(rx.State, Mixin):
class Base1(Mixin, rx.State):
_base1: str = "_base1"
base1: str = "base1"

Expand Down Expand Up @@ -122,14 +120,15 @@ def computed_backend_vars_child3(self) -> str:

def index() -> rx.Component:
return rx.vstack(
rx.chakra.input(
rx.input(
id="token", value=Base1.router.session.client_token, is_read_only=True
),
# Base 1
# Base 1 (Mixin, ChildMixin)
rx.heading(Base1.computed_mixin, id="base1-computed_mixin"),
rx.heading(Base1.computed_basevar, id="base1-computed_basevar"),
rx.heading(Base1.computed_child_mixin, id="base1-child-mixin"),
rx.heading(Base1.computed_child_mixin, id="base1-computed-child-mixin"),
rx.heading(Base1.base1, id="base1-base1"),
rx.heading(Base1.child_mixin, id="base1-child-mixin"),
rx.button(
"Base1.on_click_mixin",
on_click=Base1.on_click_mixin, # type: ignore
Expand All @@ -138,31 +137,33 @@ def index() -> rx.Component:
rx.heading(
Base1.computed_backend_vars_base1, id="base1-computed_backend_vars"
),
# Base 2
# Base 2 (no mixins)
rx.heading(Base2.computed_basevar, id="base2-computed_basevar"),
rx.heading(Base2.base2, id="base2-base2"),
rx.heading(
Base2.computed_backend_vars_base2, id="base2-computed_backend_vars"
),
# Child 1
# Child 1 (Mixin, ChildMixin, OtherMixin)
rx.heading(Child1.computed_basevar, id="child1-computed_basevar"),
rx.heading(Child1.computed_mixin, id="child1-computed_mixin"),
rx.heading(Child1.computed_other_mixin, id="child1-other-mixin"),
rx.heading(Child1.computed_child_mixin, id="child1-child-mixin"),
rx.heading(Child1.computed_child_mixin, id="child1-computed-child-mixin"),
rx.heading(Child1.base1, id="child1-base1"),
rx.heading(Child1.other_mixin, id="child1-other_mixin"),
rx.heading(Child1.child_mixin, id="child1-child-mixin"),
rx.button(
"Child1.on_click_other_mixin",
on_click=Child1.on_click_other_mixin, # type: ignore
id="child1-other-mixin-btn",
),
# Child 2
# Child 2 (Mixin, ChildMixin, OtherMixin)
rx.heading(Child2.computed_basevar, id="child2-computed_basevar"),
rx.heading(Child2.computed_mixin, id="child2-computed_mixin"),
rx.heading(Child2.computed_other_mixin, id="child2-other-mixin"),
rx.heading(Child2.computed_child_mixin, id="child2-child-mixin"),
rx.heading(Child2.computed_child_mixin, id="child2-computed-child-mixin"),
rx.heading(Child2.base2, id="child2-base2"),
rx.heading(Child2.other_mixin, id="child2-other_mixin"),
rx.heading(Child2.child_mixin, id="child2-child-mixin"),
rx.button(
"Child2.on_click_mixin",
on_click=Child2.on_click_mixin, # type: ignore
Expand All @@ -173,15 +174,16 @@ def index() -> rx.Component:
on_click=Child2.on_click_other_mixin, # type: ignore
id="child2-other-mixin-btn",
),
# Child 3
# Child 3 (Mixin, ChildMixin, OtherMixin)
rx.heading(Child3.computed_basevar, id="child3-computed_basevar"),
rx.heading(Child3.computed_mixin, id="child3-computed_mixin"),
rx.heading(Child3.computed_other_mixin, id="child3-other-mixin"),
rx.heading(Child3.computed_childvar, id="child3-computed_childvar"),
rx.heading(Child3.computed_child_mixin, id="child3-child-mixin"),
rx.heading(Child3.computed_child_mixin, id="child3-computed-child-mixin"),
rx.heading(Child3.child3, id="child3-child3"),
rx.heading(Child3.base2, id="child3-base2"),
rx.heading(Child3.other_mixin, id="child3-other_mixin"),
rx.heading(Child3.child_mixin, id="child3-child-mixin"),
rx.button(
"Child3.on_click_mixin",
on_click=Child3.on_click_mixin, # type: ignore
Expand Down Expand Up @@ -282,7 +284,9 @@ def test_state_inheritance(
base1_computed_basevar = driver.find_element(By.ID, "base1-computed_basevar")
assert base1_computed_basevar.text == "computed_basevar1"

base1_computed_child_mixin = driver.find_element(By.ID, "base1-child-mixin")
base1_computed_child_mixin = driver.find_element(
By.ID, "base1-computed-child-mixin"
)
assert base1_computed_child_mixin.text == "computed_child_mixin"

base1_base1 = driver.find_element(By.ID, "base1-base1")
Expand All @@ -293,6 +297,9 @@ def test_state_inheritance(
)
assert base1_computed_backend_vars.text == "_base1"

base1_child_mixin = driver.find_element(By.ID, "base1-child-mixin")
assert base1_child_mixin.text == "child_mixin"

# Base 2
base2_computed_basevar = driver.find_element(By.ID, "base2-computed_basevar")
assert base2_computed_basevar.text == "computed_basevar2"
Expand All @@ -315,7 +322,9 @@ def test_state_inheritance(
child1_computed_other_mixin = driver.find_element(By.ID, "child1-other-mixin")
assert child1_computed_other_mixin.text == "other_mixin"

child1_computed_child_mixin = driver.find_element(By.ID, "child1-child-mixin")
child1_computed_child_mixin = driver.find_element(
By.ID, "child1-computed-child-mixin"
)
assert child1_computed_child_mixin.text == "computed_child_mixin"

child1_base1 = driver.find_element(By.ID, "child1-base1")
Expand All @@ -324,6 +333,9 @@ def test_state_inheritance(
child1_other_mixin = driver.find_element(By.ID, "child1-other_mixin")
assert child1_other_mixin.text == "other_mixin"

child1_child_mixin = driver.find_element(By.ID, "child1-child-mixin")
assert child1_child_mixin.text == "child_mixin"

# Child 2
child2_computed_basevar = driver.find_element(By.ID, "child2-computed_basevar")
assert child2_computed_basevar.text == "computed_basevar2"
Expand All @@ -334,7 +346,9 @@ def test_state_inheritance(
child2_computed_other_mixin = driver.find_element(By.ID, "child2-other-mixin")
assert child2_computed_other_mixin.text == "other_mixin"

child2_computed_child_mixin = driver.find_element(By.ID, "child2-child-mixin")
child2_computed_child_mixin = driver.find_element(
By.ID, "child2-computed-child-mixin"
)
assert child2_computed_child_mixin.text == "computed_child_mixin"

child2_base2 = driver.find_element(By.ID, "child2-base2")
Expand All @@ -343,6 +357,9 @@ def test_state_inheritance(
child2_other_mixin = driver.find_element(By.ID, "child2-other_mixin")
assert child2_other_mixin.text == "other_mixin"

child2_child_mixin = driver.find_element(By.ID, "child2-child-mixin")
assert child2_child_mixin.text == "child_mixin"

# Child 3
child3_computed_basevar = driver.find_element(By.ID, "child3-computed_basevar")
assert child3_computed_basevar.text == "computed_basevar2"
Expand All @@ -356,7 +373,9 @@ def test_state_inheritance(
child3_computed_childvar = driver.find_element(By.ID, "child3-computed_childvar")
assert child3_computed_childvar.text == "computed_childvar"

child3_computed_child_mixin = driver.find_element(By.ID, "child3-child-mixin")
child3_computed_child_mixin = driver.find_element(
By.ID, "child3-computed-child-mixin"
)
assert child3_computed_child_mixin.text == "computed_child_mixin"

child3_child3 = driver.find_element(By.ID, "child3-child3")
Expand All @@ -368,6 +387,9 @@ def test_state_inheritance(
child3_other_mixin = driver.find_element(By.ID, "child3-other_mixin")
assert child3_other_mixin.text == "other_mixin"

child3_child_mixin = driver.find_element(By.ID, "child3-child-mixin")
assert child3_child_mixin.text == "child_mixin"

child3_computed_backend_vars = driver.find_element(
By.ID, "child3-computed_backend_vars"
)
Expand Down
37 changes: 30 additions & 7 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Whether the state has ever been touched since instantiation.
_was_touched: bool = False

# Whether this state class is a mixin and should not be instantiated.
_mixin: ClassVar[bool] = False

# A special event handler for setting base vars.
setvar: ClassVar[EventHandler]

Expand Down Expand Up @@ -428,17 +431,17 @@ def _get_computed_vars(cls) -> list[ComputedVar]:
"""
return [
v
for mixin in cls.__mro__
if mixin is cls or not issubclass(mixin, (BaseState, ABC))
for mixin in cls._mixins() + [cls]
for v in mixin.__dict__.values()
if isinstance(v, ComputedVar)
]

@classmethod
def __init_subclass__(cls, **kwargs):
def __init_subclass__(cls, mixin: bool = False, **kwargs):
"""Do some magic for the subclass initialization.
Args:
mixin: Whether the subclass is a mixin and should not be initialized.
**kwargs: The kwargs to pass to the pydantic init_subclass method.
Raises:
Expand All @@ -447,6 +450,11 @@ def __init_subclass__(cls, **kwargs):
from reflex.utils.exceptions import StateValueError

super().__init_subclass__(**kwargs)

cls._mixin = mixin
if mixin:
return

# Event handlers should not shadow builtin state methods.
cls._check_overridden_methods()
# Computed vars should not shadow builtin state props.
Expand Down Expand Up @@ -618,8 +626,11 @@ def _mixins(cls) -> List[Type]:
return [
mixin
for mixin in cls.__mro__
if not issubclass(mixin, (BaseState, ABC))
and mixin not in [pydantic.BaseModel, Base]
if (
mixin not in [pydantic.BaseModel, Base, cls]
and issubclass(mixin, BaseState)
and mixin._mixin is True
)
]

@classmethod
Expand Down Expand Up @@ -742,7 +753,7 @@ def get_parent_state(cls) -> Type[BaseState] | None:
parent_states = [
base
for base in cls.__bases__
if types._issubclass(base, BaseState) and base is not BaseState
if issubclass(base, BaseState) and base is not BaseState and not base._mixin
]
assert len(parent_states) < 2, "Only one parent state is allowed."
return parent_states[0] if len(parent_states) == 1 else None # type: ignore
Expand Down Expand Up @@ -1833,7 +1844,7 @@ def on_load_internal(self) -> list[Event | EventSpec] | None:
]


class ComponentState(Base):
class ComponentState(State, mixin=True):
"""Base class to allow for the creation of a state instance per component.
This allows for the bundling of UI and state logic into a single class,
Expand Down Expand Up @@ -1875,6 +1886,18 @@ def get_component(cls, **props):
# The number of components created from this class.
_per_component_state_instance_count: ClassVar[int] = 0

@classmethod
def __init_subclass__(cls, mixin: bool = False, **kwargs):
"""Overwrite mixin default to True.
Args:
mixin: Whether the subclass is a mixin and should not be initialized.
**kwargs: The kwargs to pass to the pydantic init_subclass method.
"""
if ComponentState in cls.__bases__:
mixin = True
super().__init_subclass__(mixin=mixin, **kwargs)

@classmethod
def get_component(cls, *children, **props) -> "Component":
"""Get the component instance.
Expand Down

0 comments on commit d96baac

Please sign in to comment.