diff --git a/polyfactory/factories/sqlalchemy_factory.py b/polyfactory/factories/sqlalchemy_factory.py index be8765d3..21c73255 100644 --- a/polyfactory/factories/sqlalchemy_factory.py +++ b/polyfactory/factories/sqlalchemy_factory.py @@ -178,6 +178,9 @@ def should_column_be_set(cls, column: Any) -> bool: if not cls.should_dataclass_init_field(column.name): return False + if column.computed and (cls.__session__ is not None or cls.__async_session__ is not None): + return False + return bool(cls.__set_foreign_keys__ or not column.foreign_keys) @classmethod diff --git a/tests/sqlalchemy_factory/models.py b/tests/sqlalchemy_factory/models.py index 4280b12e..ac33078b 100644 --- a/tests/sqlalchemy_factory/models.py +++ b/tests/sqlalchemy_factory/models.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import Any, Optional -from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, func, orm, text +from sqlalchemy import Boolean, Column, Computed, DateTime, ForeignKey, Integer, String, func, orm, text from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import relationship from sqlalchemy.orm.decl_api import DeclarativeMeta, registry @@ -146,3 +146,11 @@ class Employee(Base): name = Column(String) company_id = Column(Integer, ForeignKey("companies.id")) company = relationship(Company, back_populates="employees") + + +class Shape(Base): + __tablename__ = "shape" + + id = Column(Integer, primary_key=True) + side: Any = Column(Integer(), nullable=False, default=10) + area: Any = Column(Integer, Computed("side * side"), nullable=False) diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py index e3dcb1ff..a9b03db2 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py @@ -51,6 +51,7 @@ CollectionChildMixin, CollectionParentMixin, NonSQLAchemyClass, + Shape, _registry, ) from tests.sqlalchemy_factory.types import ListLike, SetLike @@ -133,6 +134,34 @@ class ModelFactory(SQLAlchemyFactory[Model]): ... assert instance.age * 3 == instance.triple_age +def test_computed_column_sync_persistence(engine: Engine) -> None: + Base.metadata.create_all(engine) + + class ShapeFactory(SQLAlchemyFactory[Shape]): + __model__ = Shape + __session__ = Session(engine) + + instance = ShapeFactory.create_sync() + assert instance.area == pow(instance.side, 2) + + +async def test_computed_column_async_persistence(engine: Engine, async_engine: AsyncEngine) -> None: + class ShapeFactory(SQLAlchemyFactory[Shape]): + __model__ = Shape + __async_session__ = AsyncSession(async_engine) + + instance = await ShapeFactory.create_async() + assert instance.area == pow(instance.side, 2) + + +def test_computed_column_no_persistence() -> None: + class ShapeFactory(SQLAlchemyFactory[Shape]): + __model__ = Shape + + fields = ShapeFactory.get_model_fields() + assert "area" in [field.name for field in fields] + + @pytest.mark.parametrize( "type_", tuple(SQLAlchemyFactory.get_sqlalchemy_types().keys()),