Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Var type handling for better rx.Model attribute access #2010

Merged
merged 9 commits into from
Oct 25, 2023
31 changes: 30 additions & 1 deletion reflex/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import alembic.script
import alembic.util
import sqlalchemy
import sqlalchemy.orm
import sqlmodel

from reflex import constants
Expand Down Expand Up @@ -68,6 +69,22 @@ def __init_subclass__(cls):

super().__init_subclass__()

@classmethod
def _dict_recursive(cls, value):
"""Recursively serialize the relationship object(s).

Args:
value: The value to serialize.

Returns:
The serialized value.
"""
if hasattr(value, "dict"):
return value.dict()
elif isinstance(value, list):
return [cls._dict_recursive(item) for item in value]
return value

def dict(self, **kwargs):
"""Convert the object to a dictionary.

Expand All @@ -77,7 +94,19 @@ def dict(self, **kwargs):
Returns:
The object as a dictionary.
"""
return {name: getattr(self, name) for name in self.__fields__}
base_fields = {name: getattr(self, name) for name in self.__fields__}
relationships = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some comments explaining these lines

# SQLModel relationships do not appear in __fields__, but should be included if present.
for name in self.__sqlmodel_relationships__:
try:
relationships[name] = self._dict_recursive(getattr(self, name))
except sqlalchemy.orm.exc.DetachedInstanceError:
# This happens when the relationship was never loaded and the session is closed.
continue
return {
**base_fields,
**relationships,
}

@staticmethod
def create_all():
Expand Down
7 changes: 7 additions & 0 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,13 @@ def _set_default_value(cls, prop: BaseVar):
if default_value is not None:
field.required = False
field.default = default_value
if (
not field.required
and field.default is None
and not types.is_optional(prop._var_type)
):
# Ensure frontend uses null coalescing when accessing.
prop._var_type = Optional[prop._var_type]

@staticmethod
def _get_base_functions() -> dict[str, FunctionType]:
Expand Down
102 changes: 79 additions & 23 deletions reflex/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,22 @@
from __future__ import annotations

import contextlib
import typing
from typing import Any, Callable, Literal, Type, Union, _GenericAlias # type: ignore
import types
from typing import (
Any,
Callable,
Iterable,
Literal,
Optional,
Type,
Union,
_GenericAlias, # type: ignore
get_args,
get_origin,
get_type_hints,
)

from pydantic.fields import ModelField

from reflex.base import Base
from reflex.utils import serializers
Expand All @@ -21,18 +35,6 @@
ArgsSpec = Callable


def get_args(alias: _GenericAlias) -> tuple[Type, ...]:
"""Get the arguments of a type alias.

Args:
alias: The type alias.

Returns:
The arguments of the type alias.
"""
return alias.__args__


def is_generic_alias(cls: GenericType) -> bool:
"""Check whether the class is a generic alias.

Expand Down Expand Up @@ -69,11 +71,11 @@ def is_union(cls: GenericType) -> bool:
Returns:
Whether the class is a Union.
"""
with contextlib.suppress(ImportError):
from typing import _UnionGenericAlias # type: ignore
# UnionType added in py3.10
if not hasattr(types, "UnionType"):
return get_origin(cls) is Union

return isinstance(cls, _UnionGenericAlias)
return cls.__origin__ == Union if is_generic_alias(cls) else False
return get_origin(cls) in [Union, types.UnionType]


def is_literal(cls: GenericType) -> bool:
Expand All @@ -85,7 +87,61 @@ def is_literal(cls: GenericType) -> bool:
Returns:
Whether the class is a literal.
"""
return hasattr(cls, "__origin__") and cls.__origin__ is Literal
return get_origin(cls) is Literal


def is_optional(cls: GenericType) -> bool:
"""Check if a class is an Optional.

Args:
cls: The class to check.

Returns:
Whether the class is an Optional.
"""
return is_union(cls) and type(None) in get_args(cls)


def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None:
"""Check if an attribute can be accessed on the cls and return its type.

Supports pydantic models, unions, and annotated attributes on rx.Model.

Args:
cls: The class to check.
name: The name of the attribute to check.

Returns:
The type of the attribute, if accessible, or None
"""
from reflex.model import Model

if hasattr(cls, "__fields__") and name in cls.__fields__:
# pydantic models
field = cls.__fields__[name]
type_ = field.outer_type_
if isinstance(type_, ModelField):
type_ = type_.type_
if not field.required and field.default is None:
# Ensure frontend uses null coalescing when accessing.
type_ = Optional[type_]
return type_
elif isinstance(cls, type) and issubclass(cls, Model):
# Check in the annotations directly (for sqlmodel.Relationship)
hints = get_type_hints(cls)
if name in hints:
type_ = hints[name]
if isinstance(type_, ModelField):
return type_.type_
return type_
elif is_union(cls):
# Check in each arg of the annotation.
for arg in get_args(cls):
type_ = get_attribute_access_type(arg, name)
if type_ is not None:
# Return the first attribute type that is accessible.
return type_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add return None

return None # Attribute is not accessible.


def get_base_class(cls: GenericType) -> Type:
Expand Down Expand Up @@ -171,7 +227,7 @@ def is_dataframe(value: Type) -> bool:
Returns:
Whether the value is a dataframe.
"""
if is_generic_alias(value) or value == typing.Any:
if is_generic_alias(value) or value == Any:
return False
return value.__name__ == "DataFrame"

