Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add recursion guard #468

Merged
merged 13 commits into from
Jan 16, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import List

from sqlalchemy import ForeignKey, create_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship

from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory, T


class Base(DeclarativeBase):
...


class Author(Base):
__tablename__ = "authors"

id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str]

books: Mapped[List["Book"]] = relationship(
"Book",
uselist=True,
back_populates="author",
)


class Book(Base):
__tablename__ = "books"

id: Mapped[int] = mapped_column(primary_key=True)
author_id: Mapped[int] = mapped_column(ForeignKey(Author.id), nullable=False)
author: Mapped[Author] = relationship(
"Author",
uselist=False,
back_populates="books",
)


class BaseFactory(SQLAlchemyFactory[T]):
__is_base_factory__ = True
__set_relationships__ = True
__randomize_collection_length__ = True
__min_collection_length__ = 3


def test_custom_sqla_factory() -> None:
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
session = Session(engine)

BaseFactory.__session__ = session # Or using a callable that returns a session

author = BaseFactory.create_factory(Author).create_sync()
assert author.id is not None
assert author.id == author.books[0].author_id

book = BaseFactory.create_factory(Book).create_sync()
assert book.id is not None
assert book.author.books == [book]

BaseFactory._base_factories.remove(BaseFactory) # noqa: SLF001
guacs marked this conversation as resolved.
Show resolved Hide resolved
10 changes: 10 additions & 0 deletions docs/usage/library_factories/sqlalchemy_factory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ By default, this will add generated models to the session and then commit. This
Similarly for ``__async_session__`` and ``create_async``.


Adding global overrides
------------------------------

By combining the above and using other settings, a global base factory can be set up for other factories.

.. literalinclude:: /examples/library_factories/sqlalchemy_factory/test_example_4.py
:caption: Using persistence
:language: python


API reference
------------------------------
Full API docs are available :class:`here <polyfactory.factories.sqlalchemy_factory.SQLAlchemyFactory>`.
2 changes: 1 addition & 1 deletion polyfactory/factories/attrs_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,4 @@ def resolve_types(cls, model: type[T], **kwargs: Any) -> None:
:param kwargs: Any parameters that need to be passed to `attrs.resolve_types`.
"""

attrs.resolve_types(model, **kwargs) # type: ignore[type-var]
attrs.resolve_types(model, **kwargs)
guacs marked this conversation as resolved.
Show resolved Hide resolved
88 changes: 60 additions & 28 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
from abc import ABC, abstractmethod
from collections import Counter, abc, deque
from contextlib import suppress
Expand Down Expand Up @@ -39,6 +40,7 @@
Mapping,
Sequence,
Type,
TypedDict,
TypeVar,
cast,
)
Expand All @@ -63,14 +65,7 @@
unwrap_optional,
)
from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage
from polyfactory.utils.predicates import (
get_type_origin,
is_any,
is_literal,
is_optional,
is_safe_subclass,
is_union,
)
from polyfactory.utils.predicates import get_type_origin, is_any, is_literal, is_optional, is_safe_subclass, is_union
from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage
from polyfactory.value_generators.constrained_collections import (
handle_constrained_collection,
Expand All @@ -86,11 +81,7 @@
from polyfactory.value_generators.constrained_strings import handle_constrained_string_or_bytes
from polyfactory.value_generators.constrained_url import handle_constrained_url
from polyfactory.value_generators.constrained_uuid import handle_constrained_uuid
from polyfactory.value_generators.primitives import (
create_random_boolean,
create_random_bytes,
create_random_string,
)
from polyfactory.value_generators.primitives import create_random_boolean, create_random_bytes, create_random_string

if TYPE_CHECKING:
from typing_extensions import TypeGuard
Expand All @@ -103,6 +94,17 @@
F = TypeVar("F", bound="BaseFactory[Any]")


class BuildContext(TypedDict):
seen_models: set[type]


def _get_build_context(build_context: BuildContext | None) -> BuildContext:
if build_context is None:
return {"seen_models": set()}

return copy.deepcopy(build_context)
guacs marked this conversation as resolved.
Show resolved Hide resolved


class BaseFactory(ABC, Generic[T]):
"""Base Factory class - this class holds the main logic of the library"""

Expand Down Expand Up @@ -271,7 +273,12 @@ def _get_async_persistence(cls) -> AsyncPersistenceProtocol[T]:
)

@classmethod
def _handle_factory_field(cls, field_value: Any, field_build_parameters: Any | None = None) -> Any:
def _handle_factory_field(
cls,
field_value: Any,
build_context: BuildContext,
field_build_parameters: Any | None = None,
) -> Any:
"""Handle a value defined on the factory class itself.

