diff --git a/reflex/app.py b/reflex/app.py index afc40e3b88..fc8efb4201 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1462,10 +1462,10 @@ class EventNamespace(AsyncNamespace): app: App # Keep a mapping between socket ID and client token. - token_to_sid: dict[str, str] = {} + token_to_sid: dict[str, str] # Keep a mapping between client token and socket ID. - sid_to_token: dict[str, str] = {} + sid_to_token: dict[str, str] def __init__(self, namespace: str, app: App): """Initialize the event namespace. @@ -1475,6 +1475,8 @@ def __init__(self, namespace: str, app: App): app: The application object. """ super().__init__(namespace) + self.token_to_sid = {} + self.sid_to_token = {} self.app = app def on_connect(self, sid, environ): diff --git a/reflex/components/plotly/plotly.py b/reflex/components/plotly/plotly.py index 1e551ce87f..f07a743cb5 100644 --- a/reflex/components/plotly/plotly.py +++ b/reflex/components/plotly/plotly.py @@ -255,7 +255,7 @@ def _exclude_props(self) -> set[str]: def _render(self): tag = super()._render() - figure = self.data.to(dict) + figure = self.data.to(dict) if self.data is not None else {} merge_dicts = [] # Data will be merged and spread from these dict Vars if self.layout is not None: # Why is this not a literal dict? Great question... it didn't work diff --git a/reflex/state.py b/reflex/state.py index 58a74e1cc0..61c12fed29 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1748,7 +1748,11 @@ async def _process_event( if value is None: continue hinted_args = value_inside_optional(hinted_args) - if isinstance(value, dict) and inspect.isclass(hinted_args): + if ( + isinstance(value, dict) + and inspect.isclass(hinted_args) + and not types.is_generic_alias(hinted_args) # py3.9-py3.10 + ): if issubclass(hinted_args, Model): # Remove non-fields from the payload payload[arg] = hinted_args( @@ -1759,7 +1763,7 @@ async def _process_event( } ) elif dataclasses.is_dataclass(hinted_args) or issubclass( - hinted_args, Base + hinted_args, (Base, BaseModelV1, BaseModelV2) ): payload[arg] = hinted_args(**value) if isinstance(value, list) and (hinted_args is set or hinted_args is Set): @@ -2174,8 +2178,6 @@ def _serialize(self) -> bytes: payload = b"" try: payload = pickle.dumps((self._to_schema(), self)) - if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF: - self._check_state_size(len(payload)) except HANDLED_PICKLE_ERRORS as og_pickle_error: error = ( f"Failed to serialize state {self.get_full_name()} due to unpicklable object. " @@ -2193,10 +2195,15 @@ def _serialize(self) -> bytes: except HANDLED_PICKLE_ERRORS as ex: error += f"Dill was also unable to pickle the state: {ex}" console.warn(error) + if environment.REFLEX_COMPRESS_STATE.get(): from blosc2 import compress payload = compress(payload) + + if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF: + self._check_state_size(len(payload)) + return payload @classmethod diff --git a/tests/units/test_state.py b/tests/units/test_state.py index c8a52e6c0b..45c021bd82 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -10,7 +10,17 @@ import sys import threading from textwrap import dedent -from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union +from typing import ( + Any, + AsyncGenerator, + Callable, + Dict, + List, + Optional, + Set, + Tuple, + Union, +) from unittest.mock import AsyncMock, Mock import pytest @@ -1828,12 +1838,11 @@ async def _coro_waiter(): @pytest.fixture(scope="function") -def mock_app(monkeypatch, state_manager: StateManager) -> rx.App: - """Mock app fixture. +def mock_app_simple(monkeypatch) -> rx.App: + """Simple Mock app fixture. Args: monkeypatch: Pytest monkeypatch object. - state_manager: A state manager. Returns: The app, after mocking out prerequisites.get_app() @@ -1844,7 +1853,6 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App: setattr(app_module, CompileVars.APP, app) app.state = TestState - app._state_manager = state_manager app.event_namespace.emit = AsyncMock() # type: ignore def _mock_get_app(*args, **kwargs): @@ -1854,6 +1862,21 @@ def _mock_get_app(*args, **kwargs): return app +@pytest.fixture(scope="function") +def mock_app(mock_app_simple: rx.App, state_manager: StateManager) -> rx.App: + """Mock app fixture. + + Args: + mock_app_simple: A simple mock app. + state_manager: A state manager. + + Returns: + The app, after mocking out prerequisites.get_app() + """ + mock_app_simple._state_manager = state_manager + return mock_app_simple + + @pytest.mark.asyncio async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): """Test that the state proxy works. @@ -3506,3 +3529,106 @@ class SubMixin(Mixin, mixin=True): with pytest.raises(ReflexRuntimeError): SubMixin() + + +class ReflexModel(rx.Model): + """A model for testing.""" + + foo: str + + +class UpcastState(rx.State): + """A state for testing upcasting.""" + + passed: bool = False + + def rx_model(self, m: ReflexModel): # noqa: D102 + assert isinstance(m, ReflexModel) + self.passed = True + + def rx_base(self, o: Object): # noqa: D102 + assert isinstance(o, Object) + self.passed = True + + def rx_base_or_none(self, o: Optional[Object]): # noqa: D102 + if o is not None: + assert isinstance(o, Object) + self.passed = True + + def rx_basemodelv1(self, m: ModelV1): # noqa: D102 + assert isinstance(m, ModelV1) + self.passed = True + + def rx_basemodelv2(self, m: ModelV2): # noqa: D102 + assert isinstance(m, ModelV2) + self.passed = True + + def rx_dataclass(self, dc: ModelDC): # noqa: D102 + assert isinstance(dc, ModelDC) + self.passed = True + + def py_set(self, s: set): # noqa: D102 + assert isinstance(s, set) + self.passed = True + + def py_Set(self, s: Set): # noqa: D102 + assert isinstance(s, Set) + self.passed = True + + def py_tuple(self, t: tuple): # noqa: D102 + assert isinstance(t, tuple) + self.passed = True + + def py_Tuple(self, t: Tuple): # noqa: D102 + assert isinstance(t, tuple) + self.passed = True + + def py_dict(self, d: dict[str, str]): # noqa: D102 + assert isinstance(d, dict) + self.passed = True + + def py_list(self, ls: list[str]): # noqa: D102 + assert isinstance(ls, list) + self.passed = True + + def py_Any(self, a: Any): # noqa: D102 + assert isinstance(a, list) + self.passed = True + + def py_unresolvable(self, u: "Unresolvable"): # noqa: D102, F821 # type: ignore + assert isinstance(u, list) + self.passed = True + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_app_simple") +@pytest.mark.parametrize( + ("handler", "payload"), + [ + (UpcastState.rx_model, {"m": {"foo": "bar"}}), + (UpcastState.rx_base, {"o": {"foo": "bar"}}), + (UpcastState.rx_base_or_none, {"o": {"foo": "bar"}}), + (UpcastState.rx_base_or_none, {"o": None}), + (UpcastState.rx_basemodelv1, {"m": {"foo": "bar"}}), + (UpcastState.rx_basemodelv2, {"m": {"foo": "bar"}}), + (UpcastState.rx_dataclass, {"dc": {"foo": "bar"}}), + (UpcastState.py_set, {"s": ["foo", "foo"]}), + (UpcastState.py_Set, {"s": ["foo", "foo"]}), + (UpcastState.py_tuple, {"t": ["foo", "foo"]}), + (UpcastState.py_Tuple, {"t": ["foo", "foo"]}), + (UpcastState.py_dict, {"d": {"foo": "bar"}}), + (UpcastState.py_list, {"ls": ["foo", "foo"]}), + (UpcastState.py_Any, {"a": ["foo"]}), + (UpcastState.py_unresolvable, {"u": ["foo"]}), + ], +) +async def test_upcast_event_handler_arg(handler, payload): + """Test that upcast event handler args work correctly. + + Args: + handler: The handler to test. + payload: The payload to test. + """ + state = UpcastState() + async for update in state._process_event(handler, state, payload): + assert update.delta == {UpcastState.get_full_name(): {"passed": True}}