Expand All @@ -185,6 +241,8 @@ def is_valid_var_type(type_: Type) -> bool:
Returns:
Whether the type is a valid prop type.
"""
if is_union(type_):
return all((is_valid_var_type(arg) for arg in get_args(type_)))
return _issubclass(type_, StateVar) or serializers.has_serializer(type_)


Expand All @@ -200,9 +258,7 @@ def is_backend_variable(name: str) -> bool:
return name.startswith("_") and not name.startswith("__")


def check_type_in_allowed_types(
value_type: Type, allowed_types: typing.Iterable
) -> bool:
def check_type_in_allowed_types(value_type: Type, allowed_types: Iterable) -> bool:
"""Check that a value type is found in a list of allowed types.

Args:
Expand Down
18 changes: 8 additions & 10 deletions reflex/vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
get_type_hints,
)

from pydantic.fields import ModelField

from reflex import constants
from reflex.base import Base
from reflex.utils import console, format, serializers, types
Expand Down Expand Up @@ -420,15 +418,12 @@ def __getattr__(self, name: str) -> Var:
raise TypeError(
f"You must provide an annotation for the state var `{self._var_full_name}`. Annotation cannot be `{self._var_type}`"
) from None
if (
hasattr(self._var_type, "__fields__")
and name in self._var_type.__fields__
):
type_ = self._var_type.__fields__[name].outer_type_
if isinstance(type_, ModelField):
type_ = type_.type_
is_optional = types.is_optional(self._var_type)
type_ = types.get_attribute_access_type(self._var_type, name)

if type_ is not None:
return BaseVar(
_var_name=f"{self._var_name}.{name}",
_var_name=f"{self._var_name}{'?' if is_optional else ''}.{name}",
_var_type=type_,
_var_state=self._var_state,
_var_is_local=self._var_is_local,
Expand Down Expand Up @@ -1235,6 +1230,9 @@ def get_default_value(self) -> Any:
Raises:
ImportError: If the var is a dataframe and pandas is not installed.
"""
if types.is_optional(self._var_type):
return None

type_ = (
get_origin(self._var_type)
if types.is_generic_alias(self._var_type)
Expand Down
53 changes: 51 additions & 2 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import json
import os
import sys
from typing import Dict, Generator, List
from typing import Dict, Generator, List, Optional, Union
from unittest.mock import AsyncMock, Mock

import pytest
Expand All @@ -30,7 +30,7 @@
StateProxy,
StateUpdate,
)
from reflex.utils import prerequisites
from reflex.utils import prerequisites, types
from reflex.utils.format import json_dumps
from reflex.vars import BaseVar, ComputedVar

Expand Down Expand Up @@ -2239,3 +2239,52 @@ class MutableResetState(State):
instance.items.append([3, 3])
assert instance.items != default
assert instance.items != copied_default


class Custom1(Base):
"""A custom class with a str field."""

foo: str


class Custom2(Base):
"""A custom class with a Custom1 field."""

c1: Optional[Custom1] = None
c1r: Custom1


class Custom3(Base):
"""A custom class with a Custom2 field."""

c2: Optional[Custom2] = None
c2r: Custom2


def test_state_union_optional():
"""Test that state can be defined with Union and Optional vars."""

class UnionState(State):
int_float: Union[int, float] = 0
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think other Var operations are prepared to handle bare Union types...

opt_int: Optional[int]
c3: Optional[Custom3]
c3i: Custom3 # implicitly required
c3r: Custom3 = Custom3(c2r=Custom2(c1r=Custom1(foo="")))
custom_union: Union[Custom1, Custom2, Custom3] = Custom1(foo="")

assert UnionState.c3.c2._var_name == "c3?.c2" # type: ignore
assert UnionState.c3.c2.c1._var_name == "c3?.c2?.c1" # type: ignore
assert UnionState.c3.c2.c1.foo._var_name == "c3?.c2?.c1?.foo" # type: ignore
assert UnionState.c3.c2.c1r.foo._var_name == "c3?.c2?.c1r.foo" # type: ignore
assert UnionState.c3.c2r.c1._var_name == "c3?.c2r.c1" # type: ignore
assert UnionState.c3.c2r.c1.foo._var_name == "c3?.c2r.c1?.foo" # type: ignore
assert UnionState.c3.c2r.c1r.foo._var_name == "c3?.c2r.c1r.foo" # type: ignore
assert UnionState.c3i.c2._var_name == "c3i.c2" # type: ignore
assert UnionState.c3r.c2._var_name == "c3r.c2" # type: ignore
assert UnionState.custom_union.foo is not None # type: ignore
assert UnionState.custom_union.c1 is not None # type: ignore
assert UnionState.custom_union.c1r is not None # type: ignore
assert UnionState.custom_union.c2 is not None # type: ignore
assert UnionState.custom_union.c2r is not None # type: ignore
assert types.is_optional(UnionState.opt_int._var_type) # type: ignore
assert types.is_union(UnionState.int_float._var_type) # type: ignore