Skip to content

Commit

Permalink
Merge branch 'master' of github.com:kingbuzzman/factory_boy
Browse files Browse the repository at this point in the history
  • Loading branch information
kingbuzzman committed May 19, 2022
2 parents 9dba520 + d111d1a commit 2f8ea92
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 54 deletions.
2 changes: 2 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ ChangeLog

- :issue:`366`: Add :class:`factory.django.Password` to generate Django :class:`~django.contrib.auth.models.User`
passwords.
- :issue:`304`: Add :attr:`~factory.alchemy.SQLAlchemyOptions.sqlalchemy_session_factory` to dynamically
create sessions for use by the :class:`~factory.alchemy.SQLAlchemyModelFactory`.
- Add support for Django 3.2
- Add support for Django 4.0
- Add support for Python 3.10
Expand Down
19 changes: 19 additions & 0 deletions docs/orms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,25 @@ To work, this class needs an `SQLAlchemy`_ session object affected to the :attr:
SQLAlchemy session to use to communicate with the database when creating
an object through this :class:`SQLAlchemyModelFactory`.

.. attribute:: sqlalchemy_session_factory

.. versionadded:: 3.3.0

:class:`~collections.abc.Callable` returning a :class:`~sqlalchemy.orm.Session` instance to use to communicate
with the database. You can either provide the session through this attribute, or through
:attr:`~factory.alchemy.SQLAlchemyOptions.sqlalchemy_session`, but not both at the same time.

.. code-block:: python
from . import common
class UserFactory(factory.alchemy.SQLAlchemyModelFactory):
class Meta:
model = User
sqlalchemy_session_factory = lambda: common.Session()
username = 'john'
.. attribute:: sqlalchemy_session_persistence

Control the action taken by ``sqlalchemy_session`` at the end of a create call.
Expand Down
6 changes: 4 additions & 2 deletions docs/recipes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ simply use a :class:`factory.Iterator` on the chosen queryset:
language = factory.Iterator(models.Language.objects.all())
Here, ``models.Language.objects.all()`` won't be evaluated until the
first call to ``UserFactory``; thus avoiding DB queries at import time.
Here, ``models.Language.objects.all()`` is a
:class:`~django.db.models.query.QuerySet` and will only hit the database when
``factory_boy`` starts iterating on it, i.e on the first call to
``UserFactory``; thus avoiding DB queries at import time.


Reverse dependencies (reverse ForeignKey)
Expand Down
12 changes: 12 additions & 0 deletions factory/alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,18 @@ def _check_sqlalchemy_session_persistence(self, meta, value):
(meta, VALID_SESSION_PERSISTENCE_TYPES, value)
)

@staticmethod
def _check_has_sqlalchemy_session_set(meta, value):
if value and meta.sqlalchemy_session:
raise RuntimeError("Provide either a sqlalchemy_session or a sqlalchemy_session_factory, not both")

