Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement on_mount and on_unmount for all components. #1636

Merged
merged 12 commits into from
Aug 30, 2023
93 changes: 93 additions & 0 deletions integration/test_event_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,35 @@ def on_load_yield_chain():
rx.input(value=State.token, readonly=True, id="token"),
)

def on_mount_return_chain():
return rx.fragment(
rx.text(
"return",
on_mount=State.on_load_return_chain,
on_unmount=lambda: State.event_arg("unmount"), # type: ignore
),
rx.input(value=State.token, readonly=True, id="token"),
rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"),
)

def on_mount_yield_chain():
return rx.fragment(
rx.text(
"yield",
on_mount=[
State.on_load_yield_chain,
lambda: State.event_arg("mount"), # type: ignore
],
on_unmount=State.event_no_args,
),
rx.input(value=State.token, readonly=True, id="token"),
rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"),
)

app.add_page(on_load_return_chain, on_load=State.on_load_return_chain) # type: ignore
app.add_page(on_load_yield_chain, on_load=State.on_load_yield_chain) # type: ignore
app.add_page(on_mount_return_chain)
app.add_page(on_mount_yield_chain)

app.compile()

Expand Down Expand Up @@ -330,3 +357,69 @@ def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
time.sleep(0.5)
backend_state = event_chain.app_instance.state_manager.states[token]
assert backend_state.event_order == exp_event_order


@pytest.mark.parametrize(
("uri", "exp_event_order"),
[
(
"/on-mount-return-chain",
[
"on_load_return_chain",
"event_arg:unmount",
"on_load_return_chain",
"event_arg:1",
"event_arg:2",
"event_arg:3",
"event_arg:1",
"event_arg:2",
"event_arg:3",
"event_arg:unmount",
],
),
(
"/on-mount-yield-chain",
[
"on_load_yield_chain",
"event_arg:mount",
"event_no_args",
"on_load_yield_chain",
"event_arg:mount",
"event_arg:4",
"event_arg:5",
"event_arg:6",
"event_arg:4",
"event_arg:5",
"event_arg:6",
"event_no_args",
],
),
],
)
def test_event_chain_on_mount(event_chain, driver, uri, exp_event_order):
"""Load the URI, assert that the events are handled in the correct order.

These pages use `on_mount` and `on_unmount`, which get fired twice in dev mode
due to react StrictMode being used.

In prod mode, these events are only fired once.

Args:
event_chain: AppHarness for the event_chain app
driver: selenium WebDriver open to the app
uri: the page to load
exp_event_order: the expected events recorded in the State
"""
driver.get(event_chain.frontend_url + uri)
token_input = driver.find_element(By.ID, "token")
assert token_input

token = event_chain.poll_for_value(token_input)

unmount_button = driver.find_element(By.ID, "unmount")
assert unmount_button
unmount_button.click()

