diff --git a/reflex/__init__.py b/reflex/__init__.py index 63de1f386c..57e7768ea3 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -331,7 +331,7 @@ "style": ["Style", "toggle_color_mode"], "utils.imports": ["ImportVar"], "utils.serializers": ["serializer"], - "vars": ["Var"], + "vars": ["Var", "field", "Field"], } _SUBMODULES: set[str] = { diff --git a/reflex/__init__.pyi b/reflex/__init__.pyi index ef5bcfd8f6..28d768c357 100644 --- a/reflex/__init__.pyi +++ b/reflex/__init__.pyi @@ -189,7 +189,9 @@ from .style import Style as Style from .style import toggle_color_mode as toggle_color_mode from .utils.imports import ImportVar as ImportVar from .utils.serializers import serializer as serializer +from .vars import Field as Field from .vars import Var as Var +from .vars import field as field del compat RADIX_THEMES_MAPPING: dict diff --git a/reflex/state.py b/reflex/state.py index 5798564fa4..330f139b23 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -32,6 +32,7 @@ Type, Union, cast, + get_args, get_type_hints, ) @@ -81,7 +82,7 @@ ) from reflex.utils.exec import is_testing_env from reflex.utils.serializers import serializer -from reflex.utils.types import override +from reflex.utils.types import get_origin, override from reflex.vars import VarData if TYPE_CHECKING: @@ -240,12 +241,16 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField): Returns: The Var instance. """ + from reflex.vars import Field + field_name = format.format_state_name(cls.get_full_name()) + "." + f.name return dispatch( field_name=field_name, var_data=VarData.from_state(cls, f.name), - result_var_type=f.outer_type_, + result_var_type=f.outer_type_ + if get_origin(f.outer_type_) is not Field + else get_args(f.outer_type_)[0], ) diff --git a/reflex/vars/__init__.py b/reflex/vars/__init__.py index 56d304cd6b..1a4cebe19a 100644 --- a/reflex/vars/__init__.py +++ b/reflex/vars/__init__.py @@ -1,8 +1,10 @@ """Immutable-Based Var System.""" +from .base import Field as Field from .base import LiteralVar as LiteralVar from .base import Var as Var from .base import VarData as VarData +from .base import field as field from .base import get_unique_variable_name as get_unique_variable_name from .base import get_uuid_string_var as get_uuid_string_var from .base import var_operation as var_operation diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 7eab62c681..e9429e01ad 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -2832,3 +2832,68 @@ def dispatch( _var_data=var_data, _var_type=result_var_type, ).guess_type() + + +V = TypeVar("V") + + +class Field(Generic[T]): + """Shadow class for Var to allow for type hinting in the IDE.""" + + def __set__(self, instance, value: T): + """Set the Var. + + Args: + instance: The instance of the class setting the Var. + value: The value to set the Var to. + """ + + @overload + def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ... + + @overload + def __get__(self: Field[int], instance: None, owner) -> NumberVar: ... + + @overload + def __get__(self: Field[str], instance: None, owner) -> StringVar: ... + + @overload + def __get__(self: Field[None], instance: None, owner) -> NoneVar: ... + + @overload + def __get__( + self: Field[List[V]] | Field[Set[V]] | Field[Tuple[V, ...]], + instance: None, + owner, + ) -> ArrayVar[List[V]]: ... + + @overload + def __get__( + self: Field[Dict[str, V]], instance: None, owner + ) -> ObjectVar[Dict[str, V]]: ... + + @overload + def __get__(self, instance: None, owner) -> Var[T]: ... + + @overload + def __get__(self, instance, owner) -> T: ... + + def __get__(self, instance, owner): # type: ignore + """Get the Var. + + Args: + instance: The instance of the class accessing the Var. + owner: The class that the Var is attached to. + """ + + +def field(value: T) -> Field[T]: + """Create a Field with a value. + + Args: + value: The value of the Field. + + Returns: + The Field. + """ + return value # type: ignore