Skip to content

Commit

Permalink
fix conflict with reflex-dev#4431
Browse files Browse the repository at this point in the history
Merge remote-tracking branch 'upstream/main' into state-compression
  • Loading branch information
benedikt-bartscher committed Nov 25, 2024
2 parents 39c7296 + 51ca89b commit 322e7b5
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 12 deletions.
6 changes: 4 additions & 2 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,10 +1462,10 @@ class EventNamespace(AsyncNamespace):
app: App

# Keep a mapping between socket ID and client token.
token_to_sid: dict[str, str] = {}
token_to_sid: dict[str, str]

# Keep a mapping between client token and socket ID.
sid_to_token: dict[str, str] = {}
sid_to_token: dict[str, str]

def __init__(self, namespace: str, app: App):
"""Initialize the event namespace.
Expand All @@ -1475,6 +1475,8 @@ def __init__(self, namespace: str, app: App):
app: The application object.
"""
super().__init__(namespace)
self.token_to_sid = {}
self.sid_to_token = {}
self.app = app

def on_connect(self, sid, environ):
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/plotly/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _exclude_props(self) -> set[str]:

def _render(self):
tag = super()._render()
figure = self.data.to(dict)
figure = self.data.to(dict) if self.data is not None else {}
merge_dicts = [] # Data will be merged and spread from these dict Vars
if self.layout is not None:
# Why is this not a literal dict? Great question... it didn't work
Expand Down
15 changes: 11 additions & 4 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,7 +1748,11 @@ async def _process_event(
if value is None:
continue
hinted_args = value_inside_optional(hinted_args)
if isinstance(value, dict) and inspect.isclass(hinted_args):
if (
isinstance(value, dict)
and inspect.isclass(hinted_args)
and not types.is_generic_alias(hinted_args) # py3.9-py3.10
):
if issubclass(hinted_args, Model):
# Remove non-fields from the payload
payload[arg] = hinted_args(
Expand All @@ -1759,7 +1763,7 @@ async def _process_event(
}
)
elif dataclasses.is_dataclass(hinted_args) or issubclass(
hinted_args, Base
hinted_args, (Base, BaseModelV1, BaseModelV2)
):
payload[arg] = hinted_args(**value)
if isinstance(value, list) and (hinted_args is set or hinted_args is Set):
Expand Down Expand Up @@ -2174,8 +2178,6 @@ def _serialize(self) -> bytes:
payload = b""
try:
payload = pickle.dumps((self._to_schema(), self))
if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
self._check_state_size(len(payload))
except HANDLED_PICKLE_ERRORS as og_pickle_error:
error = (
f"Failed to serialize state {self.get_full_name()} due to unpicklable object. "
Expand All @@ -2193,10 +2195,15 @@ def _serialize(self) -> bytes:
except HANDLED_PICKLE_ERRORS as ex:
error += f"Dill was also unable to pickle the state: {ex}"
console.warn(error)

if environment.REFLEX_COMPRESS_STATE.get():
from blosc2 import compress

payload = compress(payload)

if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
self._check_state_size(len(payload))

return payload

@classmethod
Expand Down
136 changes: 131 additions & 5 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,17 @@
import sys
import threading
from textwrap import dedent
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Union,
)
from unittest.mock import AsyncMock, Mock

import pytest
Expand Down Expand Up @@ -1828,12 +1838,11 @@ async def _coro_waiter():


@pytest.fixture(scope="function")
def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
"""Mock app fixture.
def mock_app_simple(monkeypatch) -> rx.App:
"""Simple Mock app fixture.
Args:
monkeypatch: Pytest monkeypatch object.
state_manager: A state manager.
Returns:
The app, after mocking out prerequisites.get_app()
Expand All @@ -1844,7 +1853,6 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:

setattr(app_module, CompileVars.APP, app)
app.state = TestState
app._state_manager = state_manager
app.event_namespace.emit = AsyncMock() # type: ignore

def _mock_get_app(*args, **kwargs):
Expand All @@ -1854,6 +1862,21 @@ def _mock_get_app(*args, **kwargs):
return app


@pytest.fixture(scope="function")
def mock_app(mock_app_simple: rx.App, state_manager: StateManager) -> rx.App:
"""Mock app fixture.
Args:
mock_app_simple: A simple mock app.
state_manager: A state manager.
Returns:
The app, after mocking out prerequisites.get_app()
"""
mock_app_simple._state_manager = state_manager
return mock_app_simple


@pytest.mark.asyncio
async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
"""Test that the state proxy works.
Expand Down Expand Up @@ -3506,3 +3529,106 @@ class SubMixin(Mixin, mixin=True):

with pytest.raises(ReflexRuntimeError):
SubMixin()


class ReflexModel(rx.Model):
"""A model for testing."""

foo: str


class UpcastState(rx.State):
"""A state for testing upcasting."""

passed: bool = False

def rx_model(self, m: ReflexModel): # noqa: D102
assert isinstance(m, ReflexModel)
self.passed = True

def rx_base(self, o: Object): # noqa: D102
assert isinstance(o, Object)
self.passed = True

def rx_base_or_none(self, o: Optional[Object]): # noqa: D102
if o is not None:
assert isinstance(o, Object)
self.passed = True

def rx_basemodelv1(self, m: ModelV1): # noqa: D102
assert isinstance(m, ModelV1)
self.passed = True

def rx_basemodelv2(self, m: ModelV2): # noqa: D102
assert isinstance(m, ModelV2)
self.passed = True

def rx_dataclass(self, dc: ModelDC): # noqa: D102
assert isinstance(dc, ModelDC)
self.passed = True

def py_set(self, s: set): # noqa: D102
assert isinstance(s, set)
self.passed = True

def py_Set(self, s: Set): # noqa: D102
assert isinstance(s, Set)
self.passed = True

def py_tuple(self, t: tuple): # noqa: D102
assert isinstance(t, tuple)
self.passed = True

def py_Tuple(self, t: Tuple): # noqa: D102
assert isinstance(t, tuple)
self.passed = True

def py_dict(self, d: dict[str, str]): # noqa: D102
assert isinstance(d, dict)
self.passed = True

def py_list(self, ls: list[str]): # noqa: D102
assert isinstance(ls, list)
self.passed = True

def py_Any(self, a: Any): # noqa: D102
assert isinstance(a, list)
self.passed = True

def py_unresolvable(self, u: "Unresolvable"): # noqa: D102, F821 # type: ignore
assert isinstance(u, list)
self.passed = True


@pytest.mark.asyncio
@pytest.mark.usefixtures("mock_app_simple")
@pytest.mark.parametrize(
("handler", "payload"),
[
(UpcastState.rx_model, {"m": {"foo": "bar"}}),
(UpcastState.rx_base, {"o": {"foo": "bar"}}),
(UpcastState.rx_base_or_none, {"o": {"foo": "bar"}}),
(UpcastState.rx_base_or_none, {"o": None}),
(UpcastState.rx_basemodelv1, {"m": {"foo": "bar"}}),
(UpcastState.rx_basemodelv2, {"m": {"foo": "bar"}}),
(UpcastState.rx_dataclass, {"dc": {"foo": "bar"}}),
(UpcastState.py_set, {"s": ["foo", "foo"]}),
(UpcastState.py_Set, {"s": ["foo", "foo"]}),
(UpcastState.py_tuple, {"t": ["foo", "foo"]}),
(UpcastState.py_Tuple, {"t": ["foo", "foo"]}),
(UpcastState.py_dict, {"d": {"foo": "bar"}}),
(UpcastState.py_list, {"ls": ["foo", "foo"]}),
(UpcastState.py_Any, {"a": ["foo"]}),
(UpcastState.py_unresolvable, {"u": ["foo"]}),
],
)
async def test_upcast_event_handler_arg(handler, payload):
"""Test that upcast event handler args work correctly.
Args:
handler: The handler to test.
payload: The payload to test.
"""
state = UpcastState()
async for update in state._process_event(handler, state, payload):
assert update.delta == {UpcastState.get_full_name(): {"passed": True}}

0 comments on commit 322e7b5

Please sign in to comment.