Skip to content

Commit

Permalink
fix: handle default_factory in get_attribute_access_type (#4517)
Browse files Browse the repository at this point in the history
* fix: handle default_factory in get_attribute_access_type, add tests for sqla dataclasses

* only test classes which have default_factory + add test for no default
  • Loading branch information
benedikt-bartscher authored Dec 12, 2024
1 parent 95eb663 commit e4b5755
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 17 deletions.
6 changes: 5 additions & 1 deletion reflex/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,11 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
type_ = field.outer_type_
if isinstance(type_, ModelField):
type_ = type_.type_
if not field.required and field.default is None:
if (
not field.required
and field.default is None
and field.default_factory is None
):
# Ensure frontend uses null coalescing when accessing.
type_ = Optional[type_]
return type_
Expand Down
128 changes: 112 additions & 16 deletions tests/units/test_attribute_access_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,19 @@
from typing import Dict, List, Optional, Type, Union

import attrs
import pydantic.v1
import pytest
import sqlalchemy
import sqlmodel
from sqlalchemy import JSON, TypeDecorator
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
MappedAsDataclass,
mapped_column,
relationship,
)

import reflex as rx
from reflex.utils.types import GenericType, get_attribute_access_type
Expand Down Expand Up @@ -53,6 +61,10 @@ class SQLALabel(SQLABase):
id: Mapped[int] = mapped_column(primary_key=True)
test_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("test.id"))
test: Mapped[SQLAClass] = relationship(back_populates="labels")
test_dataclass_id: Mapped[int] = mapped_column(
sqlalchemy.ForeignKey("test_dataclass.id")
)
test_dataclass: Mapped[SQLAClassDataclass] = relationship(back_populates="labels")


class SQLAClass(SQLABase):
Expand Down Expand Up @@ -104,9 +116,64 @@ def first_label(self) -> Optional[SQLALabel]:
return self.labels[0] if self.labels else None


class SQLAClassDataclass(MappedAsDataclass, SQLABase):
"""Test sqlalchemy model."""

id: Mapped[int] = mapped_column(primary_key=True)
no_default: Mapped[int] = mapped_column(nullable=True)
count: Mapped[int] = mapped_column()
name: Mapped[str] = mapped_column()
int_list: Mapped[List[int]] = mapped_column(
sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER)
)
str_list: Mapped[List[str]] = mapped_column(
sqlalchemy.types.ARRAY(item_type=sqlalchemy.String)
)
optional_int: Mapped[Optional[int]] = mapped_column(nullable=True)
sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id))
sqla_tag: Mapped[Optional[SQLATag]] = relationship()
labels: Mapped[List[SQLALabel]] = relationship(back_populates="test_dataclass")
# do not use lower case dict here!
# https://github.com/sqlalchemy/sqlalchemy/issues/9902
dict_str_str: Mapped[Dict[str, str]] = mapped_column()
default_factory: Mapped[List[int]] = mapped_column(
sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER),
default_factory=list,
)
__tablename__: str = "test_dataclass"

@property
def str_property(self) -> str:
"""String property.
Returns:
Name attribute
"""
return self.name

@hybrid_property
def str_or_int_property(self) -> Union[str, int]:
"""String or int property.
Returns:
Name attribute
"""
return self.name

@hybrid_property
def first_label(self) -> Optional[SQLALabel]:
"""First label property.
Returns:
First label
"""
return self.labels[0] if self.labels else None


class ModelClass(rx.Model):
"""Test reflex model."""

no_default: Optional[int] = sqlmodel.Field(nullable=True)
count: int = 0
name: str = "test"
int_list: List[int] = []
Expand All @@ -115,6 +182,7 @@ class ModelClass(rx.Model):
sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {}
default_factory: List[int] = sqlmodel.Field(default_factory=list)

@property
def str_property(self) -> str:
Expand Down Expand Up @@ -147,6 +215,7 @@ def first_label(self) -> Optional[SQLALabel]:
class BaseClass(rx.Base):
"""Test rx.Base class."""

no_default: Optional[int] = pydantic.v1.Field(required=False)
count: int = 0
name: str = "test"
int_list: List[int] = []
Expand All @@ -155,6 +224,7 @@ class BaseClass(rx.Base):
sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {}
default_factory: List[int] = pydantic.v1.Field(default_factory=list)

@property
def str_property(self) -> str:
Expand Down Expand Up @@ -236,6 +306,7 @@ class AttrClass:
sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {}
default_factory: List[int] = attrs.field(factory=list)

@property
def str_property(self) -> str:
Expand Down Expand Up @@ -265,27 +336,17 @@ def first_label(self) -> Optional[SQLALabel]:
return self.labels[0] if self.labels else None


@pytest.fixture(
params=[
@pytest.mark.parametrize(
"cls",
[
SQLAClass,
SQLAClassDataclass,
BaseClass,
BareClass,
ModelClass,
AttrClass,
]
],
)
def cls(request: pytest.FixtureRequest) -> type:
"""Fixture for the class to test.
Args:
request: pytest request object.
Returns:
Class to test.
"""
return request.param


@pytest.mark.parametrize(
"attr, expected",
[
Expand All @@ -311,3 +372,38 @@ def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType)
expected: Expected type.
"""
assert get_attribute_access_type(cls, attr) == expected


@pytest.mark.parametrize(
"cls",
[
SQLAClassDataclass,
BaseClass,
ModelClass,
AttrClass,
],
)
def test_get_attribute_access_type_default_factory(cls: type) -> None:
"""Test get_attribute_access_type returns the correct type for default factory fields.
Args:
cls: Class to test.
"""
assert get_attribute_access_type(cls, "default_factory") == List[int]


@pytest.mark.parametrize(
"cls",
[
SQLAClassDataclass,
BaseClass,
ModelClass,
],
)
def test_get_attribute_access_type_no_default(cls: type) -> None:
"""Test get_attribute_access_type returns the correct type for fields with no default which are not required.
Args:
cls: Class to test.
"""
assert get_attribute_access_type(cls, "no_default") == Optional[int]

0 comments on commit e4b5755

Please sign in to comment.