def _build_default_options(self):
return super()._build_default_options() + [
base.OptionDefault('sqlalchemy_get_or_create', (), inherit=True),
base.OptionDefault('sqlalchemy_session', None, inherit=True),
base.OptionDefault(
'sqlalchemy_session_factory', None, inherit=True, checker=self._check_has_sqlalchemy_session_set
),
base.OptionDefault(
'sqlalchemy_session_persistence',
None,
Expand Down Expand Up @@ -90,6 +98,10 @@ def _get_or_create(cls, model_class, session, args, kwargs):
@classmethod
def _create(cls, model_class, *args, **kwargs):
"""Create an instance of the model, and save it to the database."""
session_factory = cls._meta.sqlalchemy_session_factory
if session_factory:
cls._meta.sqlalchemy_session = session_factory()

session = cls._meta.sqlalchemy_session

if session is None:
Expand Down
86 changes: 34 additions & 52 deletions factory/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

DEFAULT_DB_ALIAS = 'default' # Same as django.db.DEFAULT_DB_ALIAS

DJANGO_22 = Version('2.2') <= Version(django_version) < Version('3.0')
DJANGO_22 = Version(django_version) < Version('3.0')

_LAZY_LOADS = {}

Expand Down Expand Up @@ -205,9 +205,18 @@ def create_batch(cls, size, **kwargs):

@classmethod
def _refresh_database_pks(cls, model_cls, objs):
"""
Before Django 3.0, there is an issue when bulk_insert.
The issue is that if you create an instance of a model,
and reference it in another unsaved instance of a model.
When you create the instance of the first one, the pk/id
is never updated on the sub model that referenced the first.
"""
if not DJANGO_22:
return
fields = [f for f in model_cls._meta.get_fields() if isinstance(f, models.fields.related.ForeignObject)]
fields = [f for f in model_cls._meta.get_fields()
if isinstance(f, models.fields.related.ForeignObject)]
if not fields:
return
for obj in objs:
Expand All @@ -217,17 +226,13 @@ def _refresh_database_pks(cls, model_cls, objs):
@classmethod
def _bulk_create(cls, size, **kwargs):
models_to_create = cls.build_batch(size, **kwargs)
collector = Collector(cls._meta.database)
collector = DependencyInsertOrderCollector()
collector.collect(cls, models_to_create)
collector.sort()
for model_cls, objs in collector.data.items():
manager = cls._get_manager(model_cls)
for instance in objs:
models.signals.pre_save.send(model_cls, instance=instance, created=False)
cls._refresh_database_pks(model_cls, objs)
manager.bulk_create(objs)
for instance in objs:
models.signals.post_save.send(model_cls, instance=instance, created=True)
return models_to_create

@classmethod
Expand Down Expand Up @@ -334,29 +339,20 @@ def _make_data(self, params):
return thumb_io.getvalue()


class Collector:
def __init__(self, using):
self.using = using
class DependencyInsertOrderCollector:
def __init__(self):
# Initially, {model: {instances}}, later values become lists.
self.data = defaultdict(list)
# {model: {(field, value): {instances}}}
self.field_updates = defaultdict(functools.partial(defaultdict, set))
# {model: {field: {instances}}}
self.restricted_objects = defaultdict(functools.partial(defaultdict, set))
# fast_deletes is a list of queryset-likes that can be deleted without
# fetching the objects into memory.
self.fast_deletes = []

# Tracks deletion-order dependency for databases without transactions
# or ability to defer constraint checks. Only concrete model classes
# should be included, as the dependencies exist only between actual
# database tables; proxy models are represented here by their concrete
# parent.
self.dependencies = defaultdict(set) # {model: {models}}

def add(self, objs, source=None, nullable=False, reverse_dependency=False):
def add(self, objs, source=None, nullable=False):
"""
Add 'objs' to the collection of objects to be deleted. If the call is
Add 'objs' to the collection of objects to be inserted in order. If the call is
the result of a cascade, 'source' should be the model that caused it,
and 'nullable' should be set to True if the relation can be null.
Return a list of all objects that were not already collected.
Expand All @@ -372,21 +368,15 @@ def add(self, objs, source=None, nullable=False, reverse_dependency=False):
continue
if id(obj) not in lookup:
new_objs.append(obj)
# import ipdb; ipdb.sset_trace()
instances.extend(new_objs)
# Nullable relationships can be ignored -- they are nulled out before
# deleting, and therefore do not affect the order in which objects have
# to be deleted.
if source is not None and not nullable:
self.add_dependency(source, model, reverse_dependency=reverse_dependency)
# if not nullable:
# import ipdb; ipdb.sset_trace()
# self.add_dependency(source, model, reverse_dependency=reverse_dependency)
self.add_dependency(source, model)
return new_objs

def add_dependency(self, model, dependency, reverse_dependency=False):
if reverse_dependency:
model, dependency = dependency, model
def add_dependency(self, model, dependency):
self.dependencies[model._meta.concrete_model].add(
dependency._meta.concrete_model
)
Expand All @@ -398,11 +388,6 @@ def collect(
objs,
source=None,
nullable=False,
collect_related=True,
source_attr=None,
reverse_dependency=False,
keep_parents=False,
fail_on_restricted=True,
):
"""
Add 'objs' to the collection of objects to be deleted as well as all
Expand All @@ -412,10 +397,6 @@ def collect(
If the call is the result of a cascade, 'source' should be the model
that caused it and 'nullable' should be set to True, if the relation
can be null.
If 'reverse_dependency' is True, 'source' will be deleted before the
current model, rather than after. (Needed for cascading to parent
models, the one case in which the cascade follows the forwards
direction of an FK rather than the reverse direction.)
If 'keep_parents' is True, data of parent model's will be not deleted.
If 'fail_on_restricted' is False, error won't be raised even if it's
prohibited to delete such objects due to RESTRICT, that defers
Expand All @@ -424,47 +405,47 @@ def collect(
can be deleted.
"""
new_objs = self.add(
objs, source, nullable, reverse_dependency=reverse_dependency
objs, source, nullable
)
if not new_objs:
return

# import ipdb; ipdb.sset_trace()
model = new_objs[0].__class__

def get_candidate_relations(opts):
# The candidate relations are the ones that come from N-1 and 1-1 relations.
# N-N (i.e., many-to-many) relations aren't candidates for deletion.
return (
f
for f in opts.get_fields(include_hidden=True)
if isinstance(f, models.ForeignKey)
)
# The candidate relations are the ones that come from N-1 and 1-1 relations.
candidate_relations = (
f for f in model._meta.get_fields(include_hidden=True)
if isinstance(f, models.ForeignKey)
)

collected_objs = []
for field in get_candidate_relations(model._meta):
for field in candidate_relations:
for obj in new_objs:
val = getattr(obj, field.name)
if isinstance(val, models.Model):
collected_objs.append(val)

for name, _ in factory_cls._meta.post_declarations.as_dict().items():

for name, in factory_cls._meta.post_declarations.as_dict().keys():
for obj in new_objs:
val = getattr(obj, name, None)
if isinstance(val, models.Model):
collected_objs.append(val)

if collected_objs:
new_objs = self.collect(
factory_cls=factory_cls, objs=collected_objs, source=model, reverse_dependency=False
factory_cls=factory_cls, objs=collected_objs, source=model
)

def sort(self):
"""
Sort the model instances by the least dependecies to the most dependencies.
We want to insert the models with no dependencies first, and continue inserting
using the models that the higher models depend on.
"""
sorted_models = []
concrete_models = set()
models = list(self.data)
# import ipdb; ipdb.sset_trace()
while len(sorted_models) < len(models):
found = False
for model in models:
Expand All @@ -476,6 +457,7 @@ def sort(self):
concrete_models.add(model._meta.concrete_model)
found = True
if not found:
logger.debug('dependency order could not be determined')
return
self.data = {model: self.data[model] for model in sorted_models}

Expand Down
28 changes: 28 additions & 0 deletions tests/test_alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,34 @@ def test_build_does_not_raises_exception_when_no_session_was_set(self):
self.assertEqual(inst1.id, 1)


class SQLAlchemySessionFactoryTestCase(unittest.TestCase):

def test_create_get_session_from_sqlalchemy_session_factory(self):
class SessionGetterFactory(SQLAlchemyModelFactory):
class Meta:
model = models.StandardModel
sqlalchemy_session = None
sqlalchemy_session_factory = lambda: models.session

id = factory.Sequence(lambda n: n)

SessionGetterFactory.create()
self.assertEqual(SessionGetterFactory._meta.sqlalchemy_session, models.session)
# Reuse the session obtained from sqlalchemy_session_factory.
SessionGetterFactory.create()

def test_create_raise_exception_sqlalchemy_session_factory_not_callable(self):
message = "^Provide either a sqlalchemy_session or a sqlalchemy_session_factory, not both$"
with self.assertRaisesRegex(RuntimeError, message):
class SessionAndGetterFactory(SQLAlchemyModelFactory):
class Meta:
model = models.StandardModel
sqlalchemy_session = models.session
sqlalchemy_session_factory = lambda: models.session

id = factory.Sequence(lambda n: n)


class NameConflictTests(unittest.TestCase):
"""Regression test for `TypeError: _save() got multiple values for argument 'session'`
Expand Down
10 changes: 10 additions & 0 deletions tests/test_django.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,16 @@ class Meta:
level_2 = factory.SubFactory(Level2Factory)


class DependencyInsertOrderCollector(django_test.TestCase):

def test_empty(self):
collector = factory.django.DependencyInsertOrderCollector()
collector.collect(Level2Factory, [])
collector.sort()

self.assertEqual(collector.data, {})


@unittest.skipIf(SKIP_BULK_INSERT, "bulk insert not supported by current db.")
class DjangoBulkInsert(django_test.TestCase):

Expand Down

0 comments on commit 2f8ea92

Please sign in to comment.