From 73516b924740a1c0298df7fb96f71d8ba2f64b11 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 3 Oct 2024 19:19:06 -0700 Subject: [PATCH] [ENG-3867] Garden Variety Pickle (#4054) * Use regular `pickle` module from stdlib * Avoid recreating the rx.State tree for every `get_state` * Remove dill dependency * relock deps --- poetry.lock | 96 +++++++++++++++++--------------------- pyproject.toml | 1 - reflex/state.py | 87 ++++++++++++++++++++++------------ reflex/utils/exceptions.py | 4 ++ 4 files changed, 103 insertions(+), 85 deletions(-) diff --git a/poetry.lock b/poetry.lock index f94a3832a89..928731c2653 100644 --- a/poetry.lock +++ b/poetry.lock @@ -516,21 +516,6 @@ files = [ {file = "darglint-1.8.1.tar.gz", hash = "sha256:080d5106df149b199822e7ee7deb9c012b49891538f14a11be681044f0bb20da"}, ] -[[package]] -name = "dill" -version = "0.3.8" -description = "serialize all of Python" -optional = false -python-versions = ">=3.8" -files = [ - {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"}, - {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"}, -] - -[package.extras] -graph = ["objgraph (>=1.7.2)"] -profile = ["gprof2dot (>=2022.7.29)"] - [[package]] name = "distlib" version = "0.3.8" @@ -719,13 +704,13 @@ files = [ [[package]] name = "httpcore" -version = "1.0.5" +version = "1.0.6" description = "A minimal low-level HTTP client." optional = false python-versions = ">=3.8" files = [ - {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, - {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, + {file = "httpcore-1.0.6-py3-none-any.whl", hash = "sha256:27b59625743b85577a8c0e10e55b50b5368a4f2cfe8cc7bcfa9cf00829c2682f"}, + {file = "httpcore-1.0.6.tar.gz", hash = "sha256:73f6dbd6eb8c21bbf7ef8efad555481853f5f6acdeaff1edb0694289269ee17f"}, ] [package.dependencies] @@ -736,7 +721,7 @@ h11 = ">=0.13,<0.15" asyncio = ["anyio (>=4.0,<5.0)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] -trio = ["trio (>=0.22.0,<0.26.0)"] +trio = ["trio (>=0.22.0,<1.0)"] [[package]] name = "httpx" @@ -863,21 +848,25 @@ test = ["portend", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-c [[package]] name = "jaraco-functools" -version = "4.0.2" +version = "4.1.0" description = "Functools like those found in stdlib" optional = false python-versions = ">=3.8" files = [ - {file = "jaraco.functools-4.0.2-py3-none-any.whl", hash = "sha256:c9d16a3ed4ccb5a889ad8e0b7a343401ee5b2a71cee6ed192d3f68bc351e94e3"}, - {file = "jaraco_functools-4.0.2.tar.gz", hash = "sha256:3460c74cd0d32bf82b9576bbb3527c4364d5b27a21f5158a62aed6c4b42e23f5"}, + {file = "jaraco.functools-4.1.0-py3-none-any.whl", hash = "sha256:ad159f13428bc4acbf5541ad6dec511f91573b90fba04df61dafa2a1231cf649"}, + {file = "jaraco_functools-4.1.0.tar.gz", hash = "sha256:70f7e0e2ae076498e212562325e805204fc092d7b4c17e0e86c959e249701a9d"}, ] [package.dependencies] more-itertools = "*" [package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -test = ["jaraco.classes", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["jaraco.classes", "pytest (>=6,!=8.1.*)"] +type = ["pytest-mypy"] [[package]] name = "jeepney" @@ -1788,13 +1777,13 @@ windows-terminal = ["colorama (>=0.4.6)"] [[package]] name = "pyproject-hooks" -version = "1.1.0" +version = "1.2.0" description = "Wrappers to call pyproject.toml-based build backend hooks." optional = false python-versions = ">=3.7" files = [ - {file = "pyproject_hooks-1.1.0-py3-none-any.whl", hash = "sha256:7ceeefe9aec63a1064c18d939bdc3adf2d8aa1988a510afec15151578b232aa2"}, - {file = "pyproject_hooks-1.1.0.tar.gz", hash = "sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965"}, + {file = "pyproject_hooks-1.2.0-py3-none-any.whl", hash = "sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913"}, + {file = "pyproject_hooks-1.2.0.tar.gz", hash = "sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8"}, ] [[package]] @@ -1992,13 +1981,13 @@ docs = ["sphinx"] [[package]] name = "python-multipart" -version = "0.0.10" +version = "0.0.12" description = "A streaming multipart parser for Python" optional = false python-versions = ">=3.8" files = [ - {file = "python_multipart-0.0.10-py3-none-any.whl", hash = "sha256:2b06ad9e8d50c7a8db80e3b56dab590137b323410605af2be20d62a5f1ba1dc8"}, - {file = "python_multipart-0.0.10.tar.gz", hash = "sha256:46eb3c6ce6fdda5fb1a03c7e11d490e407c6930a2703fe7aef4da71c374688fa"}, + {file = "python_multipart-0.0.12-py3-none-any.whl", hash = "sha256:43dcf96cf65888a9cd3423544dd0d75ac10f7aa0c3c28a175bbcd00c9ce1aebf"}, + {file = "python_multipart-0.0.12.tar.gz", hash = "sha256:045e1f98d719c1ce085ed7f7e1ef9d8ccc8c02ba02b5566d5f7521410ced58cb"}, ] [[package]] @@ -2143,31 +2132,31 @@ md = ["cmarkgfm (>=0.8.0)"] [[package]] name = "redis" -version = "5.0.8" +version = "5.1.0" description = "Python client for Redis database and key-value store" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "redis-5.0.8-py3-none-any.whl", hash = "sha256:56134ee08ea909106090934adc36f65c9bcbbaecea5b21ba704ba6fb561f8eb4"}, - {file = "redis-5.0.8.tar.gz", hash = "sha256:0c5b10d387568dfe0698c6fad6615750c24170e548ca2deac10c649d463e9870"}, + {file = "redis-5.1.0-py3-none-any.whl", hash = "sha256:fd4fccba0d7f6aa48c58a78d76ddb4afc698f5da4a2c1d03d916e4fd7ab88cdd"}, + {file = "redis-5.1.0.tar.gz", hash = "sha256:b756df1e4a3858fcc0ef861f3fc53623a96c41e2b1f5304e09e0fe758d333d40"}, ] [package.dependencies] async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""} [package.extras] -hiredis = ["hiredis (>1.0.0)"] -ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] +hiredis = ["hiredis (>=3.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"] [[package]] name = "reflex-chakra" -version = "0.6.0" +version = "0.6.1" description = "reflex using chakra components" optional = false python-versions = "<4.0,>=3.8" files = [ - {file = "reflex_chakra-0.6.0-py3-none-any.whl", hash = "sha256:eca1593fca67289e05591dd21fbcc8632c119d64a08bdc41fd995055a114cc91"}, - {file = "reflex_chakra-0.6.0.tar.gz", hash = "sha256:db1c7b48f1ba547bf91e5af103fce6fc7191d7225b414ebfbada7d983e33dd87"}, + {file = "reflex_chakra-0.6.1-py3-none-any.whl", hash = "sha256:824d461264b6d2c836ba4a2a430e677a890b82e83da149672accfc58786442fa"}, + {file = "reflex_chakra-0.6.1.tar.gz", hash = "sha256:4b9b3c8bada19cbb4d1b8d8bc4ab0460ec008a91f380010c34d416d5b613dc07"}, ] [package.dependencies] @@ -2247,18 +2236,19 @@ idna2008 = ["idna"] [[package]] name = "rich" -version = "13.8.1" +version = "13.9.1" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" optional = false -python-versions = ">=3.7.0" +python-versions = ">=3.8.0" files = [ - {file = "rich-13.8.1-py3-none-any.whl", hash = "sha256:1760a3c0848469b97b558fc61c85233e3dafb69c7a071b4d60c38099d3cd4c06"}, - {file = "rich-13.8.1.tar.gz", hash = "sha256:8260cda28e3db6bf04d2d1ef4dbc03ba80a824c88b0e7668a0f23126a424844a"}, + {file = "rich-13.9.1-py3-none-any.whl", hash = "sha256:b340e739f30aa58921dc477b8adaa9ecdb7cecc217be01d93730ee1bc8aa83be"}, + {file = "rich-13.9.1.tar.gz", hash = "sha256:097cffdf85db1babe30cc7deba5ab3a29e1b9885047dab24c57e9a7f8a9c1466"}, ] [package.dependencies] markdown-it-py = ">=2.2.0" pygments = ">=2.13.0,<3.0.0" +typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""} [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] @@ -2595,13 +2585,13 @@ files = [ [[package]] name = "tomli" -version = "2.0.1" +version = "2.0.2" description = "A lil' TOML parser" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, - {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, + {file = "tomli-2.0.2-py3-none-any.whl", hash = "sha256:2ebe24485c53d303f690b0ec092806a085f07af5a5aa1464f3931eec36caaa38"}, + {file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"}, ] [[package]] @@ -2734,13 +2724,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "uvicorn" -version = "0.30.6" +version = "0.31.0" description = "The lightning-fast ASGI server." optional = false python-versions = ">=3.8" files = [ - {file = "uvicorn-0.30.6-py3-none-any.whl", hash = "sha256:65fd46fe3fda5bdc1b03b94eb634923ff18cd35b2f084813ea79d1f103f711b5"}, - {file = "uvicorn-0.30.6.tar.gz", hash = "sha256:4b15decdda1e72be08209e860a1e10e92439ad5b97cf44cc945fcbee66fc5788"}, + {file = "uvicorn-0.31.0-py3-none-any.whl", hash = "sha256:cac7be4dd4d891c363cd942160a7b02e69150dcbc7a36be04d5f4af4b17c8ced"}, + {file = "uvicorn-0.31.0.tar.gz", hash = "sha256:13bc21373d103859f68fe739608e2eb054a816dea79189bc3ca08ea89a275906"}, ] [package.dependencies] @@ -2753,13 +2743,13 @@ standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", [[package]] name = "virtualenv" -version = "20.26.5" +version = "20.26.6" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.7" files = [ - {file = "virtualenv-20.26.5-py3-none-any.whl", hash = "sha256:4f3ac17b81fba3ce3bd6f4ead2749a72da5929c01774948e243db9ba41df4ff6"}, - {file = "virtualenv-20.26.5.tar.gz", hash = "sha256:ce489cac131aa58f4b25e321d6d186171f78e6cb13fafbf32a840cee67733ff4"}, + {file = "virtualenv-20.26.6-py3-none-any.whl", hash = "sha256:7345cc5b25405607a624d8418154577459c3e0277f5466dd79c49d5e492995f2"}, + {file = "virtualenv-20.26.6.tar.gz", hash = "sha256:280aede09a2a5c317e409a00102e7077c6432c5a38f0ef938e643805a7ad2c48"}, ] [package.dependencies] @@ -3011,4 +3001,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "adccd071775567aeefe219261aeb9e222906c865745f03edb1e770edc79c44ac" +content-hash = "e4b462ebfae90550ba7fa49b360d7110c0d344ee616c23989c22d866ef8f6f31" diff --git a/pyproject.toml b/pyproject.toml index 08c4fbdbcc5..2817413687e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ packages = [ [tool.poetry.dependencies] python = "^3.9" -dill = ">=0.3.8,<0.4" fastapi = ">=0.96.0,!=0.111.0,!=0.111.1" gunicorn = ">=20.1.0,<24.0" jinja2 = ">=3.1.2,<4.0" diff --git a/reflex/state.py b/reflex/state.py index b1988e38a48..5798564fa4c 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -9,6 +9,7 @@ import functools import inspect import os +import pickle import uuid from abc import ABC, abstractmethod from collections import defaultdict @@ -19,6 +20,7 @@ TYPE_CHECKING, Any, AsyncIterator, + BinaryIO, Callable, ClassVar, Dict, @@ -33,7 +35,6 @@ get_type_hints, ) -import dill from sqlalchemy.orm import DeclarativeBase from typing_extensions import Self @@ -76,6 +77,7 @@ ImmutableStateError, LockExpiredError, SetUndefinedStateVarError, + StateSchemaMismatchError, ) from reflex.utils.exec import is_testing_env from reflex.utils.serializers import serializer @@ -1914,7 +1916,7 @@ async def __aexit__(self, *exc_info: Any) -> None: def __getstate__(self): """Get the state for redis serialization. - This method is called by cloudpickle to serialize the object. + This method is called by pickle to serialize the object. It explicitly removes parent_state and substates because those are serialized separately by the StateManagerRedis to allow for better horizontal scaling as state size increases. @@ -1930,6 +1932,43 @@ def __getstate__(self): state["__dict__"].pop("_was_touched", None) return state + def _serialize(self) -> bytes: + """Serialize the state for redis. + + Returns: + The serialized state. + """ + return pickle.dumps((state_to_schema(self), self)) + + @classmethod + def _deserialize( + cls, data: bytes | None = None, fp: BinaryIO | None = None + ) -> BaseState: + """Deserialize the state from redis/disk. + + data and fp are mutually exclusive, but one must be provided. + + Args: + data: The serialized state data. + fp: The file pointer to the serialized state data. + + Returns: + The deserialized state. + + Raises: + ValueError: If both data and fp are provided, or neither are provided. + StateSchemaMismatchError: If the state schema does not match the expected schema. + """ + if data is not None and fp is None: + (substate_schema, state) = pickle.loads(data) + elif fp is not None and data is None: + (substate_schema, state) = pickle.load(fp) + else: + raise ValueError("Only one of `data` or `fp` must be provided") + if substate_schema != state_to_schema(state): + raise StateSchemaMismatchError() + return state + class State(BaseState): """The app Base State.""" @@ -2086,7 +2125,11 @@ def create(cls, *children, **props) -> "Component": """ cls._per_component_state_instance_count += 1 state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}" - component_state = type(state_cls_name, (cls, State), {}, mixin=False) + component_state = type( + state_cls_name, (cls, State), {"__module__": __name__}, mixin=False + ) + # Save a reference to the dynamic state for pickle/unpickle. + globals()[state_cls_name] = component_state component = component_state.get_component(*children, **props) component.State = component_state return component @@ -2552,7 +2595,7 @@ def is_serializable(value: Any) -> bool: Whether the value is serializable. """ try: - return bool(dill.dumps(value)) + return bool(pickle.dumps(value)) except Exception: return False @@ -2688,8 +2731,7 @@ async def load_state(self, token: str, root_state: BaseState) -> BaseState: if token_path.exists(): try: with token_path.open(mode="rb") as file: - (substate_schema, substate) = dill.load(file) - if substate_schema == state_to_schema(substate): + substate = BaseState._deserialize(fp=file) await self.populate_substates(client_token, substate, root_state) return substate except Exception: @@ -2731,10 +2773,12 @@ async def get_state( client_token, substate_address = _split_substate_key(token) root_state_token = _substate_key(client_token, substate_address.split(".")[0]) + root_state = self.states.get(root_state_token) + if root_state is None: + # Create a new root state which will be persisted in the next set_state call. + root_state = self.state(_reflex_internal_init=True) - return await self.load_state( - root_state_token, self.state(_reflex_internal_init=True) - ) + return await self.load_state(root_state_token, root_state) async def set_state_for_substate(self, client_token: str, substate: BaseState): """Set the state for a substate. @@ -2747,7 +2791,7 @@ async def set_state_for_substate(self, client_token: str, substate: BaseState): self.states[substate_token] = substate - state_dilled = dill.dumps((state_to_schema(substate), substate)) + state_dilled = substate._serialize() if not self.states_directory.exists(): self.states_directory.mkdir(parents=True, exist_ok=True) self.token_path(substate_token).write_bytes(state_dilled) @@ -2790,25 +2834,6 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: await self.set_state(token, state) -# Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes -if not isinstance(State.validate.__func__, FunctionType): - cython_function_or_method = type(State.validate.__func__) - - @dill.register(cython_function_or_method) - def _dill_reduce_cython_function_or_method(pickler, obj): - # Ignore cython function when pickling. - pass - - -@dill.register(type(State)) -def _dill_reduce_state(pickler, obj): - if obj is not State and issubclass(obj, State): - # Avoid serializing subclasses of State, instead get them by reference from the State class. - pickler.save_reduce(State.get_class_substate, (obj.get_full_name(),), obj=obj) - else: - dill.Pickler.dispatch[type](pickler, obj) - - def _default_lock_expiration() -> int: """Get the default lock expiration time. @@ -2948,7 +2973,7 @@ async def get_state( if redis_state is not None: # Deserialize the substate. - state = dill.loads(redis_state) + state = BaseState._deserialize(data=redis_state) # Populate parent state if missing and requested. if parent_state is None: @@ -3060,7 +3085,7 @@ async def set_state( ) # Persist only the given state (parents or substates are excluded by BaseState.__getstate__). if state._get_was_touched(): - pickle_state = dill.dumps(state, byref=True) + pickle_state = state._serialize() self._warn_if_too_large(state, len(pickle_state)) await self.redis.set( _substate_key(client_token, state), diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 0383f7ba620..8bce605b566 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -123,3 +123,7 @@ class DynamicComponentMissingLibrary(ReflexError, ValueError): class SetUndefinedStateVarError(ReflexError, AttributeError): """Raised when setting the value of a var without first declaring it.""" + + +class StateSchemaMismatchError(ReflexError, TypeError): + """Raised when the serialized schema of a state class does not match the current schema."""