Skip to content

Commit

Permalink
add dataclass support
Browse files Browse the repository at this point in the history
  • Loading branch information
adhami3310 committed Sep 10, 2024
1 parent afa7cfb commit 532de4f
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 49 deletions.
5 changes: 4 additions & 1 deletion reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import functools
import inspect
import io
import json
import multiprocessing
import os
import platform
Expand Down Expand Up @@ -1096,6 +1097,7 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
if delta:
# When the state is modified reset dirty status and emit the delta to the frontend.
state._clean()
print(dir(state.router))
await self.event_namespace.emit_update(
update=StateUpdate(delta=delta),
sid=state.router.session.session_id,
Expand Down Expand Up @@ -1531,8 +1533,9 @@ async def on_event(self, sid, data):
sid: The Socket.IO session id.
data: The event data.
"""
fields = json.loads(data)
# Get the event.
event = Event.parse_raw(data)
event = Event(**{k: v for k, v in fields.items() if k != "handler"})

self.token_to_sid[event.token] = sid
self.sid_to_token[sid] = event.token
Expand Down
31 changes: 19 additions & 12 deletions reflex/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
except ImportError:
from typing_extensions import Annotated


@dataclasses.dataclass(
init=True,
frozen=True,
Expand All @@ -46,10 +47,10 @@ class Event:
name: str

# The routing data where event occurred
router_data: Dict[str, Any] = {}
router_data: Dict[str, Any] = dataclasses.field(default_factory=dict)

# The event payload.
payload: Dict[str, Any] = {}
payload: Dict[str, Any] = dataclasses.field(default_factory=dict)

@property
def substate_token(self) -> str:
Expand Down Expand Up @@ -268,12 +269,6 @@ def __init__(
object.__setattr__(self, "client_handler_name", client_handler_name)
object.__setattr__(self, "args", args or tuple())

class Config:
"""The Pydantic config."""

# Required to allow tuple fields.
frozen = True

def with_args(
self, args: Tuple[Tuple[ImmutableVar, ImmutableVar], ...]
) -> EventSpec:
Expand Down Expand Up @@ -345,13 +340,13 @@ def __init__(self, fn: Callable[..., EventSpec] | None = None, **kwargs):
if fn is not None:
default_event_spec = fn()
super().__init__(
fn=fn, # type: ignore
event_actions=default_event_spec.event_actions,
client_handler_name=default_event_spec.client_handler_name,
args=default_event_spec.args,
handler=default_event_spec.handler,
**kwargs,
)
object.__setattr__(self, "fn", fn)
else:
super().__init__(**kwargs)

Expand Down Expand Up @@ -392,22 +387,34 @@ class EventChain(EventActionsMixin):
prevent_default = EventChain(events=[], args_spec=lambda: []).prevent_default


class Target(Base):
@dataclasses.dataclass(
init=True,
frozen=True,
)
class Target:
"""A Javascript event target."""

checked: bool = False
value: Any = None


class FrontendEvent(Base):
@dataclasses.dataclass(
init=True,
frozen=True,
)
class FrontendEvent:
"""A Javascript event."""

target: Target = Target()
key: str = ""
value: Any = None


class FileUpload(Base):
@dataclasses.dataclass(
init=True,
frozen=True,
)
class FileUpload:
"""Class to represent a file upload."""

upload_id: Optional[str] = None
Expand Down
21 changes: 20 additions & 1 deletion reflex/ivars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,9 @@ def to(
if issubclass(output, (ObjectVar, Base)):
return ToObjectOperation.create(self, var_type or dict)

if dataclasses.is_dataclass(output):
return ToObjectOperation.create(self, var_type or dict)

if issubclass(output, FunctionVar):
# if fixed_type is not None and not issubclass(fixed_type, Callable):
# raise TypeError(
Expand Down Expand Up @@ -479,7 +482,11 @@ def guess_type(self) -> ImmutableVar:
):
return self.to(NumberVar, self._var_type)

if all(inspect.isclass(t) and issubclass(t, Base) for t in inner_types):
if all(
inspect.isclass(t)
and (issubclass(t, Base) or dataclasses.is_dataclass(t))
for t in inner_types
):
return self.to(ObjectVar, self._var_type)

return self
Expand All @@ -499,6 +506,8 @@ def guess_type(self) -> ImmutableVar:
return self.to(StringVar)
if issubclass(fixed_type, Base):
return self.to(ObjectVar, self._var_type)
if dataclasses.is_dataclass(fixed_type):
return self.to(ObjectVar, self._var_type)
return self

def get_default_value(self) -> Any:
Expand Down Expand Up @@ -985,6 +994,16 @@ def create(
)
return LiteralVar.create(serialized_value, _var_data=_var_data)

if dataclasses.is_dataclass(value) and not isinstance(value, type):
return LiteralObjectVar.create(
{
k: (None if callable(v) else v)
for k, v in dataclasses.asdict(value).items()
},
_var_type=type(value),
_var_data=_var_data,
)

raise TypeError(
f"Unsupported type {type(value)} for LiteralVar. Tried to create a LiteralVar from {value}."
)
Expand Down
110 changes: 77 additions & 33 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import asyncio
import contextlib
import copy
import dataclasses
import functools
import inspect
import json
import os
import uuid
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -78,7 +80,8 @@
TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb


class HeaderData(Base):
@dataclasses.dataclass(frozen=True)
class HeaderData:
"""An object containing headers data."""

host: str = ""
Expand All @@ -100,39 +103,59 @@ def __init__(self, router_data: Optional[dict] = None):
Args:
router_data: the router_data dict.
"""
super().__init__()
if router_data:
for k, v in router_data.get(constants.RouteVar.HEADERS, {}).items():
setattr(self, format.to_snake_case(k), v)
object.__setattr__(self, format.to_snake_case(k), v)
else:
for k in dataclasses.fields(self):
object.__setattr__(self, k.name, "")


class PageData(Base):
@dataclasses.dataclass(frozen=True)
class PageData:
"""An object containing page data."""

host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?)
path: str = ""
raw_path: str = ""
full_path: str = ""
full_raw_path: str = ""
params: dict = {}
params: dict = dataclasses.field(default_factory=dict)

