diff --git a/reflex/event.py b/reflex/event.py index f93bc63d2a..7a4e0713b1 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -12,7 +12,6 @@ from typing import ( Any, Callable, - ClassVar, Dict, Generic, List, @@ -33,9 +32,7 @@ from reflex.utils.types import ArgsSpec, GenericType from reflex.vars import VarData from reflex.vars.base import ( - LiteralNoneVar, LiteralVar, - ToOperation, Var, ) from reflex.vars.function import ( @@ -1254,7 +1251,7 @@ def get_fn_signature(fn: Callable) -> inspect.Signature: return signature.replace(parameters=(new_param, *signature.parameters.values())) -class EventVar(ObjectVar): +class EventVar(ObjectVar, python_types=EventSpec): """Base class for event vars.""" @@ -1315,7 +1312,7 @@ def create( ) -class EventChainVar(FunctionVar): +class EventChainVar(FunctionVar, python_types=EventChain): """Base class for event chain vars.""" @@ -1384,32 +1381,6 @@ def create( ) -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ToEventVarOperation(ToOperation, EventVar): - """Result of a cast to an event var.""" - - _original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create()) - - _default_var_type: ClassVar[Type] = EventSpec - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ToEventChainVarOperation(ToOperation, EventChainVar): - """Result of a cast to an event chain var.""" - - _original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create()) - - _default_var_type: ClassVar[Type] = EventChain - - G = ParamSpec("G") IndividualEventType = Union[EventSpec, EventHandler, Callable[G, Any], Var] @@ -1537,8 +1508,6 @@ class EventNamespace(types.SimpleNamespace): LiteralEventVar = LiteralEventVar EventChainVar = EventChainVar LiteralEventChainVar = LiteralEventChainVar - ToEventVarOperation = ToEventVarOperation - ToEventChainVarOperation = ToEventChainVarOperation EventType = EventType __call__ = staticmethod(event_handler) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 14e7251bb7..aee7db678b 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -19,6 +19,7 @@ TYPE_CHECKING, Any, Callable, + ClassVar, Dict, FrozenSet, Generic, @@ -37,7 +38,13 @@ overload, ) -from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override +from typing_extensions import ( + ParamSpec, + TypeGuard, + deprecated, + get_type_hints, + override, +) from reflex import constants from reflex.base import Base @@ -61,15 +68,13 @@ if TYPE_CHECKING: from reflex.state import BaseState - from .function import FunctionVar, ToFunctionOperation + from .function import FunctionVar from .number import ( BooleanVar, NumberVar, - ToBooleanVarOperation, - ToNumberVarOperation, ) - from .object import ObjectVar, ToObjectOperation - from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation + from .object import ObjectVar + from .sequence import ArrayVar, StringVar VAR_TYPE = TypeVar("VAR_TYPE", covariant=True) @@ -78,6 +83,184 @@ warnings.filterwarnings("ignore", message="fields may not start with an underscore") +@dataclasses.dataclass( + eq=False, + frozen=True, +) +class VarSubclassEntry: + """Entry for a Var subclass.""" + + var_subclass: Type[Var] + to_var_subclass: Type[ToOperation] + python_types: Tuple[GenericType, ...] + + +_var_subclasses: List[VarSubclassEntry] = [] +_var_literal_subclasses: List[Tuple[Type[LiteralVar], VarSubclassEntry]] = [] + + +@dataclasses.dataclass( + eq=True, + frozen=True, +) +class VarData: + """Metadata associated with a x.""" + + # The name of the enclosing state. + state: str = dataclasses.field(default="") + + # The name of the field in the state. + field_name: str = dataclasses.field(default="") + + # Imports needed to render this var + imports: ImmutableParsedImportDict = dataclasses.field(default_factory=tuple) + + # Hooks that need to be present in the component to render this var + hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple) + + def __init__( + self, + state: str = "", + field_name: str = "", + imports: ImportDict | ParsedImportDict | None = None, + hooks: dict[str, None] | None = None, + ): + """Initialize the var data. + + Args: + state: The name of the enclosing state. + field_name: The name of the field in the state. + imports: Imports needed to render this var. + hooks: Hooks that need to be present in the component to render this var. + """ + immutable_imports: ImmutableParsedImportDict = tuple( + sorted( + ((k, tuple(sorted(v))) for k, v in parse_imports(imports or {}).items()) + ) + ) + object.__setattr__(self, "state", state) + object.__setattr__(self, "field_name", field_name) + object.__setattr__(self, "imports", immutable_imports) + object.__setattr__(self, "hooks", tuple(hooks or {})) + + def old_school_imports(self) -> ImportDict: + """Return the imports as a mutable dict. + + Returns: + The imports as a mutable dict. + """ + return dict((k, list(v)) for k, v in self.imports) + + @classmethod + def merge(cls, *others: VarData | None) -> VarData | None: + """Merge multiple var data objects. + + Args: + *others: The var data objects to merge. + + Returns: + The merged var data object. + """ + state = "" + field_name = "" + _imports = {} + hooks = {} + for var_data in others: + if var_data is None: + continue + state = state or var_data.state + field_name = field_name or var_data.field_name + _imports = imports.merge_imports(_imports, var_data.imports) + hooks.update( + var_data.hooks + if isinstance(var_data.hooks, dict) + else {k: None for k in var_data.hooks} + ) + + if state or _imports or hooks or field_name: + return VarData( + state=state, + field_name=field_name, + imports=_imports, + hooks=hooks, + ) + return None + + def __bool__(self) -> bool: + """Check if the var data is non-empty. + + Returns: + True if any field is set to a non-default value. + """ + return bool(self.state or self.imports or self.hooks or self.field_name) + + @classmethod + def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData: + """Set the state of the var. + + Args: + state: The state to set or the full name of the state. + field_name: The name of the field in the state. Optional. + + Returns: + The var with the set state. + """ + from reflex.utils import format + + state_name = state if isinstance(state, str) else state.get_full_name() + return VarData( + state=state_name, + field_name=field_name, + hooks={ + "const {0} = useContext(StateContexts.{0})".format( + format.format_state_name(state_name) + ): None + }, + imports={ + f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")], + "react": [ImportVar(tag="useContext")], + }, + ) + + +def _decode_var_immutable(value: str) -> tuple[VarData | None, str]: + """Decode the state name from a formatted var. + + Args: + value: The value to extract the state name from. + + Returns: + The extracted state name and the value without the state name. + """ + var_datas = [] + if isinstance(value, str): + # fast path if there is no encoded VarData + if constants.REFLEX_VAR_OPENING_TAG not in value: + return None, value + + offset = 0 + + # Find all tags. + while m := _decode_var_pattern.search(value): + start, end = m.span() + value = value[:start] + value[end:] + + serialized_data = m.group(1) + + if serialized_data.isnumeric() or ( + serialized_data[0] == "-" and serialized_data[1:].isnumeric() + ): + # This is a global immutable var. + var = _global_vars[int(serialized_data)] + var_data = var._get_all_var_data() + + if var_data is not None: + var_datas.append(var_data) + offset += end - start + + return VarData.merge(*var_datas) if var_datas else None, value + + @dataclasses.dataclass( eq=False, frozen=True, @@ -151,6 +334,40 @@ def _var_is_string(self) -> bool: """ return False + def __init_subclass__( + cls, python_types: Tuple[GenericType, ...] | GenericType = types.Unset, **kwargs + ): + """Initialize the subclass. + + Args: + python_types: The python types that the var represents. + **kwargs: Additional keyword arguments. + """ + super().__init_subclass__(**kwargs) + + if python_types is not types.Unset: + python_types = ( + python_types if isinstance(python_types, tuple) else (python_types,) + ) + + @dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, + ) + class ToVarOperation(ToOperation, cls): + """Base class of converting a var to another var type.""" + + _original: Var = dataclasses.field( + default=Var(_js_expr="null", _var_type=None), + ) + + _default_var_type: ClassVar[GenericType] = python_types[0] + + ToVarOperation.__name__ = f'To{cls.__name__.removesuffix("Var")}Operation' + + _var_subclasses.append(VarSubclassEntry(cls, ToVarOperation, python_types)) + def __post_init__(self): """Post-initialize the var.""" # Decode any inline Var markup and apply it to the instance @@ -331,35 +548,35 @@ def __format__(self, format_spec: str) -> str: return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._js_expr}" @overload - def to(self, output: Type[StringVar]) -> ToStringOperation: ... + def to(self, output: Type[StringVar]) -> StringVar: ... @overload - def to(self, output: Type[str]) -> ToStringOperation: ... + def to(self, output: Type[str]) -> StringVar: ... @overload - def to(self, output: Type[BooleanVar]) -> ToBooleanVarOperation: ... + def to(self, output: Type[BooleanVar]) -> BooleanVar: ... @overload def to( self, output: Type[NumberVar], var_type: type[int] | type[float] = float - ) -> ToNumberVarOperation: ... + ) -> NumberVar: ... @overload def to( self, output: Type[ArrayVar], var_type: type[list] | type[tuple] | type[set] = list, - ) -> ToArrayOperation: ... + ) -> ArrayVar: ... @overload def to( self, output: Type[ObjectVar], var_type: types.GenericType = dict - ) -> ToObjectOperation: ... + ) -> ObjectVar: ... @overload def to( self, output: Type[FunctionVar], var_type: Type[Callable] = Callable - ) -> ToFunctionOperation: ... + ) -> FunctionVar: ... @overload def to( @@ -379,56 +596,26 @@ def to( output: The output type. var_type: The type of the var. - Raises: - TypeError: If the var_type is not a supported type for the output. - Returns: The converted var. """ - from reflex.event import ( - EventChain, - EventChainVar, - EventSpec, - EventVar, - ToEventChainVarOperation, - ToEventVarOperation, - ) - - from .function import FunctionVar, ToFunctionOperation - from .number import ( - BooleanVar, - NumberVar, - ToBooleanVarOperation, - ToNumberVarOperation, - ) - from .object import ObjectVar, ToObjectOperation - from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation + from .object import ObjectVar base_type = var_type if types.is_optional(base_type): base_type = types.get_args(base_type)[0] - fixed_type = get_origin(base_type) or base_type - fixed_output_type = get_origin(output) or output # If the first argument is a python type, we map it to the corresponding Var type. - if fixed_output_type is dict: - return self.to(ObjectVar, output) - if fixed_output_type in (list, tuple, set): - return self.to(ArrayVar, output) - if fixed_output_type in (int, float): - return self.to(NumberVar, output) - if fixed_output_type is str: - return self.to(StringVar, output) - if fixed_output_type is bool: - return self.to(BooleanVar, output) + for var_subclass in _var_subclasses[::-1]: + if fixed_output_type in var_subclass.python_types: + return self.to(var_subclass.var_subclass, output) + if fixed_output_type is None: - return ToNoneOperation.create(self) - if fixed_output_type is EventSpec: - return self.to(EventVar, output) - if fixed_output_type is EventChain: - return self.to(EventChainVar, output) + return get_to_operation(NoneVar).create(self) # type: ignore + + # Handle fixed_output_type being Base or a dataclass. try: if issubclass(fixed_output_type, Base): return self.to(ObjectVar, output) @@ -440,57 +627,12 @@ def to( return self.to(ObjectVar, output) if inspect.isclass(output): - if issubclass(output, BooleanVar): - return ToBooleanVarOperation.create(self) - - if issubclass(output, NumberVar): - if fixed_type is not None: - if fixed_type in types.UnionTypes: - inner_types = get_args(base_type) - if not all(issubclass(t, (int, float)) for t in inner_types): - raise TypeError( - f"Unsupported type {var_type} for NumberVar. Must be int or float." - ) - - elif not issubclass(fixed_type, (int, float)): - raise TypeError( - f"Unsupported type {var_type} for NumberVar. Must be int or float." - ) - return ToNumberVarOperation.create(self, var_type or float) - - if issubclass(output, ArrayVar): - if fixed_type is not None and not issubclass( - fixed_type, (list, tuple, set) - ): - raise TypeError( - f"Unsupported type {var_type} for ArrayVar. Must be list, tuple, or set." + for var_subclass in _var_subclasses[::-1]: + if issubclass(output, var_subclass.var_subclass): + to_operation_return = var_subclass.to_var_subclass.create( + value=self, _var_type=var_type ) - return ToArrayOperation.create(self, var_type or list) - - if issubclass(output, StringVar): - return ToStringOperation.create(self, var_type or str) - - if issubclass(output, EventVar): - return ToEventVarOperation.create(self, var_type or EventSpec) - - if issubclass(output, EventChainVar): - return ToEventChainVarOperation.create(self, var_type or EventChain) - - if issubclass(output, (ObjectVar, Base)): - return ToObjectOperation.create(self, var_type or dict) - - if issubclass(output, FunctionVar): - # if fixed_type is not None and not issubclass(fixed_type, Callable): - # raise TypeError( - # f"Unsupported type {var_type} for FunctionVar. Must be Callable." - # ) - return ToFunctionOperation.create(self, var_type or Callable) - - if issubclass(output, NoneVar): - return ToNoneOperation.create(self) - - if dataclasses.is_dataclass(output): - return ToObjectOperation.create(self, var_type or dict) + return to_operation_return # type: ignore # If we can't determine the first argument, we just replace the _var_type. if not issubclass(output, Var) or var_type is None: @@ -508,6 +650,18 @@ def to( return self + @overload + def guess_type(self: Var[str]) -> StringVar: ... + + @overload + def guess_type(self: Var[bool]) -> BooleanVar: ... + + @overload + def guess_type(self: Var[int] | Var[float] | Var[int | float]) -> NumberVar: ... + + @overload + def guess_type(self) -> Self: ... + def guess_type(self) -> Var: """Guesses the type of the variable based on its `_var_type` attribute. @@ -517,11 +671,8 @@ def guess_type(self) -> Var: Raises: TypeError: If the type is not supported for guessing. """ - from reflex.event import EventChain, EventChainVar, EventSpec, EventVar - - from .number import BooleanVar, NumberVar + from .number import NumberVar from .object import ObjectVar - from .sequence import ArrayVar, StringVar var_type = self._var_type if var_type is None: @@ -558,20 +709,13 @@ def guess_type(self) -> Var: if not inspect.isclass(fixed_type): raise TypeError(f"Unsupported type {var_type} for guess_type.") - if issubclass(fixed_type, bool): - return self.to(BooleanVar, self._var_type) - if issubclass(fixed_type, (int, float)): - return self.to(NumberVar, self._var_type) - if issubclass(fixed_type, dict): - return self.to(ObjectVar, self._var_type) - if issubclass(fixed_type, (list, tuple, set)): - return self.to(ArrayVar, self._var_type) - if issubclass(fixed_type, str): - return self.to(StringVar, self._var_type) - if issubclass(fixed_type, EventSpec): - return self.to(EventVar, self._var_type) - if issubclass(fixed_type, EventChain): - return self.to(EventChainVar, self._var_type) + if fixed_type is None: + return self.to(None) + + for var_subclass in _var_subclasses[::-1]: + if issubclass(fixed_type, var_subclass.python_types): + return self.to(var_subclass.var_subclass, self._var_type) + try: if issubclass(fixed_type, Base): return self.to(ObjectVar, self._var_type) @@ -782,16 +926,23 @@ def __invert__(self) -> BooleanVar: """ return ~self.bool() - def to_string(self): + def to_string(self, use_json: bool = True) -> StringVar: """Convert the var to a string. + Args: + use_json: Whether to use JSON stringify. If False, uses Object.prototype.toString. + Returns: The string var. """ - from .function import JSON_STRINGIFY + from .function import JSON_STRINGIFY, PROTOTYPE_TO_STRING from .sequence import StringVar - return JSON_STRINGIFY.call(self).to(StringVar) + return ( + JSON_STRINGIFY.call(self).to(StringVar) + if use_json + else PROTOTYPE_TO_STRING.call(self).to(StringVar) + ) def as_ref(self) -> Var: """Get a reference to the var. @@ -1017,9 +1168,129 @@ def json(self) -> str: OUTPUT = TypeVar("OUTPUT", bound=Var) +class ToOperation: + """A var operation that converts a var to another type.""" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + from .object import ObjectVar + + if isinstance(self, ObjectVar) and name != "_js_expr": + return ObjectVar.__getattr__(self, name) + return getattr(self._original, name) + + def __post_init__(self): + """Post initialization.""" + object.__delattr__(self, "_js_expr") + + def __hash__(self) -> int: + """Calculate the hash value of the object. + + Returns: + int: The hash value of the object. + """ + return hash(self._original) + + def _get_all_var_data(self) -> VarData | None: + """Get all the var data. + + Returns: + The var data. + """ + return VarData.merge( + self._original._get_all_var_data(), + self._var_data, # type: ignore + ) + + @classmethod + def create( + cls, + value: Var, + _var_type: GenericType | None = None, + _var_data: VarData | None = None, + ): + """Create a ToOperation. + + Args: + value: The value of the var. + _var_type: The type of the Var. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The ToOperation. + """ + return cls( + _js_expr="", # type: ignore + _var_data=_var_data, # type: ignore + _var_type=_var_type or cls._default_var_type, # type: ignore + _original=value, # type: ignore + ) + + class LiteralVar(Var): """Base class for immutable literal vars.""" + def __init_subclass__(cls, **kwargs): + """Initialize the subclass. + + Args: + **kwargs: Additional keyword arguments. + + Raises: + TypeError: If the LiteralVar subclass does not have a corresponding Var subclass. + """ + super().__init_subclass__(**kwargs) + + bases = cls.__bases__ + + bases_normalized = [ + base if inspect.isclass(base) else get_origin(base) for base in bases + ] + + possible_bases = [ + base + for base in bases_normalized + if issubclass(base, Var) and base != LiteralVar + ] + + if not possible_bases: + raise TypeError( + f"LiteralVar subclass {cls} must have a base class that is a subclass of Var and not LiteralVar." + ) + + var_subclasses = [ + var_subclass + for var_subclass in _var_subclasses + if var_subclass.var_subclass in possible_bases + ] + + if not var_subclasses: + raise TypeError( + f"LiteralVar {cls} must have a base class annotated with `python_types`." + ) + + if len(var_subclasses) != 1: + raise TypeError( + f"LiteralVar {cls} must have exactly one base class annotated with `python_types`." + ) + + var_subclass = var_subclasses[0] + + # Remove the old subclass, happens because __init_subclass__ is called twice + # for each subclass. This is because of __slots__ in dataclasses. + for var_literal_subclass in list(_var_literal_subclasses): + if var_literal_subclass[1] is var_subclass: + _var_literal_subclasses.remove(var_literal_subclass) + + _var_literal_subclasses.append((cls, var_subclass)) + @classmethod def create( cls, @@ -1038,50 +1309,21 @@ def create( Raises: TypeError: If the value is not a supported type for LiteralVar. """ - from .number import LiteralBooleanVar, LiteralNumberVar from .object import LiteralObjectVar - from .sequence import LiteralArrayVar, LiteralStringVar + from .sequence import LiteralStringVar if isinstance(value, Var): if _var_data is None: return value return value._replace(merge_var_data=_var_data) - if isinstance(value, str): - return LiteralStringVar.create(value, _var_data=_var_data) - - if isinstance(value, bool): - return LiteralBooleanVar.create(value, _var_data=_var_data) - - if isinstance(value, (int, float)): - return LiteralNumberVar.create(value, _var_data=_var_data) + for literal_subclass, var_subclass in _var_literal_subclasses[::-1]: + if isinstance(value, var_subclass.python_types): + return literal_subclass.create(value, _var_data=_var_data) - if isinstance(value, dict): - return LiteralObjectVar.create(value, _var_data=_var_data) - - if isinstance(value, (list, tuple, set)): - return LiteralArrayVar.create(value, _var_data=_var_data) - - if value is None: - return LiteralNoneVar.create(_var_data=_var_data) - - from reflex.event import ( - EventChain, - EventHandler, - EventSpec, - LiteralEventChainVar, - LiteralEventVar, - ) + from reflex.event import EventHandler from reflex.utils.format import get_event_handler_parts - from .object import LiteralObjectVar - - if isinstance(value, EventSpec): - return LiteralEventVar.create(value, _var_data=_var_data) - - if isinstance(value, EventChain): - return LiteralEventChainVar.create(value, _var_data=_var_data) - if isinstance(value, EventHandler): return Var(_js_expr=".".join(filter(None, get_event_handler_parts(value)))) @@ -1155,6 +1397,22 @@ def serialize_literal(value: LiteralVar): return value._var_value +def get_python_literal(value: Union[LiteralVar, Any]) -> Any | None: + """Get the Python literal value. + + Args: + value: The value to get the Python literal value of. + + Returns: + The Python literal value. + """ + if isinstance(value, LiteralVar): + return value._var_value + if isinstance(value, Var): + return None + return value + + P = ParamSpec("P") T = TypeVar("T") @@ -1205,6 +1463,12 @@ def var_operation( ) -> Callable[P, ObjectVar[OBJECT_TYPE]]: ... +@overload +def var_operation( + func: Callable[P, CustomVarOperationReturn[T]], +) -> Callable[P, Var[T]]: ... + + def var_operation( func: Callable[P, CustomVarOperationReturn[T]], ) -> Callable[P, Var[T]]: @@ -1237,6 +1501,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Var[T]: } return CustomVarOperation.create( + name=func.__name__, args=tuple(list(args_vars.items()) + list(kwargs_vars.items())), return_var=func(*args_vars.values(), **kwargs_vars), # type: ignore ).guess_type() @@ -2059,6 +2324,8 @@ def var_operation_return( class CustomVarOperation(CachedVarOperation, Var[T]): """Base class for custom var operations.""" + _name: str = dataclasses.field(default="") + _args: Tuple[Tuple[str, Var], ...] = dataclasses.field(default_factory=tuple) _return: CustomVarOperationReturn[T] = dataclasses.field( @@ -2093,6 +2360,7 @@ def _cached_get_all_var_data(self) -> VarData | None: @classmethod def create( cls, + name: str, args: Tuple[Tuple[str, Var], ...], return_var: CustomVarOperationReturn[T], _var_data: VarData | None = None, @@ -2100,6 +2368,7 @@ def create( """Create a CustomVarOperation. Args: + name: The name of the operation. args: The arguments to the operation. return_var: The return var. _var_data: Additional hooks and imports associated with the Var. @@ -2111,12 +2380,13 @@ def create( _js_expr="", _var_type=return_var._var_type, _var_data=_var_data, + _name=name, _args=args, _return=return_var, ) -class NoneVar(Var[None]): +class NoneVar(Var[None], python_types=type(None)): """A var representing None.""" @@ -2141,11 +2411,13 @@ def json(self) -> str: @classmethod def create( cls, + value: None = None, _var_data: VarData | None = None, ) -> LiteralNoneVar: """Create a var from a value. Args: + value: The value of the var. Must be None. Existed for compatibility with LiteralVar. _var_data: Additional hooks and imports associated with the Var. Returns: @@ -2158,48 +2430,26 @@ def create( ) -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ToNoneOperation(CachedVarOperation, NoneVar): - """A var operation that converts a var to None.""" - - _original_var: Var = dataclasses.field( - default_factory=lambda: LiteralNoneVar.create() - ) - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """Get the cached var name. - - Returns: - The cached var name. - """ - return str(self._original_var) +def get_to_operation(var_subclass: Type[Var]) -> Type[ToOperation]: + """Get the ToOperation class for a given Var subclass. - @classmethod - def create( - cls, - var: Var, - _var_data: VarData | None = None, - ) -> ToNoneOperation: - """Create a ToNoneOperation. + Args: + var_subclass: The Var subclass. - Args: - var: The var to convert to None. - _var_data: Additional hooks and imports associated with the Var. + Returns: + The ToOperation class. - Returns: - The ToNoneOperation. - """ - return ToNoneOperation( - _js_expr="", - _var_type=None, - _var_data=_var_data, - _original_var=var, - ) + Raises: + ValueError: If the ToOperation class cannot be found. + """ + possible_classes = [ + saved_var_subclass.to_var_subclass + for saved_var_subclass in _var_subclasses + if saved_var_subclass.var_subclass is var_subclass + ] + if not possible_classes: + raise ValueError(f"Could not find ToOperation for {var_subclass}.") + return possible_classes[0] @dataclasses.dataclass( @@ -2262,68 +2512,6 @@ def create( ) -class ToOperation: - """A var operation that converts a var to another type.""" - - def __getattr__(self, name: str) -> Any: - """Get an attribute of the var. - - Args: - name: The name of the attribute. - - Returns: - The attribute of the var. - """ - return getattr(object.__getattribute__(self, "_original"), name) - - def __post_init__(self): - """Post initialization.""" - object.__delattr__(self, "_js_expr") - - def __hash__(self) -> int: - """Calculate the hash value of the object. - - Returns: - int: The hash value of the object. - """ - return hash(object.__getattribute__(self, "_original")) - - def _get_all_var_data(self) -> VarData | None: - """Get all the var data. - - Returns: - The var data. - """ - return VarData.merge( - object.__getattribute__(self, "_original")._get_all_var_data(), - self._var_data, # type: ignore - ) - - @classmethod - def create( - cls, - value: Var, - _var_type: GenericType | None = None, - _var_data: VarData | None = None, - ): - """Create a ToOperation. - - Args: - value: The value of the var. - _var_type: The type of the Var. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The ToOperation. - """ - return cls( - _js_expr="", # type: ignore - _var_data=_var_data, # type: ignore - _var_type=_var_type or cls._default_var_type, # type: ignore - _original=value, # type: ignore - ) - - def get_uuid_string_var() -> Var: """Return a Var that generates a single memoized UUID via .web/utils/state.js. @@ -2369,168 +2557,6 @@ def get_unique_variable_name() -> str: return get_unique_variable_name() -@dataclasses.dataclass( - eq=True, - frozen=True, -) -class VarData: - """Metadata associated with a x.""" - - # The name of the enclosing state. - state: str = dataclasses.field(default="") - - # The name of the field in the state. - field_name: str = dataclasses.field(default="") - - # Imports needed to render this var - imports: ImmutableParsedImportDict = dataclasses.field(default_factory=tuple) - - # Hooks that need to be present in the component to render this var - hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple) - - def __init__( - self, - state: str = "", - field_name: str = "", - imports: ImportDict | ParsedImportDict | None = None, - hooks: dict[str, None] | None = None, - ): - """Initialize the var data. - - Args: - state: The name of the enclosing state. - field_name: The name of the field in the state. - imports: Imports needed to render this var. - hooks: Hooks that need to be present in the component to render this var. - """ - immutable_imports: ImmutableParsedImportDict = tuple( - sorted( - ((k, tuple(sorted(v))) for k, v in parse_imports(imports or {}).items()) - ) - ) - object.__setattr__(self, "state", state) - object.__setattr__(self, "field_name", field_name) - object.__setattr__(self, "imports", immutable_imports) - object.__setattr__(self, "hooks", tuple(hooks or {})) - - def old_school_imports(self) -> ImportDict: - """Return the imports as a mutable dict. - - Returns: - The imports as a mutable dict. - """ - return dict((k, list(v)) for k, v in self.imports) - - @classmethod - def merge(cls, *others: VarData | None) -> VarData | None: - """Merge multiple var data objects. - - Args: - *others: The var data objects to merge. - - Returns: - The merged var data object. - """ - state = "" - field_name = "" - _imports = {} - hooks = {} - for var_data in others: - if var_data is None: - continue - state = state or var_data.state - field_name = field_name or var_data.field_name - _imports = imports.merge_imports(_imports, var_data.imports) - hooks.update( - var_data.hooks - if isinstance(var_data.hooks, dict) - else {k: None for k in var_data.hooks} - ) - - if state or _imports or hooks or field_name: - return VarData( - state=state, - field_name=field_name, - imports=_imports, - hooks=hooks, - ) - return None - - def __bool__(self) -> bool: - """Check if the var data is non-empty. - - Returns: - True if any field is set to a non-default value. - """ - return bool(self.state or self.imports or self.hooks or self.field_name) - - @classmethod - def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData: - """Set the state of the var. - - Args: - state: The state to set or the full name of the state. - field_name: The name of the field in the state. Optional. - - Returns: - The var with the set state. - """ - from reflex.utils import format - - state_name = state if isinstance(state, str) else state.get_full_name() - return VarData( - state=state_name, - field_name=field_name, - hooks={ - "const {0} = useContext(StateContexts.{0})".format( - format.format_state_name(state_name) - ): None - }, - imports={ - f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")], - "react": [ImportVar(tag="useContext")], - }, - ) - - -def _decode_var_immutable(value: str) -> tuple[VarData | None, str]: - """Decode the state name from a formatted var. - - Args: - value: The value to extract the state name from. - - Returns: - The extracted state name and the value without the state name. - """ - var_datas = [] - if isinstance(value, str): - # fast path if there is no encoded VarData - if constants.REFLEX_VAR_OPENING_TAG not in value: - return None, value - - offset = 0 - - # Find all tags. - while m := _decode_var_pattern.search(value): - start, end = m.span() - value = value[:start] + value[end:] - - serialized_data = m.group(1) - - if serialized_data.isnumeric() or ( - serialized_data[0] == "-" and serialized_data[1:].isnumeric() - ): - # This is a global immutable var. - var = _global_vars[int(serialized_data)] - var_data = var._get_all_var_data() - - if var_data is not None: - var_datas.append(var_data) - offset += end - start - - return VarData.merge(*var_datas) if var_datas else None, value - - # Compile regex for finding reflex var tags. _decode_var_pattern_re = ( rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}" diff --git a/reflex/vars/function.py b/reflex/vars/function.py index a512432b9a..a1f7fb7bd6 100644 --- a/reflex/vars/function.py +++ b/reflex/vars/function.py @@ -4,21 +4,20 @@ import dataclasses import sys -from typing import Any, Callable, ClassVar, Optional, Tuple, Type, Union +from typing import Any, Callable, Optional, Tuple, Type, Union from reflex.utils.types import GenericType from .base import ( CachedVarOperation, LiteralVar, - ToOperation, Var, VarData, cached_property_no_lock, ) -class FunctionVar(Var[Callable]): +class FunctionVar(Var[Callable], python_types=Callable): """Base class for immutable function vars.""" def __call__(self, *args: Var | Any) -> ArgsFunctionOperation: @@ -180,17 +179,7 @@ def create( ) -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ToFunctionOperation(ToOperation, FunctionVar): - """Base class of converting a var to a function.""" - - _original: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) - - _default_var_type: ClassVar[GenericType] = Callable - - JSON_STRINGIFY = FunctionStringVar.create("JSON.stringify") +PROTOTYPE_TO_STRING = FunctionStringVar.create( + "((__to_string) => __to_string.toString())" +) diff --git a/reflex/vars/number.py b/reflex/vars/number.py index 0aaa7a0685..77c728d137 100644 --- a/reflex/vars/number.py +++ b/reflex/vars/number.py @@ -10,7 +10,6 @@ TYPE_CHECKING, Any, Callable, - ClassVar, NoReturn, Type, TypeVar, @@ -25,9 +24,7 @@ from .base import ( CustomVarOperationReturn, - LiteralNoneVar, LiteralVar, - ToOperation, Var, VarData, unionize, @@ -58,7 +55,7 @@ def raise_unsupported_operand_types( ) -class NumberVar(Var[NUMBER_T]): +class NumberVar(Var[NUMBER_T], python_types=(int, float)): """Base class for immutable number vars.""" @overload @@ -760,7 +757,7 @@ def number_trunc_operation(value: NumberVar): return var_operation_return(js_expression=f"Math.trunc({value})", var_type=int) -class BooleanVar(NumberVar[bool]): +class BooleanVar(NumberVar[bool], python_types=bool): """Base class for immutable boolean vars.""" def __invert__(self): @@ -989,18 +986,25 @@ def boolean_not_operation(value: BooleanVar): frozen=True, **{"slots": True} if sys.version_info >= (3, 10) else {}, ) -class LiteralBooleanVar(LiteralVar, BooleanVar): - """Base class for immutable literal boolean vars.""" +class LiteralNumberVar(LiteralVar, NumberVar): + """Base class for immutable literal number vars.""" - _var_value: bool = dataclasses.field(default=False) + _var_value: float | int = dataclasses.field(default=0) def json(self) -> str: """Get the JSON representation of the var. Returns: The JSON representation of the var. + + Raises: + PrimitiveUnserializableToJSON: If the var is unserializable to JSON. """ - return "true" if self._var_value else "false" + if math.isinf(self._var_value) or math.isnan(self._var_value): + raise PrimitiveUnserializableToJSON( + f"No valid JSON representation for {self}" + ) + return json.dumps(self._var_value) def __hash__(self) -> int: """Calculate the hash value of the object. @@ -1011,19 +1015,26 @@ def __hash__(self) -> int: return hash((self.__class__.__name__, self._var_value)) @classmethod - def create(cls, value: bool, _var_data: VarData | None = None): - """Create the boolean var. + def create(cls, value: float | int, _var_data: VarData | None = None): + """Create the number var. Args: value: The value of the var. _var_data: Additional hooks and imports associated with the Var. Returns: - The boolean var. + The number var. """ + if math.isinf(value): + js_expr = "Infinity" if value > 0 else "-Infinity" + elif math.isnan(value): + js_expr = "NaN" + else: + js_expr = str(value) + return cls( - _js_expr="true" if value else "false", - _var_type=bool, + _js_expr=js_expr, + _var_type=type(value), _var_data=_var_data, _var_value=value, ) @@ -1034,25 +1045,18 @@ def create(cls, value: bool, _var_data: VarData | None = None): frozen=True, **{"slots": True} if sys.version_info >= (3, 10) else {}, ) -class LiteralNumberVar(LiteralVar, NumberVar): - """Base class for immutable literal number vars.""" +class LiteralBooleanVar(LiteralVar, BooleanVar): + """Base class for immutable literal boolean vars.""" - _var_value: float | int = dataclasses.field(default=0) + _var_value: bool = dataclasses.field(default=False) def json(self) -> str: """Get the JSON representation of the var. Returns: The JSON representation of the var. - - Raises: - PrimitiveUnserializableToJSON: If the var is unserializable to JSON. """ - if math.isinf(self._var_value) or math.isnan(self._var_value): - raise PrimitiveUnserializableToJSON( - f"No valid JSON representation for {self}" - ) - return json.dumps(self._var_value) + return "true" if self._var_value else "false" def __hash__(self) -> int: """Calculate the hash value of the object. @@ -1063,26 +1067,19 @@ def __hash__(self) -> int: return hash((self.__class__.__name__, self._var_value)) @classmethod - def create(cls, value: float | int, _var_data: VarData | None = None): - """Create the number var. + def create(cls, value: bool, _var_data: VarData | None = None): + """Create the boolean var. Args: value: The value of the var. _var_data: Additional hooks and imports associated with the Var. Returns: - The number var. + The boolean var. """ - if math.isinf(value): - js_expr = "Infinity" if value > 0 else "-Infinity" - elif math.isnan(value): - js_expr = "NaN" - else: - js_expr = str(value) - return cls( - _js_expr=js_expr, - _var_type=type(value), + _js_expr="true" if value else "false", + _var_type=bool, _var_data=_var_data, _var_value=value, ) @@ -1092,32 +1089,6 @@ def create(cls, value: float | int, _var_data: VarData | None = None): boolean_types = Union[BooleanVar, bool] -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ToNumberVarOperation(ToOperation, NumberVar): - """Base class for immutable number vars that are the result of a number operation.""" - - _original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create()) - - _default_var_type: ClassVar[Type] = float - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ToBooleanVarOperation(ToOperation, BooleanVar): - """Base class for immutable boolean vars that are the result of a boolean operation.""" - - _original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create()) - - _default_var_type: ClassVar[Type] = bool - - _IS_TRUE_IMPORT: ImportDict = { f"/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")], } @@ -1140,8 +1111,12 @@ def boolify(value: Var): ) +T = TypeVar("T") +U = TypeVar("U") + + @var_operation -def ternary_operation(condition: BooleanVar, if_true: Var, if_false: Var): +def ternary_operation(condition: BooleanVar, if_true: Var[T], if_false: Var[U]): """Create a ternary operation. Args: @@ -1152,10 +1127,14 @@ def ternary_operation(condition: BooleanVar, if_true: Var, if_false: Var): Returns: The ternary operation. """ - return var_operation_return( + type_value: Union[Type[T], Type[U]] = unionize( + if_true._var_type, if_false._var_type + ) + value: CustomVarOperationReturn[Union[T, U]] = var_operation_return( js_expression=f"({condition} ? {if_true} : {if_false})", - var_type=unionize(if_true._var_type, if_false._var_type), + var_type=type_value, ) + return value NUMBER_TYPES = (int, float, NumberVar) diff --git a/reflex/vars/object.py b/reflex/vars/object.py index 38add77797..56f3535d80 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -8,7 +8,6 @@ from inspect import isclass from typing import ( Any, - ClassVar, Dict, List, NoReturn, @@ -27,7 +26,6 @@ from .base import ( CachedVarOperation, LiteralVar, - ToOperation, Var, VarData, cached_property_no_lock, @@ -48,7 +46,7 @@ OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE") -class ObjectVar(Var[OBJECT_TYPE]): +class ObjectVar(Var[OBJECT_TYPE], python_types=dict): """Base class for immutable object vars.""" def _key_type(self) -> Type: @@ -521,34 +519,6 @@ def create( ) -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ToObjectOperation(ToOperation, ObjectVar): - """Operation to convert a var to an object.""" - - _original: Var = dataclasses.field( - default_factory=lambda: LiteralObjectVar.create({}) - ) - - _default_var_type: ClassVar[GenericType] = dict - - def __getattr__(self, name: str) -> Any: - """Get an attribute of the var. - - Args: - name: The name of the attribute. - - Returns: - The attribute of the var. - """ - if name == "_js_expr": - return self._original._js_expr - return ObjectVar.__getattr__(self, name) - - @var_operation def object_has_own_property_operation(object: ObjectVar, key: Var): """Check if an object has a key. diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index 9b65507b7c..6d36f06fa0 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -11,7 +11,6 @@ from typing import ( TYPE_CHECKING, Any, - ClassVar, Dict, List, Literal, @@ -19,27 +18,28 @@ Set, Tuple, Type, - TypeVar, Union, overload, ) +from typing_extensions import TypeVar + from reflex import constants from reflex.constants.base import REFLEX_VAR_OPENING_TAG +from reflex.constants.colors import Color from reflex.utils.exceptions import VarTypeError from reflex.utils.types import GenericType, get_origin from .base import ( CachedVarOperation, CustomVarOperationReturn, - LiteralNoneVar, LiteralVar, - ToOperation, Var, VarData, _global_vars, cached_property_no_lock, figure_out_type, + get_python_literal, get_unique_variable_name, unionize, var_operation, @@ -50,13 +50,16 @@ LiteralNumberVar, NumberVar, raise_unsupported_operand_types, + ternary_operation, ) if TYPE_CHECKING: from .object import ObjectVar +STRING_TYPE = TypeVar("STRING_TYPE", default=str) + -class StringVar(Var[str]): +class StringVar(Var[STRING_TYPE], python_types=str): """Base class for immutable string vars.""" @overload @@ -350,7 +353,7 @@ def __ge__(self, other: Any): @var_operation -def string_lt_operation(lhs: StringVar | str, rhs: StringVar | str): +def string_lt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): """Check if a string is less than another string. Args: @@ -364,7 +367,7 @@ def string_lt_operation(lhs: StringVar | str, rhs: StringVar | str): @var_operation -def string_gt_operation(lhs: StringVar | str, rhs: StringVar | str): +def string_gt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): """Check if a string is greater than another string. Args: @@ -378,7 +381,7 @@ def string_gt_operation(lhs: StringVar | str, rhs: StringVar | str): @var_operation -def string_le_operation(lhs: StringVar | str, rhs: StringVar | str): +def string_le_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): """Check if a string is less than or equal to another string. Args: @@ -392,7 +395,7 @@ def string_le_operation(lhs: StringVar | str, rhs: StringVar | str): @var_operation -def string_ge_operation(lhs: StringVar | str, rhs: StringVar | str): +def string_ge_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): """Check if a string is greater than or equal to another string. Args: @@ -406,7 +409,7 @@ def string_ge_operation(lhs: StringVar | str, rhs: StringVar | str): @var_operation -def string_lower_operation(string: StringVar): +def string_lower_operation(string: StringVar[Any]): """Convert a string to lowercase. Args: @@ -419,7 +422,7 @@ def string_lower_operation(string: StringVar): @var_operation -def string_upper_operation(string: StringVar): +def string_upper_operation(string: StringVar[Any]): """Convert a string to uppercase. Args: @@ -432,7 +435,7 @@ def string_upper_operation(string: StringVar): @var_operation -def string_strip_operation(string: StringVar): +def string_strip_operation(string: StringVar[Any]): """Strip a string. Args: @@ -446,7 +449,7 @@ def string_strip_operation(string: StringVar): @var_operation def string_contains_field_operation( - haystack: StringVar, needle: StringVar | str, field: StringVar | str + haystack: StringVar[Any], needle: StringVar[Any] | str, field: StringVar[Any] | str ): """Check if a string contains another string. @@ -465,7 +468,7 @@ def string_contains_field_operation( @var_operation -def string_contains_operation(haystack: StringVar, needle: StringVar | str): +def string_contains_operation(haystack: StringVar[Any], needle: StringVar[Any] | str): """Check if a string contains another string. Args: @@ -481,7 +484,9 @@ def string_contains_operation(haystack: StringVar, needle: StringVar | str): @var_operation -def string_starts_with_operation(full_string: StringVar, prefix: StringVar | str): +def string_starts_with_operation( + full_string: StringVar[Any], prefix: StringVar[Any] | str +): """Check if a string starts with a prefix. Args: @@ -497,7 +502,7 @@ def string_starts_with_operation(full_string: StringVar, prefix: StringVar | str @var_operation -def string_item_operation(string: StringVar, index: NumberVar | int): +def string_item_operation(string: StringVar[Any], index: NumberVar | int): """Get an item from a string. Args: @@ -511,7 +516,7 @@ def string_item_operation(string: StringVar, index: NumberVar | int): @var_operation -def array_join_operation(array: ArrayVar, sep: StringVar | str = ""): +def array_join_operation(array: ArrayVar, sep: StringVar[Any] | str = ""): """Join the elements of an array. Args: @@ -536,7 +541,7 @@ def array_join_operation(array: ArrayVar, sep: StringVar | str = ""): frozen=True, **{"slots": True} if sys.version_info >= (3, 10) else {}, ) -class LiteralStringVar(LiteralVar, StringVar): +class LiteralStringVar(LiteralVar, StringVar[str]): """Base class for immutable literal string vars.""" _var_value: str = dataclasses.field(default="") @@ -658,7 +663,7 @@ def json(self) -> str: frozen=True, **{"slots": True} if sys.version_info >= (3, 10) else {}, ) -class ConcatVarOperation(CachedVarOperation, StringVar): +class ConcatVarOperation(CachedVarOperation, StringVar[str]): """Representing a concatenation of literal string vars.""" _var_value: Tuple[Var, ...] = dataclasses.field(default_factory=tuple) @@ -742,7 +747,7 @@ def create( VALUE_TYPE = TypeVar("VALUE_TYPE") -class ArrayVar(Var[ARRAY_VAR_TYPE]): +class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)): """Base class for immutable array vars.""" @overload @@ -1272,7 +1277,7 @@ def create( @var_operation -def string_split_operation(string: StringVar, sep: StringVar | str = ""): +def string_split_operation(string: StringVar[Any], sep: StringVar | str = ""): """Split a string. Args: @@ -1569,32 +1574,6 @@ def array_contains_operation(haystack: ArrayVar, needle: Any | Var): ) -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ToStringOperation(ToOperation, StringVar): - """Base class for immutable string vars that are the result of a to string operation.""" - - _original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create()) - - _default_var_type: ClassVar[Type] = str - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ToArrayOperation(ToOperation, ArrayVar): - """Base class for immutable array vars that are the result of a to array operation.""" - - _original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create()) - - _default_var_type: ClassVar[Type] = List[Any] - - @var_operation def repeat_array_operation( array: ArrayVar[ARRAY_VAR_TYPE], count: NumberVar | int @@ -1654,3 +1633,134 @@ def array_concat_operation( js_expression=f"[...{lhs}, ...{rhs}]", var_type=Union[lhs._var_type, rhs._var_type], ) + + +class ColorVar(StringVar[Color], python_types=Color): + """Base class for immutable color vars.""" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralColorVar(CachedVarOperation, LiteralVar, ColorVar): + """Base class for immutable literal color vars.""" + + _var_value: Color = dataclasses.field(default_factory=lambda: Color(color="black")) + + @classmethod + def create( + cls, + value: Color, + _var_type: Type[Color] | None = None, + _var_data: VarData | None = None, + ) -> ColorVar: + """Create a var from a string value. + + Args: + value: The value to create the var from. + _var_type: The type of the var. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + return cls( + _js_expr="", + _var_type=_var_type or Color, + _var_data=_var_data, + _var_value=value, + ) + + def __hash__(self) -> int: + """Get the hash of the var. + + Returns: + The hash of the var. + """ + return hash( + ( + self.__class__.__name__, + self._var_value.color, + self._var_value.alpha, + self._var_value.shade, + ) + ) + + @cached_property_no_lock + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + alpha = self._var_value.alpha + alpha = ( + ternary_operation( + alpha, + LiteralStringVar.create("a"), + LiteralStringVar.create(""), + ) + if isinstance(alpha, Var) + else LiteralStringVar.create("a" if alpha else "") + ) + + shade = self._var_value.shade + shade = ( + shade.to_string(use_json=False) + if isinstance(shade, Var) + else LiteralStringVar.create(str(shade)) + ) + return str( + ConcatVarOperation.create( + LiteralStringVar.create("var(--"), + self._var_value.color, + LiteralStringVar.create("-"), + alpha, + shade, + LiteralStringVar.create(")"), + ) + ) + + @cached_property_no_lock + def _cached_get_all_var_data(self) -> VarData | None: + """Get all the var data. + + Returns: + The var data. + """ + return VarData.merge( + *[ + LiteralVar.create(var)._get_all_var_data() + for var in ( + self._var_value.color, + self._var_value.alpha, + self._var_value.shade, + ) + ], + self._var_data, + ) + + def json(self) -> str: + """Get the JSON representation of the var. + + Returns: + The JSON representation of the var. + + Raises: + TypeError: If the color is not a valid color. + """ + color, alpha, shade = map( + get_python_literal, + (self._var_value.color, self._var_value.alpha, self._var_value.shade), + ) + if color is None or alpha is None or shade is None: + raise TypeError("Cannot serialize color that contains non-literal vars.") + if ( + not isinstance(color, str) + or not isinstance(alpha, bool) + or not isinstance(shade, int) + ): + raise TypeError("Color is not a valid color.") + return f"var(--{color}-{'a' if alpha else ''}{shade})" diff --git a/tests/units/components/core/test_colors.py b/tests/units/components/core/test_colors.py index a6175d56ae..74fbeb20f4 100644 --- a/tests/units/components/core/test_colors.py +++ b/tests/units/components/core/test_colors.py @@ -14,6 +14,7 @@ class ColorState(rx.State): color: str = "mint" color_part: str = "tom" shade: int = 4 + alpha: bool = False color_state_name = ColorState.get_full_name().replace(".", "__") @@ -31,7 +32,14 @@ def create_color_var(color): (create_color_var(rx.color("mint", 3, True)), '"var(--mint-a3)"', Color), ( create_color_var(rx.color(ColorState.color, ColorState.shade)), # type: ignore - f'("var(--"+{str(color_state_name)}.color+"-"+{str(color_state_name)}.shade+")")', + f'("var(--"+{str(color_state_name)}.color+"-"+(((__to_string) => __to_string.toString())({str(color_state_name)}.shade))+")")', + Color, + ), + ( + create_color_var( + rx.color(ColorState.color, ColorState.shade, ColorState.alpha) # type: ignore + ), + f'("var(--"+{str(color_state_name)}.color+"-"+({str(color_state_name)}.alpha ? "a" : "")+(((__to_string) => __to_string.toString())({str(color_state_name)}.shade))+")")', Color, ), ( @@ -43,7 +51,7 @@ def create_color_var(color): create_color_var( rx.color(f"{ColorState.color_part}ato", f"{ColorState.shade}") # type: ignore ), - f'("var(--"+{str(color_state_name)}.color_part+"ato-"+{str(color_state_name)}.shade+")")', + f'("var(--"+({str(color_state_name)}.color_part+"ato")+"-"+{str(color_state_name)}.shade+")")', Color, ), ( diff --git a/tests/units/test_var.py b/tests/units/test_var.py index c04e554a91..a8b4b759d6 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -519,8 +519,8 @@ def test_var_indexing_types(var, type_): type_ : The type on indexed object. """ - assert var[2]._var_type == type_[0] - assert var[3]._var_type == type_[1] + assert var[0]._var_type == type_[0] + assert var[1]._var_type == type_[1] def test_var_indexing_str():