:param field_value: A value defined as an attribute on the factory class.
Expand All @@ -281,12 +288,12 @@ def _handle_factory_field(cls, field_value: Any, field_build_parameters: Any | N
"""
if is_safe_subclass(field_value, BaseFactory):
if isinstance(field_build_parameters, Mapping):
return field_value.build(**field_build_parameters)
return field_value.build(build_context, **field_build_parameters)

if isinstance(field_build_parameters, Sequence):
return [field_value.build(**parameter) for parameter in field_build_parameters]
return [field_value.build(build_context, **parameter) for parameter in field_build_parameters]

return field_value.build()
return field_value.build(build_context)

if isinstance(field_value, Use):
return field_value.to_value()
Expand Down Expand Up @@ -615,6 +622,7 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912
cls,
field_meta: FieldMeta,
field_build_parameters: Any | None = None,
build_context: BuildContext | None = None,
) -> Any:
"""Return a field value on the subclass if existing, otherwise returns a mock value.

Expand All @@ -624,6 +632,7 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912
:returns: An arbitrary value.

"""
build_context = _get_build_context(build_context)
if cls.is_ignored_type(field_meta.annotation):
return None

Expand All @@ -642,18 +651,27 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912
return cls.get_constrained_field_value(annotation=unwrapped_annotation, field_meta=field_meta)

if BaseFactory.is_factory_type(annotation=unwrapped_annotation):
if not field_build_parameters and unwrapped_annotation in build_context["seen_models"]:
return None
guacs marked this conversation as resolved.
Show resolved Hide resolved

return cls._get_or_create_factory(model=unwrapped_annotation).build(
build_context,
**(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}),
)

if BaseFactory.is_batch_factory_type(annotation=unwrapped_annotation):
factory = cls._get_or_create_factory(model=field_meta.type_args[0])
if isinstance(field_build_parameters, Sequence):
return [factory.build(**field_parameters) for field_parameters in field_build_parameters]
return [factory.build(build_context, **field_parameters) for field_parameters in field_build_parameters]

if field_meta.type_args[0] in build_context["seen_models"]:
guacs marked this conversation as resolved.
Show resolved Hide resolved
return []

if not cls.__randomize_collection_length__:
return [factory.build()]
return [factory.build(build_context)]

batch_size = cls.__random__.randint(cls.__min_collection_length__, cls.__max_collection_length__)
return factory.batch(size=batch_size)
return factory.batch(size=batch_size, build_context=build_context)

if (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection):
if cls.__randomize_collection_length__:
Expand Down Expand Up @@ -833,14 +851,17 @@ def _check_declared_fields_exist_in_model(cls) -> None:
raise ConfigurationException(error_message)

@classmethod
def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
def process_kwargs(cls, build_context: BuildContext | None = None, **kwargs: Any) -> dict[str, Any]:
"""Process the given kwargs and generate values for the factory's model.

:param kwargs: Any build kwargs.

:returns: A dictionary of build results.

"""
build_context = _get_build_context(build_context)
build_context["seen_models"].add(cls.__model__)

result: dict[str, Any] = {**kwargs}
generate_post: dict[str, PostGenerated] = {}

Expand All @@ -863,25 +884,37 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
result[field_meta.name] = cls._handle_factory_field(
field_value=field_value,
field_build_parameters=field_build_parameters,
build_context=build_context,
)
continue

result[field_meta.name] = cls.get_field_value(field_meta, field_build_parameters=field_build_parameters)
result[field_meta.name] = cls.get_field_value(
field_meta,
field_build_parameters=field_build_parameters,
build_context=build_context,
)

for field_name, post_generator in generate_post.items():
result[field_name] = post_generator.to_value(field_name, result)

return result

@classmethod
def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
def process_kwargs_coverage(
cls,
build_context: BuildContext | None = None,
**kwargs: Any,
) -> abc.Iterable[dict[str, Any]]:
"""Process the given kwargs and generate values for the factory's model.

:param kwargs: Any build kwargs.

:returns: A dictionary of build results.

"""
build_context = _get_build_context(build_context)
build_context["seen_models"].add(cls.__model__)

result: dict[str, Any] = {**kwargs}
generate_post: dict[str, PostGenerated] = {}

Expand Down Expand Up @@ -918,19 +951,18 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
yield resolved

@classmethod
def build(cls, **kwargs: Any) -> T:
def build(cls, build_context: BuildContext | None = None, **kwargs: Any) -> T:
"""Build an instance of the factory's __model__

:param kwargs: Any kwargs. If field names are set in kwargs, their values will be used.

:returns: An instance of type T.

"""

return cast("T", cls.__model__(**cls.process_kwargs(**kwargs)))
return cast("T", cls.__model__(**cls.process_kwargs(build_context, **kwargs)))

@classmethod
def batch(cls, size: int, **kwargs: Any) -> list[T]:
def batch(cls, size: int, build_context: BuildContext | None = None, **kwargs: Any) -> list[T]:
"""Build a batch of size n of the factory's Meta.model.

:param size: Size of the batch.
Expand All @@ -939,7 +971,7 @@ def batch(cls, size: int, **kwargs: Any) -> list[T]:
:returns: A list of instances of type T.

"""
return [cls.build(**kwargs) for _ in range(size)]
return [cls.build(build_context, **kwargs) for _ in range(size)]

@classmethod
def coverage(cls, **kwargs: Any) -> abc.Iterator[T]:
Expand Down
14 changes: 12 additions & 2 deletions polyfactory/factories/beanie_odm_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
if TYPE_CHECKING:
from typing_extensions import TypeGuard

from polyfactory.factories.base import BuildContext
from polyfactory.field_meta import FieldMeta

try:
Expand Down Expand Up @@ -55,7 +56,12 @@ def is_supported_type(cls, value: Any) -> "TypeGuard[type[T]]":
return is_safe_subclass(value, Document)

@classmethod
def get_field_value(cls, field_meta: "FieldMeta", field_build_parameters: Any | None = None) -> Any:
def get_field_value(
cls,
field_meta: "FieldMeta",
field_build_parameters: Any | None = None,
build_context: BuildContext | None = None,
) -> Any:
"""Return a field value on the subclass if existing, otherwise returns a mock value.

:param field_meta: FieldMeta instance.
Expand All @@ -74,4 +80,8 @@ def get_field_value(cls, field_meta: "FieldMeta", field_build_parameters: Any |
field_meta.annotation = link_class
field_meta.annotation = link_class

return super().get_field_value(field_meta=field_meta, field_build_parameters=field_build_parameters)
return super().get_field_value(
field_meta=field_meta,
field_build_parameters=field_build_parameters,
build_context=build_context,
)
11 changes: 8 additions & 3 deletions polyfactory/factories/pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from polyfactory.collection_extender import CollectionExtender
from polyfactory.constants import DEFAULT_RANDOM
from polyfactory.exceptions import MissingDependencyException
from polyfactory.factories.base import BaseFactory
from polyfactory.factories.base import BaseFactory, BuildContext
from polyfactory.field_meta import Constraints, FieldMeta, Null
from polyfactory.utils.deprecation import check_for_deprecated_parameters
from polyfactory.utils.helpers import unwrap_new_type, unwrap_optional
Expand Down Expand Up @@ -368,7 +368,12 @@ def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) ->
return super().get_constrained_field_value(annotation, field_meta)

@classmethod
def build(cls, factory_use_construct: bool = False, **kwargs: Any) -> T:
def build(
cls,
build_context: BuildContext | None = None,
factory_use_construct: bool = False,
**kwargs: Any,
) -> T:
guacs marked this conversation as resolved.
Show resolved Hide resolved
"""Build an instance of the factory's __model__

:param factory_use_construct: A boolean that determines whether validations will be made when instantiating the
Expand All @@ -378,7 +383,7 @@ def build(cls, factory_use_construct: bool = False, **kwargs: Any) -> T:
:returns: An instance of type T.

"""
processed_kwargs = cls.process_kwargs(**kwargs)
processed_kwargs = cls.process_kwargs(build_context, **kwargs)

if factory_use_construct:
return (
Expand Down
Loading
Loading