def __init__(self, router_data: Optional[dict] = None):
"""Initalize the PageData object based on router_data.
Args:
router_data: the router_data dict.
"""
super().__init__()
if router_data:
self.host = router_data.get(constants.RouteVar.HEADERS, {}).get("origin")
self.path = router_data.get(constants.RouteVar.PATH, "")
self.raw_path = router_data.get(constants.RouteVar.ORIGIN, "")
self.full_path = f"{self.host}{self.path}"
self.full_raw_path = f"{self.host}{self.raw_path}"
self.params = router_data.get(constants.RouteVar.QUERY, {})
object.__setattr__(
self,
"host",
router_data.get(constants.RouteVar.HEADERS, {}).get("origin", ""),
)
object.__setattr__(
self, "path", router_data.get(constants.RouteVar.PATH, "")
)
object.__setattr__(
self, "raw_path", router_data.get(constants.RouteVar.ORIGIN, "")
)
object.__setattr__(self, "full_path", f"{self.host}{self.path}")
object.__setattr__(self, "full_raw_path", f"{self.host}{self.raw_path}")
object.__setattr__(
self, "params", router_data.get(constants.RouteVar.QUERY, {})
)
else:
object.__setattr__(self, "host", "")
object.__setattr__(self, "path", "")
object.__setattr__(self, "raw_path", "")
object.__setattr__(self, "full_path", "")
object.__setattr__(self, "full_raw_path", "")
object.__setattr__(self, "params", {})


class SessionData(Base):
@dataclasses.dataclass(frozen=True, init=False)
class SessionData:
"""An object containing session data."""

