diff --git a/docs/examples/library_factories/sqlalchemy_factory/conftest.py b/docs/examples/library_factories/sqlalchemy_factory/conftest.py new file mode 100644 index 00000000..2227e64c --- /dev/null +++ b/docs/examples/library_factories/sqlalchemy_factory/conftest.py @@ -0,0 +1,11 @@ +from collections.abc import Iterable + +import pytest + +from docs.examples.library_factories.sqlalchemy_factory.test_example_4 import BaseFactory + + +@pytest.fixture(scope="module") +def _remove_default_factories() -> Iterable[None]: + yield + BaseFactory._base_factories.remove(BaseFactory) # noqa: SLF001 diff --git a/docs/examples/library_factories/sqlalchemy_factory/test_example_4.py b/docs/examples/library_factories/sqlalchemy_factory/test_example_4.py new file mode 100644 index 00000000..0d616a38 --- /dev/null +++ b/docs/examples/library_factories/sqlalchemy_factory/test_example_4.py @@ -0,0 +1,58 @@ +from typing import List + +from sqlalchemy import ForeignKey, create_engine +from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship + +from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory, T + + +class Base(DeclarativeBase): + ... + + +class Author(Base): + __tablename__ = "authors" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + + books: Mapped[List["Book"]] = relationship( + "Book", + uselist=True, + back_populates="author", + ) + + +class Book(Base): + __tablename__ = "books" + + id: Mapped[int] = mapped_column(primary_key=True) + author_id: Mapped[int] = mapped_column(ForeignKey(Author.id), nullable=False) + author: Mapped[Author] = relationship( + "Author", + uselist=False, + back_populates="books", + ) + + +class BaseFactory(SQLAlchemyFactory[T]): + __is_base_factory__ = True + __set_relationships__ = True + __randomize_collection_length__ = True + __min_collection_length__ = 3 + + +def test_custom_sqla_factory() -> None: + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + session = Session(engine) + + BaseFactory.__session__ = session # Or using a callable that returns a session + + author = BaseFactory.create_factory(Author).create_sync() + assert author.id is not None + assert author.id == author.books[0].author_id + + book = BaseFactory.create_factory(Book).create_sync() + assert book.id is not None + assert book.author.books == [book] diff --git a/docs/usage/library_factories/sqlalchemy_factory.rst b/docs/usage/library_factories/sqlalchemy_factory.rst index ef864dc7..ee72c9e9 100644 --- a/docs/usage/library_factories/sqlalchemy_factory.rst +++ b/docs/usage/library_factories/sqlalchemy_factory.rst @@ -37,6 +37,16 @@ By default, this will add generated models to the session and then commit. This Similarly for ``__async_session__`` and ``create_async``. +Adding global overrides +------------------------------ + +By combining the above and using other settings, a global base factory can be set up for other factories. + +.. literalinclude:: /examples/library_factories/sqlalchemy_factory/test_example_4.py + :caption: Using persistence + :language: python + + API reference ------------------------------ Full API docs are available :class:`here `. diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 71b0407d..9ab29c0e 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy from abc import ABC, abstractmethod from collections import Counter, abc, deque from contextlib import suppress @@ -41,6 +42,7 @@ Mapping, Sequence, Type, + TypedDict, TypeVar, cast, ) @@ -65,14 +67,7 @@ unwrap_optional, ) from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage -from polyfactory.utils.predicates import ( - get_type_origin, - is_any, - is_literal, - is_optional, - is_safe_subclass, - is_union, -) +from polyfactory.utils.predicates import get_type_origin, is_any, is_literal, is_optional, is_safe_subclass, is_union from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage from polyfactory.value_generators.constrained_collections import ( handle_constrained_collection, @@ -88,11 +83,7 @@ from polyfactory.value_generators.constrained_strings import handle_constrained_string_or_bytes from polyfactory.value_generators.constrained_url import handle_constrained_url from polyfactory.value_generators.constrained_uuid import handle_constrained_uuid -from polyfactory.value_generators.primitives import ( - create_random_boolean, - create_random_bytes, - create_random_string, -) +from polyfactory.value_generators.primitives import create_random_boolean, create_random_bytes, create_random_string if TYPE_CHECKING: from typing_extensions import TypeGuard @@ -105,6 +96,17 @@ F = TypeVar("F", bound="BaseFactory[Any]") +class BuildContext(TypedDict): + seen_models: set[type] + + +def _get_build_context(build_context: BuildContext | None) -> BuildContext: + if build_context is None: + return {"seen_models": set()} + + return copy.deepcopy(build_context) + + class BaseFactory(ABC, Generic[T]): """Base Factory class - this class holds the main logic of the library""" @@ -277,7 +279,12 @@ def _get_async_persistence(cls) -> AsyncPersistenceProtocol[T]: ) @classmethod - def _handle_factory_field(cls, field_value: Any, field_build_parameters: Any | None = None) -> Any: + def _handle_factory_field( + cls, + field_value: Any, + build_context: BuildContext, + field_build_parameters: Any | None = None, + ) -> Any: """Handle a value defined on the factory class itself. :param field_value: A value defined as an attribute on the factory class. @@ -287,12 +294,14 @@ def _handle_factory_field(cls, field_value: Any, field_build_parameters: Any | N """ if is_safe_subclass(field_value, BaseFactory): if isinstance(field_build_parameters, Mapping): - return field_value.build(**field_build_parameters) + return field_value.build(_build_context=build_context, **field_build_parameters) if isinstance(field_build_parameters, Sequence): - return [field_value.build(**parameter) for parameter in field_build_parameters] + return [ + field_value.build(_build_context=build_context, **parameter) for parameter in field_build_parameters + ] - return field_value.build() + return field_value.build(_build_context=build_context) if isinstance(field_value, Use): return field_value.to_value() @@ -303,7 +312,12 @@ def _handle_factory_field(cls, field_value: Any, field_build_parameters: Any | N return field_value() if callable(field_value) else field_value @classmethod - def _handle_factory_field_coverage(cls, field_value: Any, field_build_parameters: Any | None = None) -> Any: + def _handle_factory_field_coverage( + cls, + field_value: Any, + field_build_parameters: Any | None = None, + build_context: BuildContext | None = None, + ) -> Any: """Handle a value defined on the factory class itself. :param field_value: A value defined as an attribute on the factory class. @@ -313,10 +327,13 @@ def _handle_factory_field_coverage(cls, field_value: Any, field_build_parameters """ if is_safe_subclass(field_value, BaseFactory): if isinstance(field_build_parameters, Mapping): - return CoverageContainer(field_value.coverage(**field_build_parameters)) + return CoverageContainer(field_value.coverage(_build_context=build_context, **field_build_parameters)) if isinstance(field_build_parameters, Sequence): - return [CoverageContainer(field_value.coverage(**parameter)) for parameter in field_build_parameters] + return [ + CoverageContainer(field_value.coverage(_build_context=build_context, **parameter)) + for parameter in field_build_parameters + ] return CoverageContainer(field_value.coverage()) @@ -621,15 +638,18 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912 cls, field_meta: FieldMeta, field_build_parameters: Any | None = None, + build_context: BuildContext | None = None, ) -> Any: """Return a field value on the subclass if existing, otherwise returns a mock value. :param field_meta: FieldMeta instance. :param field_build_parameters: Any build parameters passed to the factory as kwarg values. + :param build_context: BuildContext data for current build. :returns: An arbitrary value. """ + build_context = _get_build_context(build_context) if cls.is_ignored_type(field_meta.annotation): return None @@ -648,20 +668,32 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912 return cls.get_constrained_field_value(annotation=unwrapped_annotation, field_meta=field_meta) if BaseFactory.is_factory_type(annotation=unwrapped_annotation): + if not field_build_parameters and unwrapped_annotation in build_context["seen_models"]: + return None if is_optional(field_meta.annotation) else Null + return cls._get_or_create_factory(model=unwrapped_annotation).build( + _build_context=build_context, **(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}), ) if BaseFactory.is_batch_factory_type(annotation=unwrapped_annotation): factory = cls._get_or_create_factory(model=field_meta.type_args[0]) if isinstance(field_build_parameters, Sequence): - return [factory.build(**field_parameters) for field_parameters in field_build_parameters] + return [ + factory.build(_build_context=build_context, **field_parameters) + for field_parameters in field_build_parameters + ] + + if field_meta.type_args[0] in build_context["seen_models"]: + return [] + if not cls.__randomize_collection_length__: - return [factory.build()] + return [factory.build(_build_context=build_context)] + batch_size = cls.__random__.randint(cls.__min_collection_length__, cls.__max_collection_length__) - return factory.batch(size=batch_size) + return factory.batch(size=batch_size, _build_context=build_context) - if (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection): + if (origin := get_type_origin(unwrapped_annotation)) and is_safe_subclass(origin, Collection): if cls.__randomize_collection_length__: collection_type = get_collection_type(unwrapped_annotation) if collection_type != dict: @@ -682,8 +714,9 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912 return handle_collection_type(field_meta, origin, cls) - if is_union(field_meta.annotation) and field_meta.children: - return cls.get_field_value(cls.__random__.choice(field_meta.children)) + if is_union(unwrapped_annotation) and field_meta.children: + children = [child for child in field_meta.children if child.annotation not in build_context["seen_models"]] + return cls.get_field_value(cls.__random__.choice(children)) if is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar): return create_random_string(cls.__random__, min_length=1, max_length=10) @@ -707,11 +740,13 @@ def get_field_value_coverage( # noqa: C901 cls, field_meta: FieldMeta, field_build_parameters: Any | None = None, + build_context: BuildContext | None = None, ) -> Iterable[Any]: """Return a field value on the subclass if existing, otherwise returns a mock value. :param field_meta: FieldMeta instance. :param field_build_parameters: Any build parameters passed to the factory as kwarg values. + :param build_context: BuildContext data for current build. :returns: An iterable of values. @@ -739,6 +774,7 @@ def get_field_value_coverage( # noqa: C901 elif BaseFactory.is_factory_type(annotation=unwrapped_annotation): yield CoverageContainer( cls._get_or_create_factory(model=unwrapped_annotation).coverage( + _build_context=build_context, **(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}), ), ) @@ -861,6 +897,9 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: :returns: A dictionary of build results. """ + _build_context = _get_build_context(kwargs.pop("_build_context", None)) + _build_context["seen_models"].add(cls.__model__) + result: dict[str, Any] = {**kwargs} generate_post: dict[str, PostGenerated] = {} @@ -883,10 +922,19 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: result[field_meta.name] = cls._handle_factory_field( field_value=field_value, field_build_parameters=field_build_parameters, + build_context=_build_context, ) continue - result[field_meta.name] = cls.get_field_value(field_meta, field_build_parameters=field_build_parameters) + field_result = cls.get_field_value( + field_meta, + field_build_parameters=field_build_parameters, + build_context=_build_context, + ) + if field_result is Null: + continue + + result[field_meta.name] = field_result for field_name, post_generator in generate_post.items(): result[field_name] = post_generator.to_value(field_name, result) @@ -898,10 +946,14 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]: """Process the given kwargs and generate values for the factory's model. :param kwargs: Any build kwargs. + :param build_context: BuildContext data for current build. :returns: A dictionary of build results. """ + _build_context = _get_build_context(kwargs.pop("_build_context", None)) + _build_context["seen_models"].add(cls.__model__) + result: dict[str, Any] = {**kwargs} generate_post: dict[str, PostGenerated] = {} @@ -925,11 +977,16 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]: result[field_meta.name] = cls._handle_factory_field_coverage( field_value=field_value, field_build_parameters=field_build_parameters, + build_context=_build_context, ) continue result[field_meta.name] = CoverageContainer( - cls.get_field_value_coverage(field_meta, field_build_parameters=field_build_parameters), + cls.get_field_value_coverage( + field_meta, + field_build_parameters=field_build_parameters, + build_context=_build_context, + ), ) for resolved in resolve_kwargs_coverage(result): @@ -946,7 +1003,6 @@ def build(cls, **kwargs: Any) -> T: :returns: An instance of type T. """ - return cast("T", cls.__model__(**cls.process_kwargs(**kwargs))) @classmethod diff --git a/polyfactory/factories/beanie_odm_factory.py b/polyfactory/factories/beanie_odm_factory.py index 8f98a5cf..ddd31697 100644 --- a/polyfactory/factories/beanie_odm_factory.py +++ b/polyfactory/factories/beanie_odm_factory.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from typing_extensions import TypeGuard + from polyfactory.factories.base import BuildContext from polyfactory.field_meta import FieldMeta try: @@ -55,7 +56,12 @@ def is_supported_type(cls, value: Any) -> "TypeGuard[type[T]]": return is_safe_subclass(value, Document) @classmethod - def get_field_value(cls, field_meta: "FieldMeta", field_build_parameters: Any | None = None) -> Any: + def get_field_value( + cls, + field_meta: "FieldMeta", + field_build_parameters: Any | None = None, + build_context: BuildContext | None = None, + ) -> Any: """Return a field value on the subclass if existing, otherwise returns a mock value. :param field_meta: FieldMeta instance. @@ -74,4 +80,8 @@ def get_field_value(cls, field_meta: "FieldMeta", field_build_parameters: Any | field_meta.annotation = link_class field_meta.annotation = link_class - return super().get_field_value(field_meta=field_meta, field_build_parameters=field_build_parameters) + return super().get_field_value( + field_meta=field_meta, + field_build_parameters=field_build_parameters, + build_context=build_context, + ) diff --git a/polyfactory/factories/pydantic_factory.py b/polyfactory/factories/pydantic_factory.py index 64102e2c..16218dc4 100644 --- a/polyfactory/factories/pydantic_factory.py +++ b/polyfactory/factories/pydantic_factory.py @@ -368,7 +368,11 @@ def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> return super().get_constrained_field_value(annotation, field_meta) @classmethod - def build(cls, factory_use_construct: bool = False, **kwargs: Any) -> T: + def build( + cls, + factory_use_construct: bool = False, + **kwargs: Any, + ) -> T: """Build an instance of the factory's __model__ :param factory_use_construct: A boolean that determines whether validations will be made when instantiating the diff --git a/polyfactory/utils/helpers.py b/polyfactory/utils/helpers.py index c28c0626..a6432b58 100644 --- a/polyfactory/utils/helpers.py +++ b/polyfactory/utils/helpers.py @@ -11,13 +11,7 @@ from typing_extensions import get_args, get_origin from polyfactory.constants import TYPE_MAPPING -from polyfactory.utils.predicates import ( - is_annotated, - is_new_type, - is_optional, - is_safe_subclass, - is_union, -) +from polyfactory.utils.predicates import is_annotated, is_new_type, is_optional, is_safe_subclass, is_union if TYPE_CHECKING: from random import Random @@ -71,15 +65,14 @@ def unwrap_annotation(annotation: Any, random: Random) -> Any: :returns: The unwrapped annotation. """ - while is_optional(annotation) or is_union(annotation) or is_new_type(annotation) or is_annotated(annotation): + while is_optional(annotation) or is_new_type(annotation) or is_annotated(annotation): if is_new_type(annotation): annotation = unwrap_new_type(annotation) elif is_optional(annotation): annotation = unwrap_optional(annotation) elif is_annotated(annotation): annotation = unwrap_annotated(annotation, random=random)[0] - else: - annotation = unwrap_union(annotation, random=random) + return annotation diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py index 2bb7873e..8cdcf667 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py @@ -10,6 +10,7 @@ from sqlalchemy.orm.decl_api import DeclarativeMeta, registry from polyfactory.exceptions import ConfigurationException +from polyfactory.factories.base import BaseFactory from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory @@ -198,6 +199,26 @@ class AuthorFactory(SQLAlchemyFactory[Author]): assert isinstance(result.books[0], Book) +def test_sqla_factory_create() -> None: + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + + class OverridenSQLAlchemyFactory(SQLAlchemyFactory): + __is_base_factory__ = True + __session__ = Session(engine) + __set_relationships__ = True + + author: Author = OverridenSQLAlchemyFactory.create_factory(Author).create_sync() + assert isinstance(author.books[0], Book) + assert author.books[0].author is author + + book = OverridenSQLAlchemyFactory.create_factory(Book).create_sync() + assert book.author is not None + assert book.author.books == [book] + + BaseFactory._base_factories.remove(OverridenSQLAlchemyFactory) + + async def test_invalid_peristence_config_raises() -> None: class AuthorFactory(SQLAlchemyFactory[Author]): __model__ = Author diff --git a/tests/test_dicts.py b/tests/test_dicts.py index 0eae3da0..d9ad95b2 100644 --- a/tests/test_dicts.py +++ b/tests/test_dicts.py @@ -36,10 +36,10 @@ class MyClass(BaseModel): class MyClassFactory(ModelFactory[MyClass]): __model__ = MyClass - MyClassFactory.seed_random(100) + MyClassFactory.seed_random(10) test_obj_1 = MyClassFactory.build() test_obj_2 = MyClassFactory.build() - assert isinstance(next(iter(test_obj_1.val.values())), str) - assert isinstance(next(iter(test_obj_2.val.values())), int) + assert isinstance(next(iter(test_obj_1.val.values())), int) + assert isinstance(next(iter(test_obj_2.val.values())), str) diff --git a/tests/test_recursive_models.py b/tests/test_recursive_models.py new file mode 100644 index 00000000..9f8c4ad1 --- /dev/null +++ b/tests/test_recursive_models.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import List, Optional, Union + +import pytest +from pydantic import BaseModel, Field + +from polyfactory.factories.dataclass_factory import DataclassFactory +from polyfactory.factories.pydantic_factory import ModelFactory + + +class _Sentinel: + ... + + +@dataclass +class Node: + value: int + union_child: Union[Node, int] # noqa: UP007 + list_child: List[Node] # noqa: UP006 + optional_child: Optional[Node] # noqa: UP007 + child: Node = field(default=_Sentinel) # type: ignore[assignment] + + def __post_init__(self) -> None: + # Emulate recursive models set by external init, e.g. ORM relationships + if self.child is _Sentinel: # type: ignore[comparison-overlap] + self.child = self + + +def test_recusive_model() -> None: + factory = DataclassFactory.create_factory(Node) + + result = factory.build() + assert result.child is result, "Default is not used" + assert isinstance(result.union_child, int) + assert result.optional_child is None + assert result.list_child == [] + + assert factory.build(child={"child": None}).child.child is None + + +class PydanticNode(BaseModel): + value: int + union_child: Union[PydanticNode, int] # noqa: UP007 + list_child: List[PydanticNode] # noqa: UP006 + optional_child: Union[PydanticNode, None] # noqa: UP007 + child: PydanticNode = Field(default=_Sentinel) # type: ignore[assignment] + + +@pytest.mark.parametrize("factory_use_construct", (True, False)) +def test_recursive_pydantic_models(factory_use_construct: bool) -> None: + factory = ModelFactory.create_factory(PydanticNode) + + result = factory.build(factory_use_construct) + assert result.child is _Sentinel, "Default is not used" + assert isinstance(result.union_child, int) + assert result.optional_child is None + assert result.list_child == [] + + +@dataclass +class Author: + name: str + books: List[Book] # noqa: UP006 + + +_DEFAULT_AUTHOR = Author(name="default", books=[]) + + +@dataclass +class Book: + name: str + author: Author = field(default_factory=lambda: _DEFAULT_AUTHOR) + + +def test_recusive_list_model() -> None: + factory = DataclassFactory.create_factory(Author) + assert factory.build().books[0].author is _DEFAULT_AUTHOR + assert factory.build(books=[]).books == [] + + book_factory = DataclassFactory.create_factory(Book) + assert book_factory.build().author.books == [] + assert book_factory.build(author=None).author is None