time.sleep(1)
backend_state = event_chain.app_instance.state_manager.states[token]
assert backend_state.event_order == exp_event_order
5 changes: 5 additions & 0 deletions reflex/.templates/web/utils/state.js
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ export const queueEvents = async (events, socket) => {
export const processEvent = async (
socket
) => {
// Only proceed if the socket is up, otherwise we throw the event into the void
if (!socket) {
return;
}

// Only proceed if we're not already processing an event.
if (event_queue.length === 0 || event_processing) {
return;
Expand Down
63 changes: 57 additions & 6 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,11 @@ def get_triggers(self) -> Set[str]:
Returns:
The event triggers.
"""
return EVENT_TRIGGERS | set(self.get_controlled_triggers())
return (
EVENT_TRIGGERS
| set(self.get_controlled_triggers())
| set((constants.ON_MOUNT, constants.ON_UNMOUNT))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we just add these to the default EVENT_TRIGGERS ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i went back and forth on this, and ultimately did it this way because these aren't "real" event triggers that come from browser events, but are synthetic events that reflex handles.

)

def get_controlled_triggers(self) -> Dict[str, Var]:
"""Get the event triggers that pass the component's value to the handler.
Expand Down Expand Up @@ -525,16 +529,63 @@ def get_imports(self) -> imports.ImportDict:
self._get_imports(), *[child.get_imports() for child in self.children]
)

def _get_hooks(self) -> Optional[str]:
"""Get the React hooks for this component.
def _get_mount_lifecycle_hook(self) -> str | None:
"""Generate the component lifecycle hook.

Returns:
The hooks for just this component.
The useEffect hook for managing `on_mount` and `on_unmount` events.
"""
# pop on_mount and on_unmount from event_triggers since these are handled by
# hooks, not as actually props in the component
on_mount = self.event_triggers.pop(constants.ON_MOUNT, None)
on_unmount = self.event_triggers.pop(constants.ON_UNMOUNT, None)
if on_mount:
on_mount = format.format_event_chain(on_mount)
if on_unmount:
on_unmount = format.format_event_chain(on_unmount)
if on_mount or on_unmount:
return f"""
useEffect(() => {{
{on_mount or ""}
return () => {{
{on_unmount or ""}
}}
}}, []);"""

def _get_ref_hook(self) -> str | None:
"""Generate the ref hook for the component.

Returns:
The useRef hook for managing refs.
"""
ref = self.get_ref()
if ref is not None:
return f"const {ref} = useRef(null); refs['{ref}'] = {ref};"
return None

def _get_hooks_internal(self) -> Set[str]:
"""Get the React hooks for this component managed by the framework.

Downstream components should NOT override this method to avoid breaking
framework functionality.

Returns:
Set of internally managed hooks.
"""
return set(
hook
for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()]
if hook
)

def _get_hooks(self) -> Optional[str]:
"""Get the React hooks for this component.

Downstream components should override this method to add their own hooks.

Returns:
The hooks for just this component.
"""
return

def get_hooks(self) -> Set[str]:
"""Get the React hooks for this component and its children.
Expand All @@ -543,7 +594,7 @@ def get_hooks(self) -> Set[str]:
The code that should appear just before returning the rendered component.
"""
# Store the code in a set to avoid duplicates.
code = set()
code = self._get_hooks_internal()

# Add the hook code for this component.
hooks = self._get_hooks()
Expand Down
8 changes: 4 additions & 4 deletions reflex/components/forms/pininput.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def get_ref(self):
"""
return None

def _get_hooks(self) -> Optional[str]:
"""Override the base get_hooks to handle array refs.
def _get_ref_hook(self) -> Optional[str]:
"""Override the base _get_ref_hook to handle array refs.

Returns:
The overrided hooks.
Expand All @@ -86,7 +86,7 @@ def _get_hooks(self) -> Optional[str]:
ref = format.format_array_ref(self.id, None)
if ref:
return f"const {ref} = Array.from({{length:{self.length}}}, () => useRef(null));"
return super()._get_hooks()
return super()._get_ref_hook()

@classmethod
def create(cls, *children, **props) -> Component:
Expand Down Expand Up @@ -130,7 +130,7 @@ class PinInputField(ChakraComponent):
# Default to None because it is assigned by PinInput when created.
index: Optional[Var[int]] = None

def _get_hooks(self) -> Optional[str]:
def _get_ref_hook(self) -> Optional[str]:
return None

def get_ref(self):
Expand Down
8 changes: 4 additions & 4 deletions reflex/components/forms/rangeslider.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def get_ref(self):
"""
return None

def _get_hooks(self) -> Optional[str]:
"""Override the base get_hooks to handle array refs.
def _get_ref_hook(self) -> Optional[str]:
"""Override the base _get_ref_hook to handle array refs.

Returns:
The overrided hooks.
Expand All @@ -74,7 +74,7 @@ def _get_hooks(self) -> Optional[str]:
ref = format.format_array_ref(self.id, None)
if ref:
return f"const {ref} = Array.from({{length:2}}, () => useRef(null));"
return super()._get_hooks()
return super()._get_ref_hook()

@classmethod
def create(cls, *children, **props) -> Component:
Expand Down Expand Up @@ -130,7 +130,7 @@ class RangeSliderThumb(ChakraComponent):
# The position of the thumb.
index: Var[int]

def _get_hooks(self) -> Optional[str]:
def _get_ref_hook(self) -> Optional[str]:
# hook is None because RangeSlider is handling it.
return None

Expand Down
4 changes: 4 additions & 0 deletions reflex/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,5 +359,9 @@ class RouteRegex(SimpleNamespace):
# Alembic migrations
ALEMBIC_CONFIG = os.environ.get("ALEMBIC_CONFIG", "alembic.ini")

# Names of event handlers on all components mapped to useEffect
ON_MOUNT = "on_mount"
ON_UNMOUNT = "on_unmount"

# If this env var is set to "yes", App.compile will be a no-op
SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE"
61 changes: 60 additions & 1 deletion reflex/utils/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

from reflex import constants
from reflex.utils import types
from reflex.vars import Var

if TYPE_CHECKING:
from reflex.components.component import ComponentStyle
from reflex.event import EventHandler, EventSpec
from reflex.event import EventChain, EventHandler, EventSpec

WRAP_MAP = {
"{": "}",
Expand Down Expand Up @@ -182,6 +183,24 @@ def format_string(string: str) -> str:
return string


def format_var(var: Var) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be the same as the Var.str method? Should we move this code there? This was the intention of that method.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, it's a little different. i cribbed this from Tag.format_prop, so i assumed it must be doing something special. i'll see if i can combine them without breaking tests, so we can just use str(var) instead.

"""Format the given Var as a javascript value.

Args:
var: The Var to format.

Returns:
The formatted Var.
"""
if not var.is_local or var.is_string:
return str(var)
if types._issubclass(var.type_, str):
return format_string(var.full_name)
if is_wrapped(var.full_name, "{"):
return var.full_name
return json_dumps(var.full_name)


def format_route(route: str) -> str:
"""Format the given route.

Expand Down Expand Up @@ -311,6 +330,46 @@ def format_event(event_spec: EventSpec) -> str:
return f"E({', '.join(event_args)})"


def format_event_chain(
event_chain: EventChain | Var[EventChain],
event_arg: Var | None = None,
) -> str:
"""Format an event chain as a javascript invocation.

Args:
event_chain: The event chain to queue on the frontend.
event_arg: The browser-native event (only used to preventDefault).

Returns:
Compiled javascript code to queue the given event chain on the frontend.

Raises:
ValueError: When the given event chain is not a valid event chain.
"""
if isinstance(event_chain, Var):
from reflex.event import EventChain

if event_chain.type_ is not EventChain:
raise ValueError(f"Invalid event chain: {event_chain}")
return "".join(
[
"(() => {",
format_var(event_chain),
f"; preventDefault({format_var(event_arg)})" if event_arg else "",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Event function in state.js should also handle the preventDefault logic I thought?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we're formatting an arbitrary var as an event chain, there's no guarantee that Event gets called at all; it could be anything.

"})()",
]
)

chain = ",".join([format_event(event) for event in event_chain.events])
return "".join(
[
f"Event([{chain}]",
f", {format_var(event_arg)}" if event_arg else "",
")",
]
)


def format_query_params(router_data: Dict[str, Any]) -> Dict[str, str]:
"""Convert back query params name to python-friendly case.

Expand Down
6 changes: 4 additions & 2 deletions tests/components/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import reflex as rx
from reflex.components.component import Component, CustomComponent, custom_component
from reflex.components.layout.box import Box
from reflex.constants import ON_MOUNT, ON_UNMOUNT
from reflex.event import EVENT_ARG, EVENT_TRIGGERS, EventHandler
from reflex.state import State
from reflex.style import Style
Expand Down Expand Up @@ -377,8 +378,9 @@ def test_get_triggers(component1, component2):
component1: A test component.
component2: A test component.
"""
assert component1().get_triggers() == EVENT_TRIGGERS
assert component2().get_triggers() == {"on_open", "on_close"} | EVENT_TRIGGERS
default_triggers = {ON_MOUNT, ON_UNMOUNT} | EVENT_TRIGGERS
assert component1().get_triggers() == default_triggers
assert component2().get_triggers() == {"on_open", "on_close"} | default_triggers


def test_create_custom_component(my_component):
Expand Down