diff --git a/reflex/experimental/vars/__init__.py b/reflex/experimental/vars/__init__.py index 98fa802d30..c4b3e6913b 100644 --- a/reflex/experimental/vars/__init__.py +++ b/reflex/experimental/vars/__init__.py @@ -3,10 +3,16 @@ from .base import ArrayVar as ArrayVar from .base import BooleanVar as BooleanVar from .base import ConcatVarOperation as ConcatVarOperation +from .base import FunctionStringVar as FunctionStringVar from .base import FunctionVar as FunctionVar from .base import ImmutableVar as ImmutableVar +from .base import LiteralArrayVar as LiteralArrayVar +from .base import LiteralBooleanVar as LiteralBooleanVar +from .base import LiteralNumberVar as LiteralNumberVar +from .base import LiteralObjectVar as LiteralObjectVar from .base import LiteralStringVar as LiteralStringVar from .base import LiteralVar as LiteralVar from .base import NumberVar as NumberVar from .base import ObjectVar as ObjectVar from .base import StringVar as StringVar +from .base import VarOperationCall as VarOperationCall diff --git a/reflex/experimental/vars/base.py b/reflex/experimental/vars/base.py index 258f8d6c32..af0d350f1f 100644 --- a/reflex/experimental/vars/base.py +++ b/reflex/experimental/vars/base.py @@ -7,9 +7,10 @@ import re import sys from functools import cached_property -from typing import Any, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from reflex import constants +from reflex.base import Base from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG from reflex.utils import serializers, types from reflex.utils.exceptions import VarTypeError @@ -95,6 +96,11 @@ def __hash__(self) -> int: return hash((self._var_name, self._var_type, self._var_data)) def _get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ return self._var_data def _replace(self, merge_var_data=None, **kwargs: Any): @@ -275,10 +281,250 @@ class ArrayVar(ImmutableVar): class FunctionVar(ImmutableVar): """Base class for immutable function vars.""" + def __call__(self, *args: Var | Any) -> ArgsFunctionOperation: + """Call the function with the given arguments. + + Args: + *args: The arguments to call the function with. + + Returns: + The function call operation. + """ + return ArgsFunctionOperation( + ("...args",), + VarOperationCall(self, *args, ImmutableVar.create_safe("...args")), + ) + + def call(self, *args: Var | Any) -> VarOperationCall: + """Call the function with the given arguments. + + Args: + *args: The arguments to call the function with. + + Returns: + The function call operation. + """ + return VarOperationCall(self, *args) + + +class FunctionStringVar(FunctionVar): + """Base class for immutable function vars from a string.""" + + def __init__(self, func: str, _var_data: VarData | None = None) -> None: + """Initialize the function var. + + Args: + func: The function to call. + _var_data: Additional hooks and imports associated with the Var. + """ + super(FunctionVar, self).__init__( + _var_name=func, + _var_type=Callable, + _var_data=ImmutableVarData.merge(_var_data), + ) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class VarOperationCall(ImmutableVar): + """Base class for immutable vars that are the result of a function call.""" + + _func: Optional[FunctionVar] = dataclasses.field(default=None) + _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple) + + def __init__( + self, func: FunctionVar, *args: Var | Any, _var_data: VarData | None = None + ): + """Initialize the function call var. + + Args: + func: The function to call. + *args: The arguments to call the function with. + _var_data: Additional hooks and imports associated with the Var. + """ + super(VarOperationCall, self).__init__( + _var_name="", + _var_type=Callable, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_func", func) + object.__setattr__(self, "_args", args) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))" + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self._func._get_all_var_data() if self._func is not None else None, + *[var._get_all_var_data() for var in self._args], + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + def __post_init__(self): + """Post-initialize the var.""" + pass + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ArgsFunctionOperation(FunctionVar): + """Base class for immutable function defined via arguments and return expression.""" + + _args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple) + _return_expr: Union[Var, Any] = dataclasses.field(default=None) + + def __init__( + self, + args_names: Tuple[str, ...], + return_expr: Var | Any, + _var_data: VarData | None = None, + ) -> None: + """Initialize the function with arguments var. + + Args: + args_names: The names of the arguments. + return_expr: The return expression of the function. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ArgsFunctionOperation, self).__init__( + _var_name=f"", + _var_type=Callable, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_args_names", args_names) + object.__setattr__(self, "_return_expr", return_expr) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"(({', '.join(self._args_names)}) => ({str(LiteralVar.create(self._return_expr))}))" + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self._return_expr._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + def __post_init__(self): + """Post-initialize the var.""" + class LiteralVar(ImmutableVar): """Base class for immutable literal vars.""" + @classmethod + def create( + cls, + value: Any, + _var_data: VarData | None = None, + ) -> Var: + """Create a var from a value. + + Args: + value: The value to create the var from. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + + Raises: + TypeError: If the value is not a supported type for LiteralVar. + """ + if isinstance(value, Var): + if _var_data is None: + return value + return value._replace(merge_var_data=_var_data) + + if value is None: + return ImmutableVar.create_safe("null", _var_data=_var_data) + + if isinstance(value, Base): + return LiteralObjectVar( + value.dict(), _var_type=type(value), _var_data=_var_data + ) + + if isinstance(value, str): + return LiteralStringVar.create(value, _var_data=_var_data) + + constructor = type_mapping.get(type(value)) + + if constructor is None: + raise TypeError(f"Unsupported type {type(value)} for LiteralVar.") + + return constructor(value, _var_data=_var_data) + def __post_init__(self): """Post-initialize the var.""" @@ -298,7 +544,25 @@ def __post_init__(self): class LiteralStringVar(LiteralVar): """Base class for immutable literal string vars.""" - _var_value: Optional[str] = dataclasses.field(default=None) + _var_value: str = dataclasses.field(default="") + + def __init__( + self, + _var_value: str, + _var_data: VarData | None = None, + ): + """Initialize the string var. + + Args: + _var_value: The value of the var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralStringVar, self).__init__( + _var_name=f'"{_var_value}"', + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_var_value", _var_value) @classmethod def create( @@ -316,7 +580,7 @@ def create( The var. """ if REFLEX_VAR_OPENING_TAG in value: - strings_and_vals: list[Var] = [] + strings_and_vals: list[Var | str] = [] offset = 0 # Initialize some methods for reading json. @@ -334,7 +598,7 @@ def json_loads(s): while m := _decode_var_pattern.search(value): start, end = m.span() if start > 0: - strings_and_vals.append(LiteralStringVar.create(value[:start])) + strings_and_vals.append(value[:start]) serialized_data = m.group(1) @@ -364,17 +628,13 @@ def json_loads(s): offset += end - start if value: - strings_and_vals.append(LiteralStringVar.create(value)) + strings_and_vals.append(value) - return ConcatVarOperation.create( - tuple(strings_and_vals), _var_data=_var_data - ) + return ConcatVarOperation(*strings_and_vals, _var_data=_var_data) - return cls( - _var_value=value, - _var_name=f'"{value}"', - _var_type=str, - _var_data=ImmutableVarData.merge(_var_data), + return LiteralStringVar( + value, + _var_data=_var_data, ) @@ -386,20 +646,33 @@ def json_loads(s): class ConcatVarOperation(StringVar): """Representing a concatenation of literal string vars.""" - _var_value: tuple[Var, ...] = dataclasses.field(default_factory=tuple) + _var_value: Tuple[Union[Var, str], ...] = dataclasses.field(default_factory=tuple) - def __init__(self, _var_value: tuple[Var, ...], _var_data: VarData | None = None): + def __init__(self, *value: Var | str, _var_data: VarData | None = None): """Initialize the operation of concatenating literal string vars. Args: - _var_value: The list of vars to concatenate. + value: The values to concatenate. _var_data: Additional hooks and imports associated with the Var. """ super(ConcatVarOperation, self).__init__( _var_name="", _var_data=ImmutableVarData.merge(_var_data), _var_type=str ) - object.__setattr__(self, "_var_value", _var_value) - object.__setattr__(self, "_var_name", self._cached_var_name) + object.__setattr__(self, "_var_value", value) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) @cached_property def _cached_var_name(self) -> str: @@ -408,7 +681,16 @@ def _cached_var_name(self) -> str: Returns: The name of the var. """ - return "+".join([str(element) for element in self._var_value]) + return ( + "(" + + "+".join( + [ + str(element) if isinstance(element, Var) else f'"{element}"' + for element in self._var_value + ] + ) + + ")" + ) @cached_property def _cached_get_all_var_data(self) -> ImmutableVarData | None: @@ -418,7 +700,12 @@ def _cached_get_all_var_data(self) -> ImmutableVarData | None: The VarData of the components and all of its children. """ return ImmutableVarData.merge( - *[var._get_all_var_data() for var in self._var_value], self._var_data + *[ + var._get_all_var_data() + for var in self._var_value + if isinstance(var, Var) + ], + self._var_data, ) def _get_all_var_data(self) -> ImmutableVarData | None: @@ -433,22 +720,236 @@ def __post_init__(self): """Post-initialize the var.""" pass - @classmethod - def create( - cls, - value: tuple[Var, ...], + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralBooleanVar(LiteralVar): + """Base class for immutable literal boolean vars.""" + + _var_value: bool = dataclasses.field(default=False) + + def __init__( + self, + _var_value: bool, _var_data: VarData | None = None, - ) -> ConcatVarOperation: - """Create a var from a tuple of values. + ): + """Initialize the boolean var. Args: - value: The value to create the var from. + _var_value: The value of the var. _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralBooleanVar, self).__init__( + _var_name="true" if _var_value else "false", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_var_value", _var_value) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralNumberVar(LiteralVar): + """Base class for immutable literal number vars.""" + + _var_value: float | int = dataclasses.field(default=0) + + def __init__( + self, + _var_value: float | int, + _var_data: VarData | None = None, + ): + """Initialize the number var. + + Args: + _var_value: The value of the var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralNumberVar, self).__init__( + _var_name=str(_var_value), + _var_type=type(_var_value), + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_var_value", _var_value) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralObjectVar(LiteralVar): + """Base class for immutable literal object vars.""" + + _var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field( + default_factory=dict + ) + + def __init__( + self, + _var_value: dict[Var | Any, Var | Any], + _var_type: Type = dict, + _var_data: VarData | None = None, + ): + """Initialize the object var. + + Args: + _var_value: The value of the var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralObjectVar, self).__init__( + _var_name="", + _var_type=_var_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, + "_var_value", + _var_value, + ) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. Returns: - The var. + The attribute of the var. """ - return ConcatVarOperation( - _var_value=value, - _var_data=_var_data, + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return ( + "{ " + + ", ".join( + [ + f"[{str(LiteralVar.create(key))}] : {str(LiteralVar.create(value))}" + for key, value in self._var_value.items() + ] + ) + + " }" + ) + + @cached_property + def _get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + *[ + value._get_all_var_data() + for key, value in self._var_value + if isinstance(value, Var) + ], + *[ + key._get_all_var_data() + for key, value in self._var_value + if isinstance(key, Var) + ], + self._var_data, + ) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralArrayVar(LiteralVar): + """Base class for immutable literal array vars.""" + + _var_value: Union[ + List[Union[Var, Any]], Set[Union[Var, Any]], Tuple[Union[Var, Any], ...] + ] = dataclasses.field(default_factory=list) + + def __init__( + self, + _var_value: list[Var | Any] | tuple[Var | Any] | set[Var | Any], + _var_data: VarData | None = None, + ): + """Initialize the array var. + + Args: + _var_value: The value of the var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralArrayVar, self).__init__( + _var_name="", + _var_data=ImmutableVarData.merge(_var_data), + _var_type=list, ) + object.__setattr__(self, "_var_value", _var_value) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return ( + "[" + + ", ".join( + [str(LiteralVar.create(element)) for element in self._var_value] + ) + + "]" + ) + + @cached_property + def _get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + *[ + var._get_all_var_data() + for var in self._var_value + if isinstance(var, Var) + ], + self._var_data, + ) + + +type_mapping = { + int: LiteralNumberVar, + float: LiteralNumberVar, + bool: LiteralBooleanVar, + dict: LiteralObjectVar, + list: LiteralArrayVar, + tuple: LiteralArrayVar, + set: LiteralArrayVar, +} diff --git a/tests/test_var.py b/tests/test_var.py index 78b3a2160d..47d4f223b7 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -8,9 +8,12 @@ from reflex.base import Base from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG from reflex.experimental.vars.base import ( + ArgsFunctionOperation, ConcatVarOperation, + FunctionStringVar, ImmutableVar, LiteralStringVar, + LiteralVar, ) from reflex.state import BaseState from reflex.utils.imports import ImportVar @@ -858,6 +861,58 @@ def test_state_with_initial_computed_var( assert runtime_dict[var_name] == expected_runtime +def test_literal_var(): + complicated_var = LiteralVar.create( + [ + {"a": 1, "b": 2, "c": {"d": 3, "e": 4}}, + [1, 2, 3, 4], + 9, + "string", + True, + False, + None, + set([1, 2, 3]), + ] + ) + assert ( + str(complicated_var) + == '[{ ["a"] : 1, ["b"] : 2, ["c"] : { ["d"] : 3, ["e"] : 4 } }, [1, 2, 3, 4], 9, "string", true, false, null, [1, 2, 3]]' + ) + + +def test_function_var(): + addition_func = FunctionStringVar("((a, b) => a + b)") + assert str(addition_func.call(1, 2)) == "(((a, b) => a + b)(1, 2))" + + manual_addition_func = ArgsFunctionOperation( + ("a", "b"), + { + "args": [ImmutableVar.create_safe("a"), ImmutableVar.create_safe("b")], + "result": ImmutableVar.create_safe("a + b"), + }, + ) + assert ( + str(manual_addition_func.call(1, 2)) + == '(((a, b) => ({ ["args"] : [a, b], ["result"] : a + b }))(1, 2))' + ) + + increment_func = addition_func(1) + assert ( + str(increment_func.call(2)) + == "(((...args) => ((((a, b) => a + b)(1, ...args))))(2))" + ) + + create_hello_statement = ArgsFunctionOperation( + ("name",), f"Hello, {ImmutableVar.create_safe('name')}!" + ) + first_name = LiteralStringVar("Steven") + last_name = LiteralStringVar("Universe") + assert ( + str(create_hello_statement.call(f"{first_name} {last_name}")) + == '(((name) => (("Hello, "+name+"!")))(("Steven"+" "+"Universe")))' + ) + + def test_retrival(): var_without_data = ImmutableVar.create("test") assert var_without_data is not None @@ -931,7 +986,7 @@ def test_fstring_concat(): ), ) - assert str(string_concat) == '"foo"+imagination+"bar"+consequences+"baz"' + assert str(string_concat) == '("foo"+imagination+"bar"+consequences+"baz")' assert isinstance(string_concat, ConcatVarOperation) assert string_concat._get_all_var_data() == ImmutableVarData( state="fear",