diff --git a/reflex/state.py b/reflex/state.py index e7e6bcf326b..65bc71d904d 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -107,6 +107,7 @@ StateSchemaMismatchError, StateSerializationError, StateTooLargeError, + UnretrievableVarValueError, ) from reflex.utils.exec import is_testing_env from reflex.utils.serializers import serializer @@ -143,6 +144,9 @@ ValueError, ) +# For BaseState.get_var_value +VAR_TYPE = TypeVar("VAR_TYPE") + def _no_chain_background_task( state_cls: Type["BaseState"], name: str, fn: Callable @@ -1596,6 +1600,42 @@ async def get_state(self, state_cls: Type[BaseState]) -> BaseState: # Slow case - fetch missing parent states from redis. return await self._get_state_from_redis(state_cls) + async def get_var_value(self, var: Var[VAR_TYPE]) -> VAR_TYPE: + """Get the value of an rx.Var from another state. + + Args: + var: The var to get the value for. + + Returns: + The value of the var. + + Raises: + UnretrievableVarValueError: If the var does not have a literal value + or associated state. + """ + # Oopsie case: you didn't give me a Var... so get what you give. + if not isinstance(var, Var): + return var # type: ignore + + # Fast case: this is a literal var and the value is known. + if hasattr(var, "_var_value"): + return var._var_value + + var_data = var._get_all_var_data() + if var_data is None or not var_data.state: + raise UnretrievableVarValueError( + f"Unable to retrieve value for {var._js_expr}: not associated with any state." + ) + # Fastish case: this var belongs to this state + if var_data.state == self.get_full_name(): + return getattr(self, var_data.field_name) + + # Slow case: this var belongs to another state + other_state = await self.get_state( + self._get_root_state().get_class_substate(var_data.state) + ) + return getattr(other_state, var_data.field_name) + def _get_event_handler( self, event: Event ) -> tuple[BaseState | StateProxy, EventHandler]: diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index ae5ec016836..bceadc977ee 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -187,3 +187,7 @@ def raise_system_package_missing_error(package: str) -> NoReturn: class InvalidLockWarningThresholdError(ReflexError): """Raised when an invalid lock warning threshold is provided.""" + + +class UnretrievableVarValueError(ReflexError): + """Raised when the value of a var is not retrievable.""" diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 912d72f4f15..c1780b4f047 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -60,6 +60,7 @@ ReflexRuntimeError, SetUndefinedStateVarError, StateSerializationError, + UnretrievableVarValueError, ) from reflex.utils.format import json_dumps from reflex.vars.base import Var, computed_var @@ -115,7 +116,7 @@ class TestState(BaseState): # Set this class as not test one __test__ = False - num1: int + num1: rx.Field[int] num2: float = 3.14 key: str map_key: str = "a" @@ -163,7 +164,7 @@ class ChildState(TestState): """A child state fixture.""" value: str - count: int = 23 + count: rx.Field[int] = rx.field(23) def change_both(self, value: str, count: int): """Change both the value and count. @@ -1663,7 +1664,7 @@ async def state_manager(request) -> AsyncGenerator[StateManager, None]: @pytest.fixture() -def substate_token(state_manager, token): +def substate_token(state_manager, token) -> str: """A token + substate name for looking up in state manager. Args: @@ -3764,3 +3765,32 @@ async def test_upcast_event_handler_arg(handler, payload): state = UpcastState() async for update in state._process_event(handler, state, payload): assert update.delta == {UpcastState.get_full_name(): {"passed": True}} + + +@pytest.mark.asyncio +async def test_get_var_value(state_manager: StateManager, substate_token: str): + """Test that get_var_value works correctly. + + Args: + state_manager: The state manager to use. + substate_token: Token for the substate used by state_manager. + """ + state = await state_manager.get_state(substate_token) + + # State Var from same state + assert await state.get_var_value(TestState.num1) == 0 + state.num1 = 42 + assert await state.get_var_value(TestState.num1) == 42 + + # State Var from another state + child_state = await state.get_state(ChildState) + assert await state.get_var_value(ChildState.count) == 23 + child_state.count = 66 + assert await state.get_var_value(ChildState.count) == 66 + + # LiteralVar with known value + assert await state.get_var_value(rx.Var.create([1, 2, 3])) == [1, 2, 3] + + # Generic Var with no state + with pytest.raises(UnretrievableVarValueError): + await state.get_var_value(rx.Var("undefined"))