diff --git a/reflex/app.py b/reflex/app.py index 9e5c2541ac..43edee983f 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -9,6 +9,7 @@ import functools import inspect import io +import json import multiprocessing import os import platform @@ -1096,6 +1097,7 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: if delta: # When the state is modified reset dirty status and emit the delta to the frontend. state._clean() + print(dir(state.router)) await self.event_namespace.emit_update( update=StateUpdate(delta=delta), sid=state.router.session.session_id, @@ -1531,8 +1533,9 @@ async def on_event(self, sid, data): sid: The Socket.IO session id. data: The event data. """ + fields = json.loads(data) # Get the event. - event = Event.parse_raw(data) + event = Event(**{k: v for k, v in fields.items() if k != "handler"}) self.token_to_sid[event.token] = sid self.sid_to_token[sid] = event.token diff --git a/reflex/components/component.py b/reflex/components/component.py index f7e25e5534..d9a8ab332e 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -4,6 +4,7 @@ import copy import typing +import warnings from abc import ABC, abstractmethod from functools import lru_cache, wraps from hashlib import md5 @@ -169,6 +170,8 @@ def evaluate_style_namespaces(style: ComponentStyle) -> dict: ] ComponentChild = Union[types.PrimitiveType, Var, BaseComponent] +warnings.filterwarnings("ignore", message="fields may not start with an underscore") + class Component(BaseComponent, ABC): """A component with style, event trigger and other props.""" @@ -195,7 +198,7 @@ class Component(BaseComponent, ABC): class_name: Any = None # Special component props. - special_props: Set[ImmutableVar] = set() + special_props: List[ImmutableVar] = [] # Whether the component should take the focus once the page is loaded autofocus: bool = False @@ -656,7 +659,7 @@ def _render(self, props: dict[str, Any] | None = None) -> Tag: """ # Create the base tag. tag = Tag( - name=self.tag if not self.alias else self.alias, + name=(self.tag if not self.alias else self.alias) or "", special_props=self.special_props, ) @@ -2245,7 +2248,7 @@ def render(self) -> dict: Returns: The tag to render. """ - return dict(Tag(name=self.tag)) + return dict(Tag(name=self.tag or "")) def __str__(self) -> str: """Represent the component in React. diff --git a/reflex/components/core/upload.py b/reflex/components/core/upload.py index b6fe1024ae..7501934d86 100644 --- a/reflex/components/core/upload.py +++ b/reflex/components/core/upload.py @@ -247,9 +247,9 @@ def create(cls, *children, **props) -> Component: } # The file input to use. upload = Input.create(type="file") - upload.special_props = { + upload.special_props = [ ImmutableVar(_var_name="{...getInputProps()}", _var_type=None) - } + ] # The dropzone to use. zone = Box.create( @@ -257,9 +257,9 @@ def create(cls, *children, **props) -> Component: *children, **{k: v for k, v in props.items() if k not in supported_props}, ) - zone.special_props = { + zone.special_props = [ ImmutableVar(_var_name="{...getRootProps()}", _var_type=None) - } + ] # Create the component. upload_props["id"] = props.get("id", DEFAULT_UPLOAD_ID) diff --git a/reflex/components/el/elements/metadata.py b/reflex/components/el/elements/metadata.py index fc8e1d9d6d..df5e1e9017 100644 --- a/reflex/components/el/elements/metadata.py +++ b/reflex/components/el/elements/metadata.py @@ -1,6 +1,6 @@ """Element classes. This is an auto-generated file. Do not edit. See ../generate.py.""" -from typing import Set, Union +from typing import List, Union from reflex.components.el.element import Element from reflex.ivars.base import ImmutableVar @@ -90,9 +90,9 @@ class StyleEl(Element): # noqa: E742 media: Var[Union[str, int, bool]] - special_props: Set[ImmutableVar] = { + special_props: List[ImmutableVar] = [ ImmutableVar.create_safe("suppressHydrationWarning") - } + ] base = Base.create diff --git a/reflex/components/markdown/markdown.py b/reflex/components/markdown/markdown.py index b44ca25dda..83f647f4c3 100644 --- a/reflex/components/markdown/markdown.py +++ b/reflex/components/markdown/markdown.py @@ -195,17 +195,17 @@ def get_component(self, tag: str, **props) -> Component: if tag not in self.component_map: raise ValueError(f"No markdown component found for tag: {tag}.") - special_props = {_PROPS_IN_TAG} + special_props = [_PROPS_IN_TAG] children = [_CHILDREN] # For certain tags, the props from the markdown renderer are not actually valid for the component. if tag in NO_PROPS_TAGS: - special_props = set() + special_props = [] # If the children are set as a prop, don't pass them as children. children_prop = props.pop("children", None) if children_prop is not None: - special_props.add( + special_props.append( ImmutableVar.create_safe(f"children={{{str(children_prop)}}}") ) children = [] diff --git a/reflex/components/moment/moment.py b/reflex/components/moment/moment.py index 958ba6c57b..54411f870b 100644 --- a/reflex/components/moment/moment.py +++ b/reflex/components/moment/moment.py @@ -1,26 +1,27 @@ """Moment component for humanized date rendering.""" +import dataclasses from typing import List, Optional -from reflex.base import Base from reflex.components.component import Component, NoSSRComponent from reflex.event import EventHandler from reflex.utils.imports import ImportDict from reflex.vars import Var -class MomentDelta(Base): +@dataclasses.dataclass(frozen=True) +class MomentDelta: """A delta used for add/subtract prop in Moment.""" - years: Optional[int] - quarters: Optional[int] - months: Optional[int] - weeks: Optional[int] - days: Optional[int] - hours: Optional[int] - minutess: Optional[int] - seconds: Optional[int] - milliseconds: Optional[int] + years: Optional[int] = dataclasses.field(default=None) + quarters: Optional[int] = dataclasses.field(default=None) + months: Optional[int] = dataclasses.field(default=None) + weeks: Optional[int] = dataclasses.field(default=None) + days: Optional[int] = dataclasses.field(default=None) + hours: Optional[int] = dataclasses.field(default=None) + minutess: Optional[int] = dataclasses.field(default=None) + seconds: Optional[int] = dataclasses.field(default=None) + milliseconds: Optional[int] = dataclasses.field(default=None) class Moment(NoSSRComponent): diff --git a/reflex/components/moment/moment.pyi b/reflex/components/moment/moment.pyi index 168a239d79..2a19bcd01b 100644 --- a/reflex/components/moment/moment.pyi +++ b/reflex/components/moment/moment.pyi @@ -3,9 +3,9 @@ # ------------------- DO NOT EDIT ---------------------- # This file was generated by `reflex/utils/pyi_generator.py`! # ------------------------------------------------------ +import dataclasses from typing import Any, Callable, Dict, Optional, Union, overload -from reflex.base import Base from reflex.components.component import NoSSRComponent from reflex.event import EventHandler, EventSpec from reflex.ivars.base import ImmutableVar @@ -13,7 +13,8 @@ from reflex.style import Style from reflex.utils.imports import ImportDict from reflex.vars import Var -class MomentDelta(Base): +@dataclasses.dataclass(frozen=True) +class MomentDelta: years: Optional[int] quarters: Optional[int] months: Optional[int] diff --git a/reflex/components/plotly/plotly.py b/reflex/components/plotly/plotly.py index ed7040d1c8..c226e6e953 100644 --- a/reflex/components/plotly/plotly.py +++ b/reflex/components/plotly/plotly.py @@ -267,7 +267,7 @@ def _render(self): template_dict = LiteralVar.create({"layout": {"template": self.template}}) merge_dicts.append(template_dict.without_data()) if merge_dicts: - tag.special_props.add( + tag.special_props.append( # Merge all dictionaries and spread the result over props. ImmutableVar.create_safe( f"{{...mergician({str(figure)}," @@ -276,5 +276,5 @@ def _render(self): ) else: # Spread the figure dict over props, nothing to merge. - tag.special_props.add(ImmutableVar.create_safe(f"{{...{str(figure)}}}")) + tag.special_props.append(ImmutableVar.create_safe(f"{{...{str(figure)}}}")) return tag diff --git a/reflex/components/tags/cond_tag.py b/reflex/components/tags/cond_tag.py index 3143890c4d..7bdf9a3c77 100644 --- a/reflex/components/tags/cond_tag.py +++ b/reflex/components/tags/cond_tag.py @@ -1,19 +1,22 @@ """Tag to conditionally render components.""" +import dataclasses from typing import Any, Dict, Optional from reflex.components.tags.tag import Tag +from reflex.ivars.base import LiteralVar from reflex.vars import Var +@dataclasses.dataclass() class CondTag(Tag): """A conditional tag.""" # The condition to determine which component to render. - cond: Var[Any] + cond: Var[Any] = dataclasses.field(default_factory=lambda: LiteralVar.create(True)) # The code to render if the condition is true. - true_value: Dict + true_value: Dict = dataclasses.field(default_factory=dict) # The code to render if the condition is false. - false_value: Optional[Dict] + false_value: Optional[Dict] = None diff --git a/reflex/components/tags/iter_tag.py b/reflex/components/tags/iter_tag.py index ee7a636284..ff6925f56f 100644 --- a/reflex/components/tags/iter_tag.py +++ b/reflex/components/tags/iter_tag.py @@ -2,31 +2,36 @@ from __future__ import annotations +import dataclasses import inspect from typing import TYPE_CHECKING, Any, Callable, List, Tuple, Type, Union, get_args from reflex.components.tags.tag import Tag from reflex.ivars.base import ImmutableVar -from reflex.vars import Var +from reflex.ivars.sequence import LiteralArrayVar +from reflex.vars import Var, get_unique_variable_name if TYPE_CHECKING: from reflex.components.component import Component +@dataclasses.dataclass() class IterTag(Tag): """An iterator tag.""" # The var to iterate over. - iterable: Var[List] + iterable: Var[List] = dataclasses.field( + default_factory=lambda: LiteralArrayVar.create([]) + ) # The component render function for each item in the iterable. - render_fn: Callable + render_fn: Callable = dataclasses.field(default_factory=lambda: lambda x: x) # The name of the arg var. - arg_var_name: str + arg_var_name: str = dataclasses.field(default_factory=get_unique_variable_name) # The name of the index var. - index_var_name: str + index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name) def get_iterable_var_type(self) -> Type: """Get the type of the iterable var. diff --git a/reflex/components/tags/match_tag.py b/reflex/components/tags/match_tag.py index c2f6649d51..b67ed62ccd 100644 --- a/reflex/components/tags/match_tag.py +++ b/reflex/components/tags/match_tag.py @@ -1,19 +1,22 @@ """Tag to conditionally match cases.""" +import dataclasses from typing import Any, List from reflex.components.tags.tag import Tag +from reflex.ivars.base import LiteralVar from reflex.vars import Var +@dataclasses.dataclass() class MatchTag(Tag): """A match tag.""" # The condition to determine which case to match. - cond: Var[Any] + cond: Var[Any] = dataclasses.field(default_factory=lambda: LiteralVar.create(True)) # The list of match cases to be matched. - match_cases: List[Any] + match_cases: List[Any] = dataclasses.field(default_factory=list) # The catchall case to match. - default: Any + default: Any = dataclasses.field(default=LiteralVar.create(None)) diff --git a/reflex/components/tags/tag.py b/reflex/components/tags/tag.py index 810da30f95..8c97d72c53 100644 --- a/reflex/components/tags/tag.py +++ b/reflex/components/tags/tag.py @@ -2,22 +2,23 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, Set, Tuple, Union +import dataclasses +from typing import Any, Dict, List, Optional, Tuple, Union -from reflex.base import Base from reflex.event import EventChain from reflex.ivars.base import ImmutableVar, LiteralVar from reflex.utils import format, types -class Tag(Base): +@dataclasses.dataclass() +class Tag: """A React tag.""" # The name of the tag. name: str = "" # The props of the tag. - props: Dict[str, Any] = {} + props: Dict[str, Any] = dataclasses.field(default_factory=dict) # The inner contents of the tag. contents: str = "" @@ -26,25 +27,18 @@ class Tag(Base): args: Optional[Tuple[str, ...]] = None # Special props that aren't key value pairs. - special_props: Set[ImmutableVar] = set() + special_props: List[ImmutableVar] = dataclasses.field(default_factory=list) # The children components. - children: List[Any] = [] - - def __init__(self, *args, **kwargs): - """Initialize the tag. - - Args: - *args: Args to initialize the tag. - **kwargs: Kwargs to initialize the tag. - """ - # Convert any props to vars. - if "props" in kwargs: - kwargs["props"] = { - name: LiteralVar.create(value) - for name, value in kwargs["props"].items() - } - super().__init__(*args, **kwargs) + children: List[Any] = dataclasses.field(default_factory=list) + + def __post_init__(self): + """Post initialize the tag.""" + object.__setattr__( + self, + "props", + {name: LiteralVar.create(value) for name, value in self.props.items()}, + ) def format_props(self) -> List: """Format the tag's props. @@ -54,6 +48,29 @@ def format_props(self) -> List: """ return format.format_props(*self.special_props, **self.props) + def set(self, **kwargs: Any): + """Set the tag's fields. + + Args: + kwargs: The fields to set. + + Returns: + The tag with the fields + """ + for name, value in kwargs.items(): + setattr(self, name, value) + + return self + + def __iter__(self): + """Iterate over the tag's fields. + + Yields: + Tuple[str, Any]: The field name and value. + """ + for field in dataclasses.fields(self): + yield field.name, getattr(self, field.name) + def add_props(self, **kwargs: Optional[Any]) -> Tag: """Add props to the tag. diff --git a/reflex/event.py b/reflex/event.py index 8384f06a86..73fecfc039 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -2,6 +2,7 @@ from __future__ import annotations +import dataclasses import inspect import types import urllib.parse @@ -18,7 +19,6 @@ ) from reflex import constants -from reflex.base import Base from reflex.ivars.base import ImmutableVar, LiteralVar from reflex.ivars.function import FunctionStringVar, FunctionVar from reflex.ivars.object import ObjectVar @@ -33,7 +33,11 @@ from typing_extensions import Annotated -class Event(Base): +@dataclasses.dataclass( + init=True, + frozen=True, +) +class Event: """An event that describes any state change in the app.""" # The token to specify the client that the event is for. @@ -43,10 +47,10 @@ class Event(Base): name: str # The routing data where event occurred - router_data: Dict[str, Any] = {} + router_data: Dict[str, Any] = dataclasses.field(default_factory=dict) # The event payload. - payload: Dict[str, Any] = {} + payload: Dict[str, Any] = dataclasses.field(default_factory=dict) @property def substate_token(self) -> str: @@ -81,11 +85,15 @@ def background(fn): return fn -class EventActionsMixin(Base): +@dataclasses.dataclass( + init=True, + frozen=True, +) +class EventActionsMixin: """Mixin for DOM event actions.""" # Whether to `preventDefault` or `stopPropagation` on the event. - event_actions: Dict[str, Union[bool, int]] = {} + event_actions: Dict[str, Union[bool, int]] = dataclasses.field(default_factory=dict) @property def stop_propagation(self): @@ -94,8 +102,9 @@ def stop_propagation(self): Returns: New EventHandler-like with stopPropagation set to True. """ - return self.copy( - update={"event_actions": {"stopPropagation": True, **self.event_actions}}, + return dataclasses.replace( + self, + event_actions={"stopPropagation": True, **self.event_actions}, ) @property @@ -105,8 +114,9 @@ def prevent_default(self): Returns: New EventHandler-like with preventDefault set to True. """ - return self.copy( - update={"event_actions": {"preventDefault": True, **self.event_actions}}, + return dataclasses.replace( + self, + event_actions={"preventDefault": True, **self.event_actions}, ) def throttle(self, limit_ms: int): @@ -118,8 +128,9 @@ def throttle(self, limit_ms: int): Returns: New EventHandler-like with throttle set to limit_ms. """ - return self.copy( - update={"event_actions": {"throttle": limit_ms, **self.event_actions}}, + return dataclasses.replace( + self, + event_actions={"throttle": limit_ms, **self.event_actions}, ) def debounce(self, delay_ms: int): @@ -131,26 +142,25 @@ def debounce(self, delay_ms: int): Returns: New EventHandler-like with debounce set to delay_ms. """ - return self.copy( - update={"event_actions": {"debounce": delay_ms, **self.event_actions}}, + return dataclasses.replace( + self, + event_actions={"debounce": delay_ms, **self.event_actions}, ) +@dataclasses.dataclass( + init=True, + frozen=True, +) class EventHandler(EventActionsMixin): """An event handler responds to an event to update the state.""" # The function to call in response to the event. - fn: Any + fn: Any = dataclasses.field(default=None) # The full name of the state class this event handler is attached to. # Empty string means this event handler is a server side event. - state_full_name: str = "" - - class Config: - """The Pydantic config.""" - - # Needed to allow serialization of Callable. - frozen = True + state_full_name: str = dataclasses.field(default="") @classmethod def __class_getitem__(cls, args_spec: str) -> Annotated: @@ -215,6 +225,10 @@ def __call__(self, *args: Any) -> EventSpec: ) +@dataclasses.dataclass( + init=True, + frozen=True, +) class EventSpec(EventActionsMixin): """An event specification. @@ -223,19 +237,37 @@ class EventSpec(EventActionsMixin): """ # The event handler. - handler: EventHandler + handler: EventHandler = dataclasses.field(default=None) # type: ignore # The handler on the client to process event. - client_handler_name: str = "" + client_handler_name: str = dataclasses.field(default="") # The arguments to pass to the function. - args: Tuple[Tuple[ImmutableVar, ImmutableVar], ...] = () + args: Tuple[Tuple[ImmutableVar, ImmutableVar], ...] = dataclasses.field( + default_factory=tuple + ) - class Config: - """The Pydantic config.""" + def __init__( + self, + handler: EventHandler, + event_actions: Dict[str, Union[bool, int]] | None = None, + client_handler_name: str = "", + args: Tuple[Tuple[ImmutableVar, ImmutableVar], ...] = tuple(), + ): + """Initialize an EventSpec. - # Required to allow tuple fields. - frozen = True + Args: + event_actions: The event actions. + handler: The event handler. + client_handler_name: The client handler name. + args: The arguments to pass to the function. + """ + if event_actions is None: + event_actions = {} + object.__setattr__(self, "event_actions", event_actions) + object.__setattr__(self, "handler", handler) + object.__setattr__(self, "client_handler_name", client_handler_name) + object.__setattr__(self, "args", args or tuple()) def with_args( self, args: Tuple[Tuple[ImmutableVar, ImmutableVar], ...] @@ -286,6 +318,9 @@ def add_args(self, *args: ImmutableVar) -> EventSpec: return self.with_args(self.args + new_payload) +@dataclasses.dataclass( + frozen=True, +) class CallableEventSpec(EventSpec): """Decorate an EventSpec-returning function to act as both a EventSpec and a function. @@ -305,10 +340,13 @@ def __init__(self, fn: Callable[..., EventSpec] | None = None, **kwargs): if fn is not None: default_event_spec = fn() super().__init__( - fn=fn, # type: ignore - **default_event_spec.dict(), + event_actions=default_event_spec.event_actions, + client_handler_name=default_event_spec.client_handler_name, + args=default_event_spec.args, + handler=default_event_spec.handler, **kwargs, ) + object.__setattr__(self, "fn", fn) else: super().__init__(**kwargs) @@ -332,12 +370,16 @@ def __call__(self, *args, **kwargs) -> EventSpec: return self.fn(*args, **kwargs) +@dataclasses.dataclass( + init=True, + frozen=True, +) class EventChain(EventActionsMixin): """Container for a chain of events that will be executed in order.""" - events: List[EventSpec] + events: List[EventSpec] = dataclasses.field(default_factory=list) - args_spec: Optional[Callable] + args_spec: Optional[Callable] = dataclasses.field(default=None) # These chains can be used for their side effects when no other events are desired. @@ -345,14 +387,22 @@ class EventChain(EventActionsMixin): prevent_default = EventChain(events=[], args_spec=lambda: []).prevent_default -class Target(Base): +@dataclasses.dataclass( + init=True, + frozen=True, +) +class Target: """A Javascript event target.""" checked: bool = False value: Any = None -class FrontendEvent(Base): +@dataclasses.dataclass( + init=True, + frozen=True, +) +class FrontendEvent: """A Javascript event.""" target: Target = Target() @@ -360,7 +410,11 @@ class FrontendEvent(Base): value: Any = None -class FileUpload(Base): +@dataclasses.dataclass( + init=True, + frozen=True, +) +class FileUpload: """Class to represent a file upload.""" upload_id: Optional[str] = None diff --git a/reflex/ivars/base.py b/reflex/ivars/base.py index eeca86ed8d..93c257afd8 100644 --- a/reflex/ivars/base.py +++ b/reflex/ivars/base.py @@ -433,6 +433,9 @@ def to( if issubclass(output, (ObjectVar, Base)): return ToObjectOperation.create(self, var_type or dict) + if dataclasses.is_dataclass(output): + return ToObjectOperation.create(self, var_type or dict) + if issubclass(output, FunctionVar): # if fixed_type is not None and not issubclass(fixed_type, Callable): # raise TypeError( @@ -491,7 +494,11 @@ def guess_type(self) -> ImmutableVar: ): return self.to(NumberVar, self._var_type) - if all(inspect.isclass(t) and issubclass(t, Base) for t in inner_types): + if all( + inspect.isclass(t) + and (issubclass(t, Base) or dataclasses.is_dataclass(t)) + for t in inner_types + ): return self.to(ObjectVar, self._var_type) return self @@ -511,6 +518,8 @@ def guess_type(self) -> ImmutableVar: return self.to(StringVar, self._var_type) if issubclass(fixed_type, Base): return self.to(ObjectVar, self._var_type) + if dataclasses.is_dataclass(fixed_type): + return self.to(ObjectVar, self._var_type) return self def get_default_value(self) -> Any: @@ -998,6 +1007,16 @@ def create( ) return LiteralVar.create(serialized_value, _var_data=_var_data) + if dataclasses.is_dataclass(value) and not isinstance(value, type): + return LiteralObjectVar.create( + { + k: (None if callable(v) else v) + for k, v in dataclasses.asdict(value).items() + }, + _var_type=type(value), + _var_data=_var_data, + ) + raise TypeError( f"Unsupported type {type(value)} for LiteralVar. Tried to create a LiteralVar from {value}." ) diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py index b5694e22f9..46b524cd76 100644 --- a/reflex/middleware/hydrate_middleware.py +++ b/reflex/middleware/hydrate_middleware.py @@ -2,6 +2,7 @@ from __future__ import annotations +import dataclasses from typing import TYPE_CHECKING, Optional from reflex import constants @@ -14,6 +15,7 @@ from reflex.app import App +@dataclasses.dataclass(init=True) class HydrateMiddleware(Middleware): """Middleware to handle initial app hydration.""" diff --git a/reflex/middleware/middleware.py b/reflex/middleware/middleware.py index 76cbcfe9a7..ef9de0bdeb 100644 --- a/reflex/middleware/middleware.py +++ b/reflex/middleware/middleware.py @@ -2,10 +2,9 @@ from __future__ import annotations -from abc import ABC +from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional -from reflex.base import Base from reflex.event import Event from reflex.state import BaseState, StateUpdate @@ -13,9 +12,10 @@ from reflex.app import App -class Middleware(Base, ABC): +class Middleware(ABC): """Middleware to preprocess and postprocess requests.""" + @abstractmethod async def preprocess( self, app: App, state: BaseState, event: Event ) -> Optional[StateUpdate]: diff --git a/reflex/state.py b/reflex/state.py index 7e759f78ae..79230e2af3 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -5,8 +5,10 @@ import asyncio import contextlib import copy +import dataclasses import functools import inspect +import json import os import uuid from abc import ABC, abstractmethod @@ -84,13 +86,15 @@ TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb -class HeaderData(Base): +@dataclasses.dataclass(frozen=True) +class HeaderData: """An object containing headers data.""" host: str = "" origin: str = "" upgrade: str = "" connection: str = "" + cookie: str = "" pragma: str = "" cache_control: str = "" user_agent: str = "" @@ -106,13 +110,16 @@ def __init__(self, router_data: Optional[dict] = None): Args: router_data: the router_data dict. """ - super().__init__() if router_data: for k, v in router_data.get(constants.RouteVar.HEADERS, {}).items(): - setattr(self, format.to_snake_case(k), v) + object.__setattr__(self, format.to_snake_case(k), v) + else: + for k in dataclasses.fields(self): + object.__setattr__(self, k.name, "") -class PageData(Base): +@dataclasses.dataclass(frozen=True) +class PageData: """An object containing page data.""" host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?) @@ -120,7 +127,7 @@ class PageData(Base): raw_path: str = "" full_path: str = "" full_raw_path: str = "" - params: dict = {} + params: dict = dataclasses.field(default_factory=dict) def __init__(self, router_data: Optional[dict] = None): """Initalize the PageData object based on router_data. @@ -128,17 +135,34 @@ def __init__(self, router_data: Optional[dict] = None): Args: router_data: the router_data dict. """ - super().__init__() if router_data: - self.host = router_data.get(constants.RouteVar.HEADERS, {}).get("origin") - self.path = router_data.get(constants.RouteVar.PATH, "") - self.raw_path = router_data.get(constants.RouteVar.ORIGIN, "") - self.full_path = f"{self.host}{self.path}" - self.full_raw_path = f"{self.host}{self.raw_path}" - self.params = router_data.get(constants.RouteVar.QUERY, {}) + object.__setattr__( + self, + "host", + router_data.get(constants.RouteVar.HEADERS, {}).get("origin", ""), + ) + object.__setattr__( + self, "path", router_data.get(constants.RouteVar.PATH, "") + ) + object.__setattr__( + self, "raw_path", router_data.get(constants.RouteVar.ORIGIN, "") + ) + object.__setattr__(self, "full_path", f"{self.host}{self.path}") + object.__setattr__(self, "full_raw_path", f"{self.host}{self.raw_path}") + object.__setattr__( + self, "params", router_data.get(constants.RouteVar.QUERY, {}) + ) + else: + object.__setattr__(self, "host", "") + object.__setattr__(self, "path", "") + object.__setattr__(self, "raw_path", "") + object.__setattr__(self, "full_path", "") + object.__setattr__(self, "full_raw_path", "") + object.__setattr__(self, "params", {}) -class SessionData(Base): +@dataclasses.dataclass(frozen=True, init=False) +class SessionData: """An object containing session data.""" client_token: str = "" @@ -151,19 +175,24 @@ def __init__(self, router_data: Optional[dict] = None): Args: router_data: the router_data dict. """ - super().__init__() if router_data: - self.client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "") - self.client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "") - self.session_id = router_data.get(constants.RouteVar.SESSION_ID, "") + client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "") + client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "") + session_id = router_data.get(constants.RouteVar.SESSION_ID, "") + else: + client_token = client_ip = session_id = "" + object.__setattr__(self, "client_token", client_token) + object.__setattr__(self, "client_ip", client_ip) + object.__setattr__(self, "session_id", session_id) -class RouterData(Base): +@dataclasses.dataclass(frozen=True, init=False) +class RouterData: """An object containing RouterData.""" - session: SessionData = SessionData() - headers: HeaderData = HeaderData() - page: PageData = PageData() + session: SessionData = dataclasses.field(default_factory=SessionData) + headers: HeaderData = dataclasses.field(default_factory=HeaderData) + page: PageData = dataclasses.field(default_factory=PageData) def __init__(self, router_data: Optional[dict] = None): """Initialize the RouterData object. @@ -171,10 +200,30 @@ def __init__(self, router_data: Optional[dict] = None): Args: router_data: the router_data dict. """ - super().__init__() - self.session = SessionData(router_data) - self.headers = HeaderData(router_data) - self.page = PageData(router_data) + object.__setattr__(self, "session", SessionData(router_data)) + object.__setattr__(self, "headers", HeaderData(router_data)) + object.__setattr__(self, "page", PageData(router_data)) + + def toJson(self) -> str: + """Convert the object to a JSON string. + + Returns: + The JSON string. + """ + return json.dumps(dataclasses.asdict(self)) + + +@serializer +def serialize_routerdata(value: RouterData) -> str: + """Serialize a RouterData instance. + + Args: + value: The RouterData to serialize. + + Returns: + The serialized RouterData. + """ + return value.toJson() def _no_chain_background_task( @@ -250,10 +299,11 @@ def _split_substate_key(substate_key: str) -> tuple[str, str]: return token, state_name +@dataclasses.dataclass(frozen=True, init=False) class EventHandlerSetVar(EventHandler): """A special event handler to wrap setvar functionality.""" - state_cls: Type[BaseState] + state_cls: Type[BaseState] = dataclasses.field(init=False) def __init__(self, state_cls: Type[BaseState]): """Initialize the EventHandlerSetVar. @@ -264,8 +314,8 @@ def __init__(self, state_cls: Type[BaseState]): super().__init__( fn=type(self).setvar, state_full_name=state_cls.get_full_name(), - state_cls=state_cls, # type: ignore ) + object.__setattr__(self, "state_cls", state_cls) def setvar(self, var_name: str, value: Any): """Set the state variable to the value of the event. @@ -1911,8 +1961,13 @@ def dict( self.dirty_vars.update(self._always_dirty_computed_vars) self._mark_dirty() + def dictify(value: Any): + if dataclasses.is_dataclass(value) and not isinstance(value, type): + return dataclasses.asdict(value) + return value + base_vars = { - prop_name: self.get_value(key=getattr(self, prop_name)) + prop_name: dictify(self.get_value(getattr(self, prop_name))) for prop_name in self.base_vars } if initial and include_computed: @@ -1992,9 +2047,6 @@ def __getstate__(self): return state -EventHandlerSetVar.update_forward_refs() - - class State(BaseState): """The app Base State.""" @@ -2426,18 +2478,29 @@ def _as_state_update(self, *args, **kwargs) -> StateUpdate: self._self_mutable = original_mutable -class StateUpdate(Base): +@dataclasses.dataclass( + frozen=True, +) +class StateUpdate: """A state update sent to the frontend.""" # The state delta. - delta: Delta = {} + delta: Delta = dataclasses.field(default_factory=dict) # Events to be added to the event queue. - events: List[Event] = [] + events: List[Event] = dataclasses.field(default_factory=list) # Whether this is the final state update for the event. final: bool = True + def json(self) -> str: + """Convert the state update to a JSON string. + + Returns: + The state update as a JSON string. + """ + return json.dumps(dataclasses.asdict(self)) + class StateManager(Base, ABC): """A class to manage many client states.""" diff --git a/reflex/utils/format.py b/reflex/utils/format.py index fd7737cfe9..eff26b2150 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -2,6 +2,7 @@ from __future__ import annotations +import dataclasses import inspect import json import os @@ -623,6 +624,14 @@ def format_state(value: Any, key: Optional[str] = None) -> Any: if isinstance(value, dict): return {k: format_state(v, k) for k, v in value.items()} + # Hand dataclasses. + if dataclasses.is_dataclass(value): + if isinstance(value, type): + raise TypeError( + f"Cannot format state of type {type(value)}. Please provide an instance of the dataclass." + ) + return {k: format_state(v, k) for k, v in dataclasses.asdict(value).items()} + # Handle lists, sets, typles. if isinstance(value, types.StateIterBases): return [format_state(v) for v in value] diff --git a/reflex/utils/imports.py b/reflex/utils/imports.py index d58c2bf3f6..8f53ed07a3 100644 --- a/reflex/utils/imports.py +++ b/reflex/utils/imports.py @@ -2,10 +2,9 @@ from __future__ import annotations +import dataclasses from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Union - -from reflex.base import Base +from typing import DefaultDict, Dict, List, Optional, Tuple, Union def merge_imports( @@ -19,12 +18,22 @@ def merge_imports( Returns: The merged import dicts. """ - all_imports = defaultdict(list) + all_imports: DefaultDict[str, List[ImportVar]] = defaultdict(list) for import_dict in imports: for lib, fields in ( import_dict if isinstance(import_dict, tuple) else import_dict.items() ): - all_imports[lib].extend(fields) + if isinstance(fields, (list, tuple, set)): + all_imports[lib].extend( + ( + ImportVar(field) if isinstance(field, str) else field + for field in fields + ) + ) + else: + all_imports[lib].append( + ImportVar(fields) if isinstance(fields, str) else fields + ) return all_imports @@ -75,7 +84,8 @@ def collapse_imports( } -class ImportVar(Base): +@dataclasses.dataclass(order=True, frozen=True) +class ImportVar: """An import var.""" # The name of the import tag. @@ -111,73 +121,6 @@ def name(self) -> str: else: return self.tag or "" - def __lt__(self, other: ImportVar) -> bool: - """Compare two ImportVar objects. - - Args: - other: The other ImportVar object to compare. - - Returns: - Whether this ImportVar object is less than the other. - """ - return ( - self.tag, - self.is_default, - self.alias, - self.install, - self.render, - self.transpile, - ) < ( - other.tag, - other.is_default, - other.alias, - other.install, - other.render, - other.transpile, - ) - - def __eq__(self, other: ImportVar) -> bool: - """Check if two ImportVar objects are equal. - - Args: - other: The other ImportVar object to compare. - - Returns: - Whether the two ImportVar objects are equal. - """ - return ( - self.tag, - self.is_default, - self.alias, - self.install, - self.render, - self.transpile, - ) == ( - other.tag, - other.is_default, - other.alias, - other.install, - other.render, - other.transpile, - ) - - def __hash__(self) -> int: - """Hash the ImportVar object. - - Returns: - The hash of the ImportVar object. - """ - return hash( - ( - self.tag, - self.is_default, - self.alias, - self.install, - self.render, - self.transpile, - ) - ) - ImportTypes = Union[str, ImportVar, List[Union[str, ImportVar]], List[ImportVar]] ImportDict = Dict[str, ImportTypes] diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 3384be5cf1..78139034bc 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -2,6 +2,7 @@ from __future__ import annotations +import dataclasses import functools import glob import importlib @@ -32,7 +33,6 @@ from redis.asyncio import Redis from reflex import constants, model -from reflex.base import Base from reflex.compiler import templates from reflex.config import Config, get_config from reflex.utils import console, net, path_ops, processes @@ -43,7 +43,8 @@ CURRENTLY_INSTALLING_NODE = False -class Template(Base): +@dataclasses.dataclass(frozen=True) +class Template: """A template for a Reflex app.""" name: str @@ -52,7 +53,8 @@ class Template(Base): demo_url: str -class CpuInfo(Base): +@dataclasses.dataclass(frozen=True) +class CpuInfo: """Model to save cpu info.""" manufacturer_id: Optional[str] @@ -1279,7 +1281,7 @@ def get_release_by_tag(tag: str) -> dict | None: None, ) return { - tp["name"]: Template.parse_obj(tp) + tp["name"]: Template(**tp) for tp in templates_data if not tp["hidden"] and tp["code_url"] is not None } diff --git a/reflex/utils/telemetry.py b/reflex/utils/telemetry.py index e027ed81a7..03e2b943b7 100644 --- a/reflex/utils/telemetry.py +++ b/reflex/utils/telemetry.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import dataclasses import multiprocessing import platform import warnings @@ -144,7 +145,7 @@ def _prepare_event(event: str, **kwargs) -> dict: "python_version": get_python_version(), "cpu_count": get_cpu_count(), "memory": get_memory(), - "cpu_info": dict(cpuinfo) if cpuinfo else {}, + "cpu_info": dataclasses.asdict(cpuinfo) if cpuinfo else {}, **additional_fields, }, "timestamp": stamp, diff --git a/reflex/utils/types.py b/reflex/utils/types.py index f4463fa920..ba58408ff1 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -3,6 +3,7 @@ from __future__ import annotations import contextlib +import dataclasses import inspect import sys import types @@ -480,7 +481,11 @@ def is_valid_var_type(type_: Type) -> bool: if is_union(type_): return all((is_valid_var_type(arg) for arg in get_args(type_))) - return _issubclass(type_, StateVar) or serializers.has_serializer(type_) + return ( + _issubclass(type_, StateVar) + or serializers.has_serializer(type_) + or dataclasses.is_dataclass(type_) + ) def is_backend_base_variable(name: str, cls: Type) -> bool: diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 592e1ca6e4..79a72b9973 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -637,21 +637,21 @@ def test_component_create_unallowed_types(children, test_component): "props": [], "contents": "", "args": None, - "special_props": set(), + "special_props": [], "children": [ { "name": "RadixThemesText", "props": ['as={"p"}'], "contents": "", "args": None, - "special_props": set(), + "special_props": [], "children": [ { "name": "", "props": [], "contents": '{"first_text"}', "args": None, - "special_props": set(), + "special_props": [], "children": [], "autofocus": False, } @@ -679,13 +679,13 @@ def test_component_create_unallowed_types(children, test_component): "contents": '{"first_text"}', "name": "", "props": [], - "special_props": set(), + "special_props": [], } ], "contents": "", "name": "RadixThemesText", "props": ['as={"p"}'], - "special_props": set(), + "special_props": [], }, { "args": None, @@ -698,19 +698,19 @@ def test_component_create_unallowed_types(children, test_component): "contents": '{"second_text"}', "name": "", "props": [], - "special_props": set(), + "special_props": [], } ], "contents": "", "name": "RadixThemesText", "props": ['as={"p"}'], - "special_props": set(), + "special_props": [], }, ], "contents": "", "name": "Fragment", "props": [], - "special_props": set(), + "special_props": [], }, ), ( @@ -730,13 +730,13 @@ def test_component_create_unallowed_types(children, test_component): "contents": '{"first_text"}', "name": "", "props": [], - "special_props": set(), + "special_props": [], } ], "contents": "", "name": "RadixThemesText", "props": ['as={"p"}'], - "special_props": set(), + "special_props": [], }, { "args": None, @@ -757,31 +757,31 @@ def test_component_create_unallowed_types(children, test_component): "contents": '{"second_text"}', "name": "", "props": [], - "special_props": set(), + "special_props": [], } ], "contents": "", "name": "RadixThemesText", "props": ['as={"p"}'], - "special_props": set(), + "special_props": [], } ], "contents": "", "name": "Fragment", "props": [], - "special_props": set(), + "special_props": [], } ], "contents": "", "name": "RadixThemesBox", "props": [], - "special_props": set(), + "special_props": [], }, ], "contents": "", "name": "Fragment", "props": [], - "special_props": set(), + "special_props": [], }, ), ], @@ -1289,12 +1289,12 @@ def handler2(self, arg): id="fstring-class_name", ), pytest.param( - rx.fragment(special_props={TEST_VAR}), + rx.fragment(special_props=[TEST_VAR]), [TEST_VAR], id="direct-special_props", ), pytest.param( - rx.fragment(special_props={LiteralVar.create(f"foo{TEST_VAR}bar")}), + rx.fragment(special_props=[LiteralVar.create(f"foo{TEST_VAR}bar")]), [FORMATTED_TEST_VAR], id="fstring-special_props", ), diff --git a/tests/test_app.py b/tests/test_app.py index 5544736bfe..a2a0cede7a 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,5 +1,6 @@ from __future__ import annotations +import dataclasses import functools import io import json @@ -1052,7 +1053,7 @@ def _dynamic_state_event(name, val, **kwargs): f"comp_{arg_name}": exp_val, constants.CompileVars.IS_HYDRATED: False, # "side_effect_counter": exp_index, - "router": exp_router, + "router": dataclasses.asdict(exp_router), } }, events=[ diff --git a/tests/test_state.py b/tests/test_state.py index 29944840eb..3da74fc893 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -2,6 +2,7 @@ import asyncio import copy +import dataclasses import datetime import functools import json @@ -58,6 +59,7 @@ "origin": "", "upgrade": "", "connection": "", + "cookie": "", "pragma": "", "cache_control": "", "user_agent": "", @@ -865,8 +867,10 @@ def test_get_headers(test_state, router_data, router_data_headers): router_data: The router data fixture. router_data_headers: The expected headers. """ + print(router_data_headers) test_state.router = RouterData(router_data) - assert test_state.router.headers.dict() == { + print(test_state.router.headers) + assert dataclasses.asdict(test_state.router.headers) == { format.to_snake_case(k): v for k, v in router_data_headers.items() } @@ -1908,19 +1912,21 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): mock_app.event_namespace.emit.assert_called_once() mcall = mock_app.event_namespace.emit.mock_calls[0] assert mcall.args[0] == str(SocketEvent.EVENT) - assert json.loads(mcall.args[1]) == StateUpdate( - delta={ - parent_state.get_full_name(): { - "upper": "", - "sum": 3.14, - }, - grandchild_state.get_full_name(): { - "value2": "42", - }, - GrandchildState3.get_full_name(): { - "computed": "", - }, - } + assert json.loads(mcall.args[1]) == dataclasses.asdict( + StateUpdate( + delta={ + parent_state.get_full_name(): { + "upper": "", + "sum": 3.14, + }, + grandchild_state.get_full_name(): { + "value2": "42", + }, + GrandchildState3.get_full_name(): { + "computed": "", + }, + } + ) ) assert mcall.kwargs["to"] == grandchild_state.router.session.session_id diff --git a/tests/utils/test_format.py b/tests/utils/test_format.py index 8c559f141d..286748943b 100644 --- a/tests/utils/test_format.py +++ b/tests/utils/test_format.py @@ -553,6 +553,7 @@ def test_format_query_params(input, output): "origin": "", "upgrade": "", "connection": "", + "cookie": "", "pragma": "", "cache_control": "", "user_agent": "", diff --git a/tests/utils/test_imports.py b/tests/utils/test_imports.py index e9be5c1be0..c30d1d85c7 100644 --- a/tests/utils/test_imports.py +++ b/tests/utils/test_imports.py @@ -54,17 +54,21 @@ def test_import_var(import_var, expected_name): ( {"react": {"Component"}}, {"react": {"Component"}, "react-dom": {"render"}}, - {"react": {"Component"}, "react-dom": {"render"}}, + {"react": {ImportVar("Component")}, "react-dom": {ImportVar("render")}}, ), ( {"react": {"Component"}, "next/image": {"Image"}}, {"react": {"Component"}, "react-dom": {"render"}}, - {"react": {"Component"}, "react-dom": {"render"}, "next/image": {"Image"}}, + { + "react": {ImportVar("Component")}, + "react-dom": {ImportVar("render")}, + "next/image": {ImportVar("Image")}, + }, ), ( {"react": {"Component"}}, {"": {"some/custom.css"}}, - {"react": {"Component"}, "": {"some/custom.css"}}, + {"react": {ImportVar("Component")}, "": {ImportVar("some/custom.css")}}, ), ], )