Skip to content

Commit

Permalink
Merge pull request #2 from honglei/main
Browse files Browse the repository at this point in the history
get_column_from_field support functional sa_column
  • Loading branch information
mbsantiago authored Aug 25, 2023
2 parents 63e2692 + 4213c97 commit bcb6f32
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 18 deletions.
54 changes: 37 additions & 17 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,7 @@
else:
from typing_extensions import get_args, get_origin

if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated
from typing_extensions import Annotated, _AnnotatedAlias

_T = TypeVar("_T")
NoArgAnyCallable = Callable[[], Any]
Expand Down Expand Up @@ -167,7 +164,7 @@ def Field(
unique: bool = False,
nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
sa_column: Union[Column, PydanticUndefinedType, types.FunctionType] = PydanticUndefined, # type: ignore
sa_column: Union[Column, PydanticUndefinedType, Callable[[], Column]] = PydanticUndefined, # type: ignore
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
sa_column_kwargs: Union[
Mapping[str, Any], PydanticUndefinedType
Expand Down Expand Up @@ -440,17 +437,19 @@ def __init__(
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)


def _is_optional_or_union(type_: Optional[type]) -> bool:
if sys.version_info >= (3, 10):
return get_origin(type_) in (types.UnionType, Union)
else:
return get_origin(type_) is Union


def get_sqlalchemy_type(field: FieldInfo) -> Any:
type_: Optional[type] = field.annotation
type_: Optional[type] | _AnnotatedAlias = field.annotation

# Resolve Optional/Union fields
def is_optional_or_union(type_: Optional[type]) -> bool:
if sys.version_info >= (3, 10):
return get_origin(type_) in (types.UnionType, Union)
else:
return get_origin(type_) is Union

if type_ is not None and is_optional_or_union(type_):
if type_ is not None and _is_optional_or_union(type_):
bases = get_args(type_)
if len(bases) > 2:
raise RuntimeError(
Expand All @@ -462,14 +461,20 @@ def is_optional_or_union(type_: Optional[type]) -> bool:
# UrlConstraints(max_length=512,
# allowed_schemes=['smb', 'ftp', 'file']) ]
if type_ is pydantic.AnyUrl:
meta = field.metadata[0]
return AutoString(length=meta.max_length)
if field.metadata:
meta = field.metadata[0]
return AutoString(length=meta.max_length)
else:
return AutoString

if get_origin(type_) is Annotated:
org_type = get_origin(type_)
if org_type is Annotated:
type2 = get_args(type_)[0]
if type2 is pydantic.AnyUrl:
meta = get_args(type_)[1]
return AutoString(length=meta.max_length)
elif org_type is pydantic.AnyUrl and type(type_) is _AnnotatedAlias:
return AutoString(type_.__metadata__[0].max_length)

# The 3rd is PydanticGeneralMetadata
metadata = _get_field_metadata(field)
Expand Down Expand Up @@ -519,11 +524,18 @@ def is_optional_or_union(type_: Optional[type]) -> bool:


def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
"""
sa_column > field attributes > annotation info
"""
sa_column = getattr(field, "sa_column", PydanticUndefined)
if isinstance(sa_column, Column):
return sa_column
if isinstance(sa_column, MappedColumn):
return sa_column.column
if isinstance(sa_column, types.FunctionType):
col = sa_column()
assert isinstance(col, Column)
return col
sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field, "primary_key", False)
index = getattr(field, "index", PydanticUndefined)
Expand Down Expand Up @@ -587,6 +599,10 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any:
# in the Pydantic model so that when SQLAlchemy sets attributes that are
# added (e.g. when querying from DB) to the __fields_set__, this already exists
object.__setattr__(new_object, "__pydantic_fields_set__", set())
if not hasattr(new_object, "__pydantic_extra__"):
object.__setattr__(new_object, "__pydantic_extra__", None)
if not hasattr(new_object, "__pydantic_private__"):
object.__setattr__(new_object, "__pydantic_private__", None)
return new_object

def __init__(__pydantic_self__, **data: Any) -> None:
Expand Down Expand Up @@ -636,7 +652,10 @@ def model_validate(
# remove defaults so they don't get validated
data = {}
for key, value in validated:
field = cls.model_fields[key]
field = cls.model_fields.get(key)

if field is None:
continue

if (
hasattr(field, "default")
Expand All @@ -661,10 +680,11 @@ def _is_field_noneable(field: FieldInfo) -> bool:
return False
if field.annotation is None or field.annotation is NoneType:
return True
if get_origin(field.annotation) is Union:
if _is_optional_or_union(field.annotation):
for base in get_args(field.annotation):
if base is NoneType:
return True

return False
return False

Expand Down
1 change: 0 additions & 1 deletion sqlmodel/sql/sqltypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


class AutoString(types.TypeDecorator): # type: ignore

impl = types.String
cache_ok = True
mysql_default_length = 255
Expand Down
78 changes: 78 additions & 0 deletions tests/test_class_hierarchy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import datetime
import sys

import pytest
from pydantic import AnyUrl, UrlConstraints
from sqlmodel import (
BigInteger,
Column,
DateTime,
Field,
Integer,
SQLModel,
String,
create_engine,
)
from typing_extensions import Annotated

MoveSharedUrl = Annotated[
AnyUrl, UrlConstraints(max_length=512, allowed_schemes=["smb", "ftp", "file"])
]


@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
def test_field_resuse():
class BasicFileLog(SQLModel):
resourceID: int = Field(
sa_column=lambda: Column(Integer, index=True), description=""" """
)
transportID: Annotated[int | None, Field(description=" for ")] = None
fileName: str = Field(
sa_column=lambda: Column(String, index=True), description=""" """
)
fileSize: int | None = Field(
sa_column=lambda: Column(BigInteger), ge=0, description=""" """
)
beginTime: datetime.datetime | None = Field(
sa_column=lambda: Column(
DateTime(timezone=True),
index=True,
),
description="",
)

class SendFileLog(BasicFileLog, table=True):
id: int | None = Field(
sa_column=Column(Integer, primary_key=True, autoincrement=True),
description=""" """,
)
sendUser: str
dstUrl: MoveSharedUrl | None

class RecvFileLog(BasicFileLog, table=True):
id: int | None = Field(
sa_column=Column(Integer, primary_key=True, autoincrement=True),
description=""" """,
)
recvUser: str

sqlite_file_name = "database.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"

engine = create_engine(sqlite_url, echo=True)
SQLModel.metadata.drop_all(engine)
SQLModel.metadata.create_all(engine)
SendFileLog(
sendUser="j",
resourceID=1,
fileName="a.txt",
fileSize=3234,
beginTime=datetime.datetime.now(),
)
RecvFileLog(
sendUser="j",
resourceID=1,
fileName="a.txt",
fileSize=3234,
beginTime=datetime.datetime.now(),
)
50 changes: 50 additions & 0 deletions tests/test_model_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Optional

from sqlmodel import Field, Session, SQLModel, create_engine


def test_model_copy(clear_sqlmodel):
"""Test validation of implicit and explict None values.
# For consistency with pydantic, validators are not to be called on
# arguments that are not explicitly provided.
https://github.com/tiangolo/sqlmodel/issues/230
https://github.com/samuelcolvin/pydantic/issues/1223
"""

class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
secret_name: str
age: Optional[int] = None

hero = Hero(name="Deadpond", secret_name="Dive Wilson", age=25)

engine = create_engine("sqlite://")

SQLModel.metadata.create_all(engine)

with Session(engine) as session:
session.add(hero)
session.commit()
session.refresh(hero)

model_copy = hero.model_copy(update={"name": "Deadpond Copy"})

assert (
model_copy.name == "Deadpond Copy"
and model_copy.secret_name == "Dive Wilson"
and model_copy.age == 25
)

db_hero = session.get(Hero, hero.id)

db_copy = db_hero.model_copy(update={"name": "Deadpond Copy"})

assert (
db_copy.name == "Deadpond Copy"
and db_copy.secret_name == "Dive Wilson"
and db_copy.age == 25
)
22 changes: 22 additions & 0 deletions tests/test_nullable.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from typing import Optional

import pytest
from pydantic import AnyUrl, UrlConstraints
from sqlalchemy.exc import IntegrityError
from sqlmodel import Field, Session, SQLModel, create_engine
from typing_extensions import Annotated

MoveSharedUrl = Annotated[
AnyUrl, UrlConstraints(max_length=512, allowed_schemes=["smb", "ftp", "file"])
]


def test_nullable_fields(clear_sqlmodel, caplog):
Expand All @@ -13,6 +19,8 @@ class Hero(SQLModel, table=True):
)
required_value: str
optional_default_ellipsis: Optional[str] = Field(default=...)
optional_no_field: Optional[str]
optional_no_field_default: Optional[str] = Field(description="no default")
optional_default_none: Optional[str] = Field(default=None)
optional_non_nullable: Optional[str] = Field(
nullable=False,
Expand Down Expand Up @@ -49,6 +57,13 @@ class Hero(SQLModel, table=True):
str_default_str_nullable: str = Field(default="default", nullable=True)
str_default_ellipsis_non_nullable: str = Field(default=..., nullable=False)
str_default_ellipsis_nullable: str = Field(default=..., nullable=True)
base_url: AnyUrl
optional_url: Optional[MoveSharedUrl] = Field(default=None, description="")
url: MoveSharedUrl
annotated_url: Annotated[MoveSharedUrl, Field(description="")]
annotated_optional_url: Annotated[
Optional[MoveSharedUrl], Field(description="")
] = None

engine = create_engine("sqlite://", echo=True)
SQLModel.metadata.create_all(engine)
Expand All @@ -59,6 +74,8 @@ class Hero(SQLModel, table=True):
assert "primary_key INTEGER NOT NULL," in create_table_log
assert "required_value VARCHAR NOT NULL," in create_table_log
assert "optional_default_ellipsis VARCHAR NOT NULL," in create_table_log
assert "optional_no_field VARCHAR," in create_table_log
assert "optional_no_field_default VARCHAR NOT NULL," in create_table_log
assert "optional_default_none VARCHAR," in create_table_log
assert "optional_non_nullable VARCHAR NOT NULL," in create_table_log
assert "optional_nullable VARCHAR," in create_table_log
Expand All @@ -77,6 +94,11 @@ class Hero(SQLModel, table=True):
assert "str_default_str_nullable VARCHAR," in create_table_log
assert "str_default_ellipsis_non_nullable VARCHAR NOT NULL," in create_table_log
assert "str_default_ellipsis_nullable VARCHAR," in create_table_log
assert "base_url VARCHAR NOT NULL," in create_table_log
assert "optional_url VARCHAR(512), " in create_table_log
assert "url VARCHAR(512) NOT NULL," in create_table_log
assert "annotated_url VARCHAR(512) NOT NULL," in create_table_log
assert "annotated_optional_url VARCHAR(512)," in create_table_log


# Test for regression in https://github.com/tiangolo/sqlmodel/issues/420
Expand Down

0 comments on commit bcb6f32

Please sign in to comment.