client_token: str = ""
Expand All @@ -145,30 +168,42 @@ def __init__(self, router_data: Optional[dict] = None):
Args:
router_data: the router_data dict.
"""
super().__init__()
if router_data:
self.client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
self.client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "")
self.session_id = router_data.get(constants.RouteVar.SESSION_ID, "")
client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "")
session_id = router_data.get(constants.RouteVar.SESSION_ID, "")
else:
client_token = client_ip = session_id = ""
object.__setattr__(self, "client_token", client_token)
object.__setattr__(self, "client_ip", client_ip)
object.__setattr__(self, "session_id", session_id)


class RouterData(Base):
@dataclasses.dataclass(frozen=True, init=False)
class RouterData:
"""An object containing RouterData."""

session: SessionData = SessionData()
headers: HeaderData = HeaderData()
page: PageData = PageData()
session: SessionData = dataclasses.field(default_factory=SessionData)
headers: HeaderData = dataclasses.field(default_factory=HeaderData)
page: PageData = dataclasses.field(default_factory=PageData)

def __init__(self, router_data: Optional[dict] = None):
"""Initialize the RouterData object.
Args:
router_data: the router_data dict.
"""
super().__init__()
self.session = SessionData(router_data)
self.headers = HeaderData(router_data)
self.page = PageData(router_data)
object.__setattr__(self, "session", SessionData(router_data))
object.__setattr__(self, "headers", HeaderData(router_data))
object.__setattr__(self, "page", PageData(router_data))

def toJson(self) -> str:
"""Convert the object to a JSON string.
Returns:
The JSON string.
"""
return json.dumps(dataclasses.asdict(self))


def _no_chain_background_task(
Expand Down Expand Up @@ -244,10 +279,11 @@ def _split_substate_key(substate_key: str) -> tuple[str, str]:
return token, state_name


@dataclasses.dataclass(frozen=True, init=False)
class EventHandlerSetVar(EventHandler):
"""A special event handler to wrap setvar functionality."""

state_cls: Type[BaseState]
state_cls: Type[BaseState] = dataclasses.field(init=False)

def __init__(self, state_cls: Type[BaseState]):
"""Initialize the EventHandlerSetVar.
Expand All @@ -258,8 +294,8 @@ def __init__(self, state_cls: Type[BaseState]):
super().__init__(
fn=type(self).setvar,
state_full_name=state_cls.get_full_name(),
state_cls=state_cls, # type: ignore
)
object.__setattr__(self, "state_cls", state_cls)

def setvar(self, var_name: str, value: Any):
"""Set the state variable to the value of the event.
Expand Down Expand Up @@ -1859,9 +1895,6 @@ def __getstate__(self):
return state


EventHandlerSetVar.update_forward_refs()


class State(BaseState):
"""The app Base State."""

Expand Down Expand Up @@ -2293,18 +2326,29 @@ def _as_state_update(self, *args, **kwargs) -> StateUpdate:
self._self_mutable = original_mutable


class StateUpdate(Base):
@dataclasses.dataclass(
frozen=True,
)
class StateUpdate:
"""A state update sent to the frontend."""

# The state delta.
delta: Delta = {}
delta: Delta = dataclasses.field(default_factory=dict)

# Events to be added to the event queue.
events: List[Event] = []
events: List[Event] = dataclasses.field(default_factory=list)

# Whether this is the final state update for the event.
final: bool = True

def json(self) -> str:
"""Convert the state update to a JSON string.
Returns:
The state update as a JSON string.
"""
return json.dumps(dataclasses.asdict(self))


class StateManager(Base, ABC):
"""A class to manage many client states."""
Expand Down
5 changes: 5 additions & 0 deletions reflex/utils/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import dataclasses
import inspect
import json
import os
Expand Down Expand Up @@ -623,6 +624,10 @@ def format_state(value: Any, key: Optional[str] = None) -> Any:
if isinstance(value, dict):
return {k: format_state(v, k) for k, v in value.items()}

# Hand dataclasses.
if dataclasses.is_dataclass(value) and not isinstance(value, type):
return {k: format_state(v, k) for k, v in dataclasses.asdict(value).items()}

# Handle lists, sets, typles.
if isinstance(value, types.StateIterBases):
return [format_state(v) for v in value]
Expand Down
Loading

0 comments on commit 532de4f

Please sign in to comment.