diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ee3495a..54f3443d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Added - `_bulk_create` flag is not populating related objects as well [PR #206](https://github.com/model-bakers/model_bakery/pull/206) - Add support for iterators on GFK fields when using _quantity param [PR #207](https://github.com/model-bakers/model_bakery/pull/207) +- Add support for iterators on many-to-many fields [PR#237](https://github.com/model-bakers/model_bakery/pull/237) ### Changed - Fix typos in Recipes documentation page [PR #212](https://github.com/model-bakers/model_bakery/pull/212) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index 843bf0cd..9cc1b9a4 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -357,7 +357,12 @@ def _make( if field.name not in self.model_attrs: self.m2m_dict[field.name] = self.m2m_value(field) else: - self.m2m_dict[field.name] = self.model_attrs.pop(field.name) + if field.name in self.iterator_attrs: + self.model_attrs[field.name] = [ + next(self.iterator_attrs[field.name]) + ] + else: + self.m2m_dict[field.name] = self.model_attrs.pop(field.name) elif field.name not in self.model_attrs: if ( not isinstance(field, ForeignKey) @@ -503,14 +508,18 @@ def _skip_field(self, field: Field) -> bool: return False def _handle_one_to_many(self, instance: Model, attrs: Dict[str, Any]): - for k, v in attrs.items(): - manager = getattr(instance, k) + for key, values in attrs.items(): + manager = getattr(instance, key) + + for value in values: + if not value.pk: + value.save() try: - manager.set(v, bulk=False, clear=True) + manager.set(values, bulk=False, clear=True) except TypeError: # for many-to-many relationships the bulk keyword argument doesn't exist - manager.set(v, clear=True) + manager.set(values, clear=True) def _handle_m2m(self, instance: Model): for key, values in self.m2m_dict.items(): diff --git a/tests/test_baker.py b/tests/test_baker.py index a1110f6b..baefab7e 100644 --- a/tests/test_baker.py +++ b/tests/test_baker.py @@ -405,6 +405,47 @@ def test_regresstion_many_to_many_field_is_accepted_as_kwargs(self): assert store.customers.count() == 3 assert models.Person.objects.count() == 6 + def test_create_many_to_many_with_iter(self): + students = baker.make(models.Person, _quantity=3) + classrooms = baker.make(models.Classroom, _quantity=3, students=iter(students)) + + assert classrooms[0].students.count() == 1 + assert classrooms[0].students.first() == students[0] + assert classrooms[1].students.count() == 1 + assert classrooms[1].students.first() == students[1] + assert classrooms[2].students.count() == 1 + assert classrooms[2].students.first() == students[2] + + def test_create_many_to_many_with_unsaved_iter(self): + students = baker.prepare(models.Person, _quantity=3) + classrooms = baker.make(models.Classroom, _quantity=3, students=iter(students)) + + assert students[0].pk is not None + assert students[1].pk is not None + assert students[2].pk is not None + + assert classrooms[0].students.count() == 1 + assert classrooms[0].students.first() == students[0] + assert classrooms[1].students.count() == 1 + assert classrooms[1].students.first() == students[1] + assert classrooms[2].students.count() == 1 + assert classrooms[2].students.first() == students[2] + + def test_create_many_to_many_with_through_and_iter(self): + students = baker.make(models.Person, _quantity=3) + schools = baker.make( + models.School, + _quantity=3, + students=iter(students), + ) + + assert schools[0].students.count() == 1 + assert schools[0].students.first() == students[0] + assert schools[1].students.count() == 1 + assert schools[1].students.first() == students[1] + assert schools[2].students.count() == 1 + assert schools[2].students.first() == students[2] + def test_create_many_to_many_with_set_default_quantity(self): store = baker.make(models.Store, make_m2m=True) assert store.employees.count() == baker.MAX_MANY_QUANTITY