diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 5bc6b8b8b0..8386261e95 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -126,8 +126,13 @@ export const applyDelta = (state, delta) => { export const applyEvent = async (event, socket) => { // Handle special events if (event.name == "_redirect") { - if (event.payload.external) window.open(event.payload.path, "_blank"); - else Router.push(event.payload.path); + if (event.payload.external) { + window.open(event.payload.path, "_blank"); + } else if (event.payload.replace) { + Router.replace(event.payload.path); + } else { + Router.push(event.payload.path); + } return false; } diff --git a/reflex/app.py b/reflex/app.py index 4d99d6949e..4a7c60e2e4 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -41,7 +41,6 @@ from reflex.compiler import compiler from reflex.compiler import utils as compiler_utils from reflex.compiler.compiler import ExecutorSafeFunctions -from reflex.components import connection_modal, connection_pulser from reflex.components.base.app_wrap import AppWrap from reflex.components.base.fragment import Fragment from reflex.components.component import ( @@ -49,6 +48,7 @@ ComponentStyle, evaluate_style_namespaces, ) +from reflex.components.core import connection_pulser, connection_toaster from reflex.components.core.client_side_routing import ( Default404Page, wait_for_client_redirect, @@ -91,7 +91,7 @@ def default_overlay_component() -> Component: Returns: The default overlay_component, which is a connection_modal. """ - return Fragment.create(connection_pulser(), connection_modal()) + return Fragment.create(connection_pulser(), connection_toaster()) class OverlayFragment(Fragment): diff --git a/reflex/components/component.py b/reflex/components/component.py index 3608f8fa5f..b039ebe0d3 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -241,7 +241,7 @@ def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]: """ return {} - def add_hooks(self) -> list[str]: + def add_hooks(self) -> list[str | Var]: """Add hooks inside the component function. Hooks are pieces of literal Javascript code that is inserted inside the @@ -1265,11 +1265,20 @@ def _get_hooks_imports(self) -> imports.ImportDict: }, ) + other_imports = [] user_hooks = self._get_hooks() - if user_hooks is not None and isinstance(user_hooks, Var): - _imports = imports.merge_imports(_imports, user_hooks._var_data.imports) # type: ignore + if ( + user_hooks is not None + and isinstance(user_hooks, Var) + and user_hooks._var_data is not None + and user_hooks._var_data.imports + ): + other_imports.append(user_hooks._var_data.imports) + other_imports.extend( + hook_imports for hook_imports in self._get_added_hooks().values() + ) - return _imports + return imports.merge_imports(_imports, *other_imports) def _get_imports(self) -> imports.ImportDict: """Get all the libraries and fields that are used by the component. @@ -1416,6 +1425,36 @@ def _get_hooks_internal(self) -> dict[str, None]: **self._get_special_hooks(), } + def _get_added_hooks(self) -> dict[str, imports.ImportDict]: + """Get the hooks added via `add_hooks` method. + + Returns: + The deduplicated hooks and imports added by the component and parent components. + """ + code = {} + + def extract_var_hooks(hook: Var): + _imports = {} + if hook._var_data is not None: + for sub_hook in hook._var_data.hooks: + code[sub_hook] = {} + if hook._var_data.imports: + _imports = hook._var_data.imports + if str(hook) in code: + code[str(hook)] = imports.merge_imports(code[str(hook)], _imports) + else: + code[str(hook)] = _imports + + # Add the hook code from add_hooks for each parent class (this is reversed to preserve + # the order of the hooks in the final output) + for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))): + for hook in clz.add_hooks(self): + if isinstance(hook, Var): + extract_var_hooks(hook) + else: + code[hook] = {} + return code + def _get_hooks(self) -> str | None: """Get the React hooks for this component. @@ -1454,11 +1493,7 @@ def _get_all_hooks(self) -> dict[str, None]: if hooks is not None: code[hooks] = None - # Add the hook code from add_hooks for each parent class (this is reversed to preserve - # the order of the hooks in the final output) - for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))): - for hook in clz.add_hooks(self): - code[hook] = None + code.update(self._get_added_hooks()) # Add the hook code for the children. for child in self.children: @@ -2092,8 +2127,8 @@ def _get_memoized_event_triggers( var_deps.extend(cls._get_hook_deps(hook)) memo_var_data = VarData.merge( *[var._var_data for var in event_args], - VarData( # type: ignore - imports={"react": {ImportVar(tag="useCallback")}}, + VarData( + imports={"react": [ImportVar(tag="useCallback")]}, ), ) diff --git a/reflex/components/core/__init__.py b/reflex/components/core/__init__.py index 80c73add87..877d27739d 100644 --- a/reflex/components/core/__init__.py +++ b/reflex/components/core/__init__.py @@ -1,7 +1,12 @@ """Core Reflex components.""" from . import layout as layout -from .banner import ConnectionBanner, ConnectionModal, ConnectionPulser +from .banner import ( + ConnectionBanner, + ConnectionModal, + ConnectionPulser, + ConnectionToaster, +) from .colors import color from .cond import Cond, color_mode_cond, cond from .debounce import DebounceInput @@ -26,6 +31,7 @@ connection_banner = ConnectionBanner.create connection_modal = ConnectionModal.create +connection_toaster = ConnectionToaster.create connection_pulser = ConnectionPulser.create debounce_input = DebounceInput.create foreach = Foreach.create diff --git a/reflex/components/core/banner.py b/reflex/components/core/banner.py index 0c781fba8b..c6250743cf 100644 --- a/reflex/components/core/banner.py +++ b/reflex/components/core/banner.py @@ -16,8 +16,11 @@ ) from reflex.components.radix.themes.layout import Flex from reflex.components.radix.themes.typography.text import Text +from reflex.components.sonner.toast import Toaster, ToastProps from reflex.constants import Dirs, Hooks, Imports +from reflex.constants.compiler import CompileVars from reflex.utils import imports +from reflex.utils.serializers import serialize from reflex.vars import Var, VarData connect_error_var_data: VarData = VarData( # type: ignore @@ -25,27 +28,38 @@ hooks={Hooks.EVENTS: None}, ) +connect_errors: Var = Var.create_safe( + value=CompileVars.CONNECT_ERROR, + _var_is_local=True, + _var_is_string=False, + _var_data=connect_error_var_data, +) + connection_error: Var = Var.create_safe( value="(connectErrors.length > 0) ? connectErrors[connectErrors.length - 1].message : ''", _var_is_local=False, _var_is_string=False, -)._replace(merge_var_data=connect_error_var_data) + _var_data=connect_error_var_data, +) connection_errors_count: Var = Var.create_safe( value="connectErrors.length", _var_is_string=False, _var_is_local=False, -)._replace(merge_var_data=connect_error_var_data) + _var_data=connect_error_var_data, +) has_connection_errors: Var = Var.create_safe( value="connectErrors.length > 0", _var_is_string=False, -)._replace(_var_type=bool, merge_var_data=connect_error_var_data) + _var_data=connect_error_var_data, +).to(bool) has_too_many_connection_errors: Var = Var.create_safe( value="connectErrors.length >= 2", _var_is_string=False, -)._replace(_var_type=bool, merge_var_data=connect_error_var_data) + _var_data=connect_error_var_data, +).to(bool) class WebsocketTargetURL(Bare): @@ -81,6 +95,64 @@ def default_connection_error() -> list[str | Var | Component]: ] +class ConnectionToaster(Toaster): + """A connection toaster component.""" + + def add_hooks(self) -> list[str]: + """Add the hooks for the connection toaster. + + Returns: + The hooks for the connection toaster. + """ + toast_id = "websocket-error" + target_url = WebsocketTargetURL.create() + props = ToastProps( # type: ignore + description=Var.create( + f"`Check if server is reachable at ${target_url}`", + _var_is_string=False, + _var_is_local=False, + ), + close_button=True, + duration=120000, + id=toast_id, + ) + hook = Var.create( + f""" +const toast_props = {serialize(props)}; +const [userDismissed, setUserDismissed] = useState(false); +useEffect(() => {{ + if ({has_too_many_connection_errors}) {{ + if (!userDismissed) {{ + toast.error( + `Cannot connect to server: {connection_error}.`, + {{...toast_props, onDismiss: () => setUserDismissed(true)}}, + ) + }} + }} else {{ + toast.dismiss("{toast_id}"); + setUserDismissed(false); // after reconnection reset dismissed state + }} +}}, [{connect_errors}]);""" + ) + + hook._var_data = VarData.merge( # type: ignore + connect_errors._var_data, + VarData( + imports={ + "react": [ + imports.ImportVar(tag="useEffect"), + imports.ImportVar(tag="useState"), + ], + **target_url._get_imports(), + } + ), + ) + return [ + Hooks.EVENTS, + hook, # type: ignore + ] + + class ConnectionBanner(Component): """A connection banner component.""" @@ -158,8 +230,8 @@ def create(cls, **props) -> Component: size=props.pop("size", 32), z_index=props.pop("z_index", 9999), position=props.pop("position", "fixed"), - bottom=props.pop("botton", "30px"), - right=props.pop("right", "30px"), + bottom=props.pop("botton", "33px"), + right=props.pop("right", "33px"), animation=Var.create(f"${{pulse}} 1s infinite", _var_is_string=True), **props, ) @@ -201,6 +273,7 @@ def create(cls, **props) -> Component: has_connection_errors, WifiOffPulse.create(**props), ), + title=f"Connection Error: {connection_error}", position="fixed", width="100vw", height="0", diff --git a/reflex/components/core/banner.pyi b/reflex/components/core/banner.pyi index 43fc53e291..64f9761f9a 100644 --- a/reflex/components/core/banner.pyi +++ b/reflex/components/core/banner.pyi @@ -20,11 +20,15 @@ from reflex.components.radix.themes.components.dialog import ( ) from reflex.components.radix.themes.layout import Flex from reflex.components.radix.themes.typography.text import Text +from reflex.components.sonner.toast import Toaster, ToastProps from reflex.constants import Dirs, Hooks, Imports +from reflex.constants.compiler import CompileVars from reflex.utils import imports +from reflex.utils.serializers import serialize from reflex.vars import Var, VarData connect_error_var_data: VarData +connect_errors: Var connection_error: Var connection_errors_count: Var has_connection_errors: Var @@ -99,6 +103,132 @@ class WebsocketTargetURL(Bare): def default_connection_error() -> list[str | Var | Component]: ... +class ConnectionToaster(Toaster): + def add_hooks(self) -> list[str]: ... + @overload + @classmethod + def create( # type: ignore + cls, + *children, + theme: Optional[Union[Var[str], str]] = None, + rich_colors: Optional[Union[Var[bool], bool]] = None, + expand: Optional[Union[Var[bool], bool]] = None, + visible_toasts: Optional[Union[Var[int], int]] = None, + position: Optional[ + Union[ + Var[ + Literal[ + "top-left", + "top-center", + "top-right", + "bottom-left", + "bottom-center", + "bottom-right", + ] + ], + Literal[ + "top-left", + "top-center", + "top-right", + "bottom-left", + "bottom-center", + "bottom-right", + ], + ] + ] = None, + close_button: Optional[Union[Var[bool], bool]] = None, + offset: Optional[Union[Var[str], str]] = None, + dir: Optional[Union[Var[str], str]] = None, + hotkey: Optional[Union[Var[str], str]] = None, + invert: Optional[Union[Var[bool], bool]] = None, + toast_options: Optional[Union[Var[ToastProps], ToastProps]] = None, + gap: Optional[Union[Var[int], int]] = None, + loading_icon: Optional[Union[Var[Icon], Icon]] = None, + pause_when_page_is_hidden: Optional[Union[Var[bool], bool]] = None, + style: Optional[Style] = None, + key: Optional[Any] = None, + id: Optional[Any] = None, + class_name: Optional[Any] = None, + autofocus: Optional[bool] = None, + custom_attrs: Optional[Dict[str, Union[Var, str]]] = None, + on_blur: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_click: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_context_menu: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_double_click: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_focus: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_mount: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_mouse_down: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_mouse_enter: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_mouse_leave: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_mouse_move: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_mouse_out: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_mouse_over: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_mouse_up: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_scroll: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_unmount: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + **props + ) -> "ConnectionToaster": + """Create the component. + + Args: + *children: The children of the component. + theme: the theme of the toast + rich_colors: whether to show rich colors + expand: whether to expand the toast + visible_toasts: the number of toasts that are currently visible + position: the position of the toast + close_button: whether to show the close button + offset: offset of the toast + dir: directionality of the toast (default: ltr) + hotkey: Keyboard shortcut that will move focus to the toaster area. + invert: Dark toasts in light mode and vice versa. + toast_options: These will act as default options for all toasts. See toast() for all available options. + gap: Gap between toasts when expanded + loading_icon: Changes the default loading icon + pause_when_page_is_hidden: Pauses toast timers when the page is hidden, e.g., when the tab is backgrounded, the browser is minimized, or the OS is locked. + style: The style of the component. + key: A unique key for the component. + id: The id for the component. + class_name: The class name for the component. + autofocus: Whether the component should take the focus once the page is loaded + custom_attrs: custom attribute + **props: The props of the component. + + Returns: + The component. + """ + ... + class ConnectionBanner(Component): @overload @classmethod diff --git a/reflex/components/core/cond.py b/reflex/components/core/cond.py index 0e3e436725..9ace92b98b 100644 --- a/reflex/components/core/cond.py +++ b/reflex/components/core/cond.py @@ -13,7 +13,7 @@ from reflex.vars import BaseVar, Var, VarData _IS_TRUE_IMPORT = { - f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")}, + f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="isTrue")], } diff --git a/reflex/components/core/debounce.py b/reflex/components/core/debounce.py index 5fabd44861..88d1e1f943 100644 --- a/reflex/components/core/debounce.py +++ b/reflex/components/core/debounce.py @@ -109,13 +109,11 @@ def create(cls, *children: Component, **props: Any) -> Component: "{%s}" % (child.alias or child.tag), _var_is_local=False, _var_is_string=False, - )._replace( - _var_type=Type[Component], - merge_var_data=VarData( # type: ignore + _var_data=VarData( imports=child._get_imports(), hooks=child._get_hooks_internal(), ), - ), + ).to(Type[Component]), ) component = super().create(**props) diff --git a/reflex/components/core/foreach.py b/reflex/components/core/foreach.py index 88f2886a8a..9a6765491d 100644 --- a/reflex/components/core/foreach.py +++ b/reflex/components/core/foreach.py @@ -8,6 +8,7 @@ from reflex.components.component import Component from reflex.components.tags import IterTag from reflex.constants import MemoizationMode +from reflex.state import ComponentState from reflex.utils import console from reflex.vars import Var @@ -50,6 +51,7 @@ def create( Raises: ForeachVarError: If the iterable is of type Any. + TypeError: If the render function is a ComponentState. """ if props: console.deprecate( @@ -65,6 +67,15 @@ def create( "(If you are trying to foreach over a state var, add a type annotation to the var). " "See https://reflex.dev/docs/library/layout/foreach/" ) + + if ( + hasattr(render_fn, "__qualname__") + and render_fn.__qualname__ == ComponentState.create.__qualname__ + ): + raise TypeError( + "Using a ComponentState as `render_fn` inside `rx.foreach` is not supported yet." + ) + component = cls( iterable=iterable, render_fn=render_fn, diff --git a/reflex/components/core/upload.py b/reflex/components/core/upload.py index b3ac37c15d..65c441924c 100644 --- a/reflex/components/core/upload.py +++ b/reflex/components/core/upload.py @@ -24,12 +24,12 @@ DEFAULT_UPLOAD_ID: str = "default" -upload_files_context_var_data: VarData = VarData( # type: ignore +upload_files_context_var_data: VarData = VarData( imports={ - "react": {imports.ImportVar(tag="useContext")}, - f"/{Dirs.CONTEXTS_PATH}": { + "react": [imports.ImportVar(tag="useContext")], + f"/{Dirs.CONTEXTS_PATH}": [ imports.ImportVar(tag="UploadFilesContext"), - }, + ], }, hooks={ "const [filesById, setFilesById] = useContext(UploadFilesContext);": None, @@ -118,14 +118,13 @@ def get_upload_dir() -> Path: uploaded_files_url_prefix: Var = Var.create_safe( - "${getBackendURL(env.UPLOAD)}" -)._replace( - merge_var_data=VarData( # type: ignore + "${getBackendURL(env.UPLOAD)}", + _var_data=VarData( imports={ - f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="getBackendURL")}, - "/env.json": {imports.ImportVar(tag="env", is_default=True)}, + f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")], + "/env.json": [imports.ImportVar(tag="env", is_default=True)], } - ) + ), ) diff --git a/reflex/components/el/elements/forms.py b/reflex/components/el/elements/forms.py index a98bd47c7f..37051b2797 100644 --- a/reflex/components/el/elements/forms.py +++ b/reflex/components/el/elements/forms.py @@ -216,13 +216,17 @@ def _get_form_refs(self) -> Dict[str, Any]: if ref.startswith("refs_"): ref_var = Var.create_safe(ref[:-3]).as_ref() form_refs[ref[5:-3]] = Var.create_safe( - f"getRefValues({str(ref_var)})", _var_is_local=False - )._replace(merge_var_data=ref_var._var_data) + f"getRefValues({str(ref_var)})", + _var_is_local=False, + _var_data=ref_var._var_data, + ) else: ref_var = Var.create_safe(ref).as_ref() form_refs[ref[4:]] = Var.create_safe( - f"getRefValue({str(ref_var)})", _var_is_local=False - )._replace(merge_var_data=ref_var._var_data) + f"getRefValue({str(ref_var)})", + _var_is_local=False, + _var_data=ref_var._var_data, + ) return form_refs def _get_vars(self, include_children: bool = True) -> Iterator[Var]: @@ -619,14 +623,16 @@ def _render(self) -> Tag: on_key_down=Var.create_safe( f"(e) => enterKeySubmitOnKeyDown(e, {self.enter_key_submit._var_name_unwrapped})", _var_is_local=False, - )._replace(merge_var_data=self.enter_key_submit._var_data), + _var_data=self.enter_key_submit._var_data, + ) ) if self.auto_height is not None: tag.add_props( on_input=Var.create_safe( f"(e) => autoHeightOnInput(e, {self.auto_height._var_name_unwrapped})", _var_is_local=False, - )._replace(merge_var_data=self.auto_height._var_data), + _var_data=self.auto_height._var_data, + ) ) return tag diff --git a/reflex/components/gridjs/datatable.py b/reflex/components/gridjs/datatable.py index 6c05dfd811..fd0a220212 100644 --- a/reflex/components/gridjs/datatable.py +++ b/reflex/components/gridjs/datatable.py @@ -114,12 +114,14 @@ def _render(self) -> Tag: _var_name=f"{self.data._var_name}.columns", _var_type=List[Any], _var_full_name_needs_state_prefix=True, - )._replace(merge_var_data=self.data._var_data) + _var_data=self.data._var_data, + ) self.data = BaseVar( _var_name=f"{self.data._var_name}.data", _var_type=List[List[Any]], _var_full_name_needs_state_prefix=True, - )._replace(merge_var_data=self.data._var_data) + _var_data=self.data._var_data, + ) if types.is_dataframe(type(self.data)): # If given a pandas df break up the data and columns data = serialize(self.data) diff --git a/reflex/components/radix/themes/components/tabs.py b/reflex/components/radix/themes/components/tabs.py index af1b6b5218..130cfd166a 100644 --- a/reflex/components/radix/themes/components/tabs.py +++ b/reflex/components/radix/themes/components/tabs.py @@ -68,7 +68,7 @@ class TabsTrigger(RadixThemesComponent): _valid_parents: List[str] = ["TabsList"] @classmethod - def create(self, *children, **props) -> Component: + def create(cls, *children, **props) -> Component: """Create a TabsTrigger component. Args: diff --git a/reflex/components/sonner/toast.py b/reflex/components/sonner/toast.py index 23a855aee1..648b0db9c4 100644 --- a/reflex/components/sonner/toast.py +++ b/reflex/components/sonner/toast.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Literal, Optional +from typing import Any, Literal, Optional, Union from reflex.base import Base from reflex.components.component import Component, ComponentNamespace @@ -74,7 +74,7 @@ class ToastProps(PropsBase): """Props for the toast component.""" # Toast's description, renders underneath the title. - description: Optional[str] + description: Optional[Union[str, Var]] # Whether to show the close button. close_button: Optional[bool] @@ -162,7 +162,7 @@ def dict(self, *args, **kwargs) -> dict: class Toaster(Component): """A Toaster Component for displaying toast notifications.""" - library = "sonner@1.4.41" + library: str = "sonner@1.4.41" tag = "Toaster" @@ -209,12 +209,15 @@ class Toaster(Component): pause_when_page_is_hidden: Var[bool] def _get_hooks(self) -> Var[str]: - hook = Var.create_safe(f"{toast_ref} = toast", _var_is_local=True) - hook._var_data = VarData( # type: ignore - imports={ - "/utils/state": [ImportVar(tag="refs")], - self.library: [ImportVar(tag="toast", install=False)], - } + hook = Var.create_safe( + f"{toast_ref} = toast", + _var_is_local=True, + _var_data=VarData( + imports={ + "/utils/state": [ImportVar(tag="refs")], + self.library: [ImportVar(tag="toast", install=False)], + } + ), ) return hook diff --git a/reflex/components/sonner/toast.pyi b/reflex/components/sonner/toast.pyi index 5bd6cdeb41..6bc5ab2b5e 100644 --- a/reflex/components/sonner/toast.pyi +++ b/reflex/components/sonner/toast.pyi @@ -7,7 +7,7 @@ from typing import Any, Dict, Literal, Optional, Union, overload from reflex.vars import Var, BaseVar, ComputedVar from reflex.event import EventChain, EventHandler, EventSpec from reflex.style import Style -from typing import Any, Literal, Optional +from typing import Any, Literal, Optional, Union from reflex.base import Base from reflex.components.component import Component, ComponentNamespace from reflex.components.lucide.icon import Icon @@ -37,7 +37,7 @@ class ToastAction(Base): def serialize_action(action: ToastAction) -> dict: ... class ToastProps(PropsBase): - description: Optional[str] + description: Optional[Union[str, Var]] close_button: Optional[bool] invert: Optional[bool] important: Optional[bool] diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index b99e31e8c7..96e8b03ba7 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -103,9 +103,9 @@ class Imports(SimpleNamespace): """Common sets of import vars.""" EVENTS = { - "react": {ImportVar(tag="useContext")}, - f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")}, - f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)}, + "react": [ImportVar(tag="useContext")], + f"/{Dirs.CONTEXTS_PATH}": [ImportVar(tag="EventLoopContext")], + f"/{Dirs.STATE_PATH}": [ImportVar(tag=CompileVars.TO_EVENT)], } diff --git a/reflex/event.py b/reflex/event.py index 3f1487c199..96a59fdc10 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -467,18 +467,27 @@ def fn(): ) -def redirect(path: str | Var[str], external: Optional[bool] = False) -> EventSpec: +def redirect( + path: str | Var[str], + external: Optional[bool] = False, + replace: Optional[bool] = False, +) -> EventSpec: """Redirect to a new path. Args: path: The path to redirect to. external: Whether to open in new tab or not. + replace: If True, the current page will not create a new history entry. Returns: An event to redirect to the path. """ return server_side( - "_redirect", get_fn_signature(redirect), path=path, external=external + "_redirect", + get_fn_signature(redirect), + path=path, + external=external, + replace=replace, ) diff --git a/reflex/experimental/__init__.py b/reflex/experimental/__init__.py index 29bda85453..6972fdfe0d 100644 --- a/reflex/experimental/__init__.py +++ b/reflex/experimental/__init__.py @@ -8,6 +8,7 @@ from ..utils.console import warn from . import hooks as hooks +from .client_state import ClientStateVar as ClientStateVar from .layout import layout as layout from .misc import run_in_thread as run_in_thread @@ -16,6 +17,7 @@ ) _x = SimpleNamespace( + client_state=ClientStateVar.create, hooks=hooks, layout=layout, progress=progress, diff --git a/reflex/experimental/client_state.py b/reflex/experimental/client_state.py new file mode 100644 index 0000000000..93405b29fb --- /dev/null +++ b/reflex/experimental/client_state.py @@ -0,0 +1,198 @@ +"""Handle client side state with `useState`.""" + +import dataclasses +import sys +from typing import Any, Callable, Optional, Type + +from reflex import constants +from reflex.event import EventChain, EventHandler, EventSpec, call_script +from reflex.utils.imports import ImportVar +from reflex.vars import Var, VarData + + +def _client_state_ref(var_name: str) -> str: + """Get the ref path for a ClientStateVar. + + Args: + var_name: The name of the variable. + + Returns: + An accessor for ClientStateVar ref as a string. + """ + return f"refs['_client_state_{var_name}']" + + +@dataclasses.dataclass( + eq=False, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ClientStateVar(Var): + """A Var that exists on the client via useState.""" + + # The name of the var. + _var_name: str = dataclasses.field() + + # Track the names of the getters and setters + _setter_name: str = dataclasses.field() + _getter_name: str = dataclasses.field() + + # The type of the var. + _var_type: Type = dataclasses.field(default=Any) + + # Whether this is a local javascript variable. + _var_is_local: bool = dataclasses.field(default=False) + + # Whether the var is a string literal. + _var_is_string: bool = dataclasses.field(default=False) + + # _var_full_name should be prefixed with _var_state + _var_full_name_needs_state_prefix: bool = dataclasses.field(default=False) + + # Extra metadata associated with the Var + _var_data: Optional[VarData] = dataclasses.field(default=None) + + def __hash__(self) -> int: + """Define a hash function for a var. + + Returns: + The hash of the var. + """ + return hash( + (self._var_name, str(self._var_type), self._getter_name, self._setter_name) + ) + + @classmethod + def create(cls, var_name, default=None) -> "ClientStateVar": + """Create a local_state Var that can be accessed and updated on the client. + + The `ClientStateVar` should be included in the highest parent component + that contains the components which will access and manipulate the client + state. It has no visual rendering, including it ensures that the + `useState` hook is called in the correct scope. + + To render the var in a component, use the `value` property. + + To update the var in a component, use the `set` property. + + To access the var in an event handler, use the `retrieve` method with + `callback` set to the event handler which should receive the value. + + To update the var in an event handler, use the `push` method with the + value to update. + + Args: + var_name: The name of the variable. + default: The default value of the variable. + + Returns: + ClientStateVar + """ + if default is None: + default_var = Var.create_safe("", _var_is_local=False, _var_is_string=False) + elif not isinstance(default, Var): + default_var = Var.create_safe(default) + else: + default_var = default + setter_name = f"set{var_name.capitalize()}" + return cls( + _var_name="", + _setter_name=setter_name, + _getter_name=var_name, + _var_is_local=False, + _var_is_string=False, + _var_type=default_var._var_type, + _var_data=VarData.merge( + default_var._var_data, + VarData( # type: ignore + hooks={ + f"const [{var_name}, {setter_name}] = useState({default_var._var_name_unwrapped})": None, + f"{_client_state_ref(var_name)} = {var_name}": None, + f"{_client_state_ref(setter_name)} = {setter_name}": None, + }, + imports={ + "react": [ImportVar(tag="useState", install=False)], + f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")], + }, + ), + ), + ) + + @property + def value(self) -> Var: + """Get a placeholder for the Var. + + This property can only be rendered on the frontend. + + To access the value in a backend event handler, see `retrieve`. + + Returns: + an accessor for the client state variable. + """ + return ( + Var.create_safe( + _client_state_ref(self._getter_name), + _var_is_local=False, + _var_is_string=False, + ) + .to(self._var_type) + ._replace( + merge_var_data=VarData( # type: ignore + imports={ + f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")], + } + ) + ) + ) + + @property + def set(self) -> Var: + """Set the value of the client state variable. + + This property can only be attached to a frontend event trigger. + + To set a value from a backend event handler, see `push`. + + Returns: + A special EventChain Var which will set the value when triggered. + """ + return ( + Var.create_safe( + _client_state_ref(self._setter_name), + _var_is_local=False, + _var_is_string=False, + ) + .to(EventChain) + ._replace( + merge_var_data=VarData( # type: ignore + imports={ + f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")], + } + ) + ) + ) + + def retrieve(self, callback: EventHandler | Callable | None = None) -> EventSpec: + """Pass the value of the client state variable to a backend EventHandler. + + The event handler must `yield` or `return` the EventSpec to trigger the event. + + Args: + callback: The callback to pass the value to. + + Returns: + An EventSpec which will retrieve the value when triggered. + """ + return call_script(_client_state_ref(self._getter_name), callback=callback) + + def push(self, value: Any) -> EventSpec: + """Push a value to the client state variable from the backend. + + The event handler must `yield` or `return` the EventSpec to trigger the event. + + Args: + value: The value to update. + + Returns: + An EventSpec which will push the value when triggered. + """ + return call_script(f"{_client_state_ref(self._setter_name)}({value})") diff --git a/reflex/state.py b/reflex/state.py index 287f70073a..86a222b666 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -550,7 +550,9 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): for name, value in mixin.__dict__.items(): if isinstance(value, ComputedVar): fget = cls._copy_fn(value.fget) - newcv = ComputedVar(fget=fget, _var_name=value._var_name) + newcv = value._replace(fget=fget) + # cleanup refs to mixin cls in var_data + newcv._var_data = None newcv._var_set_state(cls) setattr(cls, name, newcv) cls.computed_vars[newcv._var_name] = newcv diff --git a/reflex/style.py b/reflex/style.py index e48aa3dd82..21a601dd0a 100644 --- a/reflex/style.py +++ b/reflex/style.py @@ -16,10 +16,10 @@ DARK_COLOR_MODE: str = "dark" # Reference the global ColorModeContext -color_mode_var_data = VarData( # type: ignore +color_mode_var_data = VarData( imports={ - f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")}, - "react": {ImportVar(tag="useContext")}, + f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="ColorModeContext")], + "react": [ImportVar(tag="useContext")], }, hooks={ f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)": None, diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index cd4739c445..1852f0434f 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -413,7 +413,7 @@ def initialize_gitignore( # Write files to the .gitignore file. with open(gitignore_file, "w", newline="\n") as f: console.debug(f"Creating {gitignore_file}") - f.write(f"{(path_ops.join(sorted(files_to_ignore))).lstrip()}") + f.write(f"{(path_ops.join(sorted(files_to_ignore))).lstrip()}\n") def initialize_requirements_txt(): diff --git a/reflex/utils/types.py b/reflex/utils/types.py index f75e20dcc4..6dd120e3dd 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -44,6 +44,22 @@ from reflex.base import Base from reflex.utils import console, serializers +if sys.version_info >= (3, 12): + from typing import override +else: + + def override(func: Callable) -> Callable: + """Fallback for @override decorator. + + Args: + func: The function to decorate. + + Returns: + The unmodified function. + """ + return func + + # Potential GenericAlias types for isinstance checks. GenericAliasTypes = [_GenericAlias] diff --git a/reflex/vars.py b/reflex/vars.py index 6ac78706a1..be6aa7eb89 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -39,10 +39,12 @@ # This module used to export ImportVar itself, so we still import it for export here from reflex.utils.imports import ImportDict, ImportVar +from reflex.utils.types import override if TYPE_CHECKING: from reflex.state import BaseState + # Set of unique variable names. USED_VARIABLES = set() @@ -341,7 +343,11 @@ class Var: @classmethod def create( - cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False + cls, + value: Any, + _var_is_local: bool = True, + _var_is_string: bool = False, + _var_data: Optional[VarData] = None, ) -> Var | None: """Create a var from a value. @@ -349,6 +355,7 @@ def create( value: The value to create the var from. _var_is_local: Whether the var is local. _var_is_string: Whether the var is a string literal. + _var_data: Additional hooks and imports associated with the Var. Returns: The var. @@ -365,9 +372,8 @@ def create( return value # Try to pull the imports and hooks from contained values. - _var_data = None if not isinstance(value, str): - _var_data = VarData.merge(*_extract_var_data(value)) + _var_data = VarData.merge(*_extract_var_data(value), _var_data) # Try to serialize the value. type_ = type(value) @@ -388,7 +394,11 @@ def create( @classmethod def create_safe( - cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False + cls, + value: Any, + _var_is_local: bool = True, + _var_is_string: bool = False, + _var_data: Optional[VarData] = None, ) -> Var: """Create a var from a value, asserting that it is not None. @@ -396,6 +406,7 @@ def create_safe( value: The value to create the var from. _var_is_local: Whether the var is local. _var_is_string: Whether the var is a string literal. + _var_data: Additional hooks and imports associated with the Var. Returns: The var. @@ -404,6 +415,7 @@ def create_safe( value, _var_is_local=_var_is_local, _var_is_string=_var_is_string, + _var_data=_var_data, ) assert var is not None return var @@ -822,19 +834,19 @@ def get_operand_full_name(operand): if invoke_fn: # invoke the function on left operand. operation_name = ( - f"{left_operand_full_name}.{fn}({right_operand_full_name})" - ) # type: ignore + f"{left_operand_full_name}.{fn}({right_operand_full_name})" # type: ignore + ) else: # pass the operands as arguments to the function. operation_name = ( - f"{left_operand_full_name} {op} {right_operand_full_name}" - ) # type: ignore + f"{left_operand_full_name} {op} {right_operand_full_name}" # type: ignore + ) operation_name = f"{fn}({operation_name})" else: # apply operator to operands (left operand right_operand) operation_name = ( - f"{left_operand_full_name} {op} {right_operand_full_name}" - ) # type: ignore + f"{left_operand_full_name} {op} {right_operand_full_name}" # type: ignore + ) operation_name = format.wrap(operation_name, "(") else: # apply operator to left operand ( left_operand) @@ -1872,6 +1884,32 @@ def __init__( kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type()) BaseVar.__init__(self, **kwargs) # type: ignore + @override + def _replace(self, merge_var_data=None, **kwargs: Any) -> ComputedVar: + """Replace the attributes of the ComputedVar. + + Args: + merge_var_data: VarData to merge into the existing VarData. + **kwargs: Var fields to update. + + Returns: + The new ComputedVar instance. + """ + return ComputedVar( + fget=kwargs.get("fget", self.fget), + initial_value=kwargs.get("initial_value", self._initial_value), + cache=kwargs.get("cache", self._cache), + _var_name=kwargs.get("_var_name", self._var_name), + _var_type=kwargs.get("_var_type", self._var_type), + _var_is_local=kwargs.get("_var_is_local", self._var_is_local), + _var_is_string=kwargs.get("_var_is_string", self._var_is_string), + _var_full_name_needs_state_prefix=kwargs.get( + "_var_full_name_needs_state_prefix", + self._var_full_name_needs_state_prefix, + ), + _var_data=VarData.merge(self._var_data, merge_var_data), + ) + @property def _cache_attr(self) -> str: """Get the attribute used to cache the value on the instance. diff --git a/reflex/vars.pyi b/reflex/vars.pyi index fb2ed46573..169e2d919c 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -34,10 +34,10 @@ def _decode_var(value: str) -> tuple[VarData, str]: ... def _extract_var_data(value: Iterable) -> list[VarData | None]: ... class VarData(Base): - state: str - imports: dict[str, set[ImportVar]] - hooks: Dict[str, None] - interpolations: List[Tuple[int, int]] + state: str = "" + imports: dict[str, List[ImportVar]] = {} + hooks: Dict[str, None] = {} + interpolations: List[Tuple[int, int]] = [] @classmethod def merge(cls, *others: VarData | None) -> VarData | None: ... @@ -50,11 +50,11 @@ class Var: _var_data: VarData | None = None @classmethod def create( - cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False + cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None, ) -> Optional[Var]: ... @classmethod def create_safe( - cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False + cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None, ) -> Var: ... @classmethod def __class_getitem__(cls, type_: Type) -> _GenericAlias: ... @@ -139,6 +139,7 @@ class ComputedVar(Var): def _cache_attr(self) -> str: ... def __get__(self, instance, owner): ... def _deps(self, objclass: Type, obj: Optional[FunctionType] = ...) -> Set[str]: ... + def _replace(self, merge_var_data=None, **kwargs: Any) -> ComputedVar: ... def mark_dirty(self, instance) -> None: ... def _determine_var_type(self) -> Type: ... @overload diff --git a/tests/components/core/test_foreach.py b/tests/components/core/test_foreach.py index 9691ed50e6..6c41845903 100644 --- a/tests/components/core/test_foreach.py +++ b/tests/components/core/test_foreach.py @@ -3,8 +3,9 @@ import pytest from reflex.components import box, el, foreach, text +from reflex.components.component import Component from reflex.components.core.foreach import Foreach, ForeachRenderError, ForeachVarError -from reflex.state import BaseState +from reflex.state import BaseState, ComponentState from reflex.vars import Var @@ -37,6 +38,25 @@ class ForEachState(BaseState): color_index_tuple: Tuple[int, str] = (0, "red") +class TestComponentState(ComponentState): + """A test component state.""" + + foo: bool + + @classmethod + def get_component(cls, *children, **props) -> Component: + """Get the component. + + Args: + children: The children components. + props: The component props. + + Returns: + The component. + """ + return el.div(*children, **props) + + def display_color(color): assert color._var_type == str return box(text(color)) @@ -252,3 +272,12 @@ def test_foreach_component_styles(): ) component._add_style_recursive({box: {"color": "red"}}) assert 'css={{"color": "red"}}' in str(component) + + +def test_foreach_component_state(): + """Test that using a component state to render in the foreach raises an error.""" + with pytest.raises(TypeError): + Foreach.create( + ForEachState.colors_list, + TestComponentState.create, + ) diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 6245746c98..e4d7205d73 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -1063,7 +1063,7 @@ def test_stateful_banner(): TEST_VAR = Var.create_safe("test")._replace( merge_var_data=VarData( hooks={"useTest": None}, - imports={"test": {ImportVar(tag="test")}}, + imports={"test": [ImportVar(tag="test")]}, state="Test", interpolations=[], ) @@ -1953,6 +1953,44 @@ def add_custom_code(self): } +def test_component_add_hooks_var(): + class HookComponent(Component): + def add_hooks(self): + return [ + "const hook3 = useRef(null)", + "const hook1 = 42", + Var.create( + "useEffect(() => () => {}, [])", + _var_data=VarData( + hooks={ + "const hook2 = 43": None, + "const hook3 = useRef(null)": None, + }, + imports={"react": [ImportVar(tag="useEffect")]}, + ), + ), + Var.create( + "const hook3 = useRef(null)", + _var_data=VarData( + imports={"react": [ImportVar(tag="useRef")]}, + ), + ), + ] + + assert list(HookComponent()._get_all_hooks()) == [ + "const hook3 = useRef(null)", + "const hook1 = 42", + "const hook2 = 43", + "useEffect(() => () => {}, [])", + ] + imports = HookComponent()._get_all_imports() + assert len(imports) == 1 + assert "react" in imports + assert len(imports["react"]) == 2 + assert ImportVar(tag="useRef") in imports["react"] + assert ImportVar(tag="useEffect") in imports["react"] + + def test_add_style_embedded_vars(test_state: BaseState): """Test that add_style works with embedded vars when returning a plain dict. diff --git a/tests/test_event.py b/tests/test_event.py index 8852631576..284542a434 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -158,12 +158,29 @@ def test_fn_with_args(_, arg1, arg2): @pytest.mark.parametrize( "input,output", [ - (("/path", None), 'Event("_redirect", {path:`/path`,external:false})'), - (("/path", True), 'Event("_redirect", {path:`/path`,external:true})'), - (("/path", False), 'Event("_redirect", {path:`/path`,external:false})'), ( - (Var.create_safe("path"), None), - 'Event("_redirect", {path:path,external:false})', + ("/path", None, None), + 'Event("_redirect", {path:`/path`,external:false,replace:false})', + ), + ( + ("/path", True, None), + 'Event("_redirect", {path:`/path`,external:true,replace:false})', + ), + ( + ("/path", False, None), + 'Event("_redirect", {path:`/path`,external:false,replace:false})', + ), + ( + (Var.create_safe("path"), None, None), + 'Event("_redirect", {path:path,external:false,replace:false})', + ), + ( + ("/path", None, True), + 'Event("_redirect", {path:`/path`,external:false,replace:true})', + ), + ( + ("/path", True, True), + 'Event("_redirect", {path:`/path`,external:true,replace:true})', ), ], ) @@ -174,11 +191,13 @@ def test_event_redirect(input, output): input: The input for running the test. output: The expected output to validate the test. """ - path, external = input - if external is None: - spec = event.redirect(path) - else: - spec = event.redirect(path, external=external) + path, external, replace = input + kwargs = {} + if external is not None: + kwargs["external"] = external + if replace is not None: + kwargs["replace"] = replace + spec = event.redirect(path, **kwargs) assert isinstance(spec, EventSpec) assert spec.handler.fn.__qualname__ == "_redirect"