Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid
from typing_extensions import Literal, TypeAlias, deprecated, get_origin
from typing_extensions import (
Literal,
TypeAlias,
_AnnotatedAlias,
deprecated,
get_origin,
)

from ._compat import ( # type: ignore[attr-defined]
IS_PYDANTIC_V2,
Expand Down Expand Up @@ -653,13 +659,24 @@ def get_sqlalchemy_type(field: Any) -> Any:
return sa_type

type_ = get_sa_type_from_field(field)
metadata = get_field_metadata(field)
if isinstance(type_, _AnnotatedAlias):
class_to_compare = type_.__origin__
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename to use_type?

if len(type_.__metadata__) == 1:
metadata = get_field_metadata(type_.__metadata__[0])
else:
# not sure if this is the right behavior
raise ValueError(
f"AnnotatedAlias with multiple metadata is not supported: {type_}"
)
else:
class_to_compare = type_
metadata = get_field_metadata(field)

# Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
if issubclass(type_, Enum):
if issubclass(class_to_compare, Enum):
return sa_Enum(type_)
if issubclass(
type_,
class_to_compare,
(
str,
ipaddress.IPv4Address,
Expand All @@ -674,28 +691,28 @@ def get_sqlalchemy_type(field: Any) -> Any:
if max_length:
return AutoString(length=max_length)
return AutoString
if issubclass(type_, float):
if issubclass(class_to_compare, float):
return Float
if issubclass(type_, bool):
if issubclass(class_to_compare, bool):
return Boolean
if issubclass(type_, int):
if issubclass(class_to_compare, int):
return Integer
if issubclass(type_, datetime):
if issubclass(class_to_compare, datetime):
return DateTime
if issubclass(type_, date):
if issubclass(class_to_compare, date):
return Date
if issubclass(type_, timedelta):
if issubclass(class_to_compare, timedelta):
return Interval
if issubclass(type_, time):
if issubclass(class_to_compare, time):
return Time
if issubclass(type_, bytes):
if issubclass(class_to_compare, bytes):
return LargeBinary
if issubclass(type_, Decimal):
if issubclass(class_to_compare, Decimal):
return Numeric(
precision=getattr(metadata, "max_digits", None),
scale=getattr(metadata, "decimal_places", None),
)
if issubclass(type_, uuid.UUID):
if issubclass(class_to_compare, uuid.UUID):
return Uuid
raise ValueError(f"{type_} has no matching SQLAlchemy type")

Expand Down
25 changes: 24 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional
from decimal import Decimal
from typing import Annotated, List, Optional

import pytest
from sqlalchemy.exc import IntegrityError
Expand Down Expand Up @@ -125,3 +126,25 @@ class Hero(SQLModel, table=True):
# The next statement should not raise an AttributeError
assert hero_rusty_man.team
assert hero_rusty_man.team.name == "Preventers"


def test_optional_annotated_decimal():
class Model(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
dec: Annotated[Decimal, Field(max_digits=4, decimal_places=2)] | None = None

engine = create_engine("sqlite://")

SQLModel.metadata.create_all(engine)

with Session(engine) as session:
session.add(model := Model(dec=Decimal("3.14")))
session.commit()
session.refresh(model)
assert model.dec == Decimal("3.14")

with Session(engine) as session:
session.add(model := Model(dec=Decimal("3.142")))
session.commit()
session.refresh(model)
assert model.dec == Decimal("3.14")
Loading