diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 3532e81a8e..ae78cc9f2b 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -42,6 +42,7 @@ from sqlalchemy import Enum as sa_Enum from sqlalchemy.orm import ( Mapped, + MappedColumn, RelationshipProperty, declared_attr, registry, @@ -702,13 +703,13 @@ def get_sqlalchemy_type(field: Any) -> Any: raise ValueError(f"{type_} has no matching SQLAlchemy type") -def get_column_from_field(field: Any) -> Column: # type: ignore +def get_column_from_field(field: Any) -> Union[Column, MappedColumn]: # type: ignore if IS_PYDANTIC_V2: field_info = field else: field_info = field.field_info sa_column = getattr(field_info, "sa_column", Undefined) - if isinstance(sa_column, Column): + if isinstance(sa_column, (Column, MappedColumn)): return sa_column sa_type = get_sqlalchemy_type(field) primary_key = getattr(field_info, "primary_key", Undefined) diff --git a/tests/test_field_sa_column.py b/tests/test_field_sa_column.py index e2ccc6d7ef..8e4f0a4c33 100644 --- a/tests/test_field_sa_column.py +++ b/tests/test_field_sa_column.py @@ -5,7 +5,7 @@ from sqlmodel import Field, SQLModel -def test_sa_column_takes_precedence() -> None: +def test_sa_column_takes_precedence(clear_sqlmodel) -> None: class Item(SQLModel, table=True): id: Optional[int] = Field( default=None, diff --git a/tests/test_field_sa_column_mapped_column.py b/tests/test_field_sa_column_mapped_column.py new file mode 100644 index 0000000000..ad4f22e053 --- /dev/null +++ b/tests/test_field_sa_column_mapped_column.py @@ -0,0 +1,122 @@ +from typing import Optional + +import pytest +from sqlalchemy import Integer, String +from sqlalchemy.orm import mapped_column +from sqlmodel import Field, SQLModel + + +def test_sa_column_takes_precedence(clear_sqlmodel) -> None: + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + sa_column=mapped_column(String, primary_key=True, nullable=False), + ) + + # It would have been nullable with no sa_column + assert Item.id.nullable is False # type: ignore + assert isinstance(Item.id.type, String) # type: ignore + + +def test_sa_column_no_sa_args() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + sa_column_args=[Integer], + sa_column=mapped_column(Integer, primary_key=True), + ) + + +def test_sa_column_no_sa_kargs() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + sa_column_kwargs={"primary_key": True}, + sa_column=mapped_column(Integer, primary_key=True), + ) + + +def test_sa_column_no_type() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + sa_type=Integer, + sa_column=mapped_column(Integer, primary_key=True), + ) + + +def test_sa_column_no_primary_key() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + primary_key=True, + sa_column=mapped_column(Integer, primary_key=True), + ) + + +def test_sa_column_no_nullable() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + nullable=True, + sa_column=mapped_column(Integer, primary_key=True), + ) + + +def test_sa_column_no_foreign_key() -> None: + with pytest.raises(RuntimeError): + + class Team(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + team_id: Optional[int] = Field( + default=None, + foreign_key="team.id", + sa_column=mapped_column(Integer, primary_key=True), + ) + + +def test_sa_column_no_unique() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + unique=True, + sa_column=mapped_column(Integer, primary_key=True), + ) + + +def test_sa_column_no_index() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + index=True, + sa_column=mapped_column(Integer, primary_key=True), + ) + + +def test_sa_column_no_ondelete() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + sa_column=mapped_column(Integer, primary_key=True), + ondelete="CASCADE", + )