Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved get_attribute_access_type #3156

Merged
merged 7 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 47 additions & 8 deletions reflex/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@

import contextlib
import inspect
import sys
import types
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Tuple,
Type,
Union,
_GenericAlias, # type: ignore
Expand All @@ -37,11 +40,16 @@

from sqlalchemy.ext.associationproxy import AssociationProxyInstance
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import DeclarativeBase, Mapped, QueryableAttribute, Relationship
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
QueryableAttribute,
Relationship,
)

from reflex import constants
from reflex.base import Base
from reflex.utils import serializers
from reflex.utils import console, serializers

# Potential GenericAlias types for isinstance checks.
GenericAliasTypes = [_GenericAlias]
Expand Down Expand Up @@ -76,6 +84,13 @@
ArgsSpec = Callable


PrimitiveToAnnotation = {
list: List,
tuple: Tuple,
dict: Dict,
}


class Unset:
"""A class to represent an unset value.

Expand Down Expand Up @@ -192,7 +207,19 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
insp = sqlalchemy.inspect(cls)
if name in insp.columns:
return insp.columns[name].type.python_type
# check for list types
column = insp.columns[name]
column_type = column.type
type_ = insp.columns[name].type.python_type
if hasattr(column_type, "item_type") and (
item_type := column_type.item_type.python_type # type: ignore
):
if type_ in PrimitiveToAnnotation:
type_ = PrimitiveToAnnotation[type_] # type: ignore
type_ = type_[item_type] # type: ignore
if column.nullable:
type_ = Optional[type_]
return type_
if name not in insp.all_orm_descriptors:
return None
descriptor = insp.all_orm_descriptors[name]
Expand All @@ -202,11 +229,10 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
prop = descriptor.property
if not isinstance(prop, Relationship):
return None
class_ = prop.mapper.class_
if prop.uselist:
return List[class_]
else:
return class_
type_ = prop.mapper.class_
# TODO: check for nullable?
type_ = List[type_] if prop.uselist else Optional[type_]
return type_
if isinstance(attr, AssociationProxyInstance):
return List[
get_attribute_access_type(
Expand All @@ -232,6 +258,19 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
if type_ is not None:
# Return the first attribute type that is accessible.
return type_
elif isinstance(cls, type):
# Bare class
if sys.version_info >= (3, 10):
exceptions = NameError
else:
exceptions = (NameError, TypeError)
try:
hints = get_type_hints(cls)
if name in hints:
return hints[name]
except exceptions as e:
console.warn(f"Failed to resolve ForwardRefs for {cls}.{name} due to {e}")
pass
return None # Attribute is not accessible.


Expand Down
161 changes: 161 additions & 0 deletions tests/test_attribute_access_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from __future__ import annotations

from typing import List, Optional

import pytest
import sqlalchemy
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship

import reflex as rx
from reflex.utils.types import GenericType, get_attribute_access_type


class SQLABase(DeclarativeBase):
"""Base class for bare SQLAlchemy models."""

pass


class SQLATag(SQLABase):
"""Tag sqlalchemy model."""

__tablename__: str = "tag"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column()


class SQLALabel(SQLABase):
"""Label sqlalchemy model."""

__tablename__: str = "label"
id: Mapped[int] = mapped_column(primary_key=True)
test_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("test.id"))
test: Mapped[SQLAClass] = relationship(back_populates="labels")


class SQLAClass(SQLABase):
"""Test sqlalchemy model."""

__tablename__: str = "test"
id: Mapped[int] = mapped_column(primary_key=True)
count: Mapped[int] = mapped_column()
name: Mapped[str] = mapped_column()
int_list: Mapped[List[int]] = mapped_column(
sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER)
)
str_list: Mapped[List[str]] = mapped_column(
sqlalchemy.types.ARRAY(item_type=sqlalchemy.String)
)
optional_int: Mapped[Optional[int]] = mapped_column(nullable=True)
sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id))
sqla_tag: Mapped[Optional[SQLATag]] = relationship()
labels: Mapped[List[SQLALabel]] = relationship(back_populates="test")

@property
def str_property(self) -> str:
"""String property.
Returns:
Name attribute
"""
return self.name


class ModelClass(rx.Model):
"""Test reflex model."""

count: int = 0
name: str = "test"
int_list: List[int] = []
str_list: List[str] = []
optional_int: Optional[int] = None
sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = []

@property
def str_property(self) -> str:
"""String property.
Returns:
Name attribute
"""
return self.name


class BaseClass(rx.Base):
"""Test rx.Base class."""

count: int = 0
name: str = "test"
int_list: List[int] = []
str_list: List[str] = []
optional_int: Optional[int] = None
sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = []

@property
def str_property(self) -> str:
"""String property.
Returns:
Name attribute
"""
return self.name


class BareClass:
"""Bare python class."""

count: int = 0
name: str = "test"
int_list: List[int] = []
str_list: List[str] = []
optional_int: Optional[int] = None
sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = []

@property
def str_property(self) -> str:
"""String property.
Returns:
Name attribute
"""
return self.name


@pytest.fixture(params=[SQLAClass, BaseClass, BareClass, ModelClass])
def cls(request: pytest.FixtureRequest) -> type:
"""Fixture for the class to test.
Args:
request: pytest request object.
Returns:
Class to test.
"""
return request.param


@pytest.mark.parametrize(
"attr, expected",
[
pytest.param("count", int, id="int"),
pytest.param("name", str, id="str"),
pytest.param("int_list", List[int], id="List[int]"),
pytest.param("str_list", List[str], id="List[str]"),
pytest.param("optional_int", Optional[int], id="Optional[int]"),
pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"),
pytest.param("labels", List[SQLALabel], id="List[SQLALabel]"),
pytest.param("str_property", str, id="str_property"),
],
)
def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType) -> None:
"""Test get_attribute_access_type returns the correct type.
Args:
cls: Class to test.
attr: Attribute to test.
expected: Expected type.
"""
assert get_attribute_access_type(cls, attr) == expected
Loading