diff --git a/docs/making_queries.md b/docs/making_queries.md index f7b9ec1..b36a2d3 100644 --- a/docs/making_queries.md +++ b/docs/making_queries.md @@ -146,10 +146,10 @@ await Note.objects.create(text="Send invoices.", completed=True) You need to pass a list of dictionaries of required fields to create multiple objects: ```python -await Product.objects.bulk_create( +await Note.objects.bulk_create( [ - {"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED}, - {"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT}, + {"text": "Buy the groceries", "completed": False}, + {"text": "Call Mum.", "completed": True}, ] ) @@ -233,6 +233,18 @@ note = await Note.objects.first() await note.update(completed=True) ``` +### .bulk_update() + +You can also bulk update multiple objects at once by passing a list of objects and a list of fields to update. + +```python +notes = await Note.objects.all() +for note in notes : + note.completed = True + +await Note.objects.bulk_update(notes, fields=["completed"]) +``` + ## Convenience Methods ### .get_or_create() @@ -252,7 +264,6 @@ if it doesn't exist, it will use `defaults` argument to create the new instance. !!! note Since `get_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception. - ### .update_or_create() To update an existing instance matching the query, or create a new one. diff --git a/orm/models.py b/orm/models.py index b402814..2766bd2 100644 --- a/orm/models.py +++ b/orm/models.py @@ -1,3 +1,5 @@ +import enum +import json import typing import databases @@ -28,6 +30,15 @@ def _update_auto_now_fields(values, fields): return values +def _convert_value(value): + if isinstance(value, dict): + return json.dumps(value) + elif isinstance(value, enum.Enum): + return value.name + else: + return value + + class ModelRegistry: def __init__(self, database: databases.Database) -> None: self.database = database @@ -454,6 +465,39 @@ async def update(self, **kwargs) -> None: await self.database.execute(expr) + async def bulk_update( + self, objs: typing.List["Model"], fields: typing.List[str] + ) -> None: + fields = { + key: field.validator + for key, field in self.model_cls.fields.items() + if key in fields + } + validator = typesystem.Schema(fields=fields) + new_objs = [ + { + key: _convert_value(value) + for key, value in obj.__dict__.items() + if key in fields + } + for obj in objs + ] + new_objs = [ + _update_auto_now_fields(validator.validate(obj), self.model_cls.fields) + for obj in new_objs + ] + pk_column = getattr(self.table.c, self.pkname) + expr = self.table.update().where(pk_column == sqlalchemy.bindparam(self.pkname)) + kwargs = { + field: sqlalchemy.bindparam(field) + for obj in new_objs + for field in obj.keys() + } + expr = expr.values(kwargs) + pk_list = [{self.pkname: getattr(obj, self.pkname)} for obj in objs] + joined_list = [{**pk, **value} for pk, value in zip(pk_list, new_objs)] + await self.database.execute_many(str(expr), joined_list) + async def get_or_create( self, defaults: typing.Dict[str, typing.Any], **kwargs ) -> typing.Tuple[typing.Any, bool]: diff --git a/tests/test_columns.py b/tests/test_columns.py index 278aecd..3fd7656 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -159,3 +159,43 @@ async def test_bulk_create(): assert products[1].data == {"foo": 456} assert products[1].value == 456.789 assert products[1].status == StatusEnum.DRAFT + + +async def test_bulk_update(): + await Product.objects.bulk_create( + [ + { + "created_day": datetime.date.today(), + "data": {"foo": 123}, + "value": 123.456, + "status": StatusEnum.RELEASED, + }, + { + "created_day": datetime.date.today(), + "data": {"foo": 456}, + "value": 456.789, + "status": StatusEnum.DRAFT, + }, + ] + ) + products = await Product.objects.all() + products[0].created_day = datetime.date.today() - datetime.timedelta(days=1) + products[1].created_day = datetime.date.today() - datetime.timedelta(days=1) + products[0].status = StatusEnum.DRAFT + products[1].status = StatusEnum.RELEASED + products[0].data = {"foo": 1234} + products[1].data = {"foo": 5678} + products[0].value = 345.5 + products[1].value = 789.8 + await Product.objects.bulk_update( + products, fields=["created_day", "status", "data", "value"] + ) + products = await Product.objects.all() + assert products[0].created_day == datetime.date.today() - datetime.timedelta(days=1) + assert products[1].created_day == datetime.date.today() - datetime.timedelta(days=1) + assert products[0].status == StatusEnum.DRAFT + assert products[1].status == StatusEnum.RELEASED + assert products[0].data == {"foo": 1234} + assert products[1].data == {"foo": 5678} + assert products[0].value == 345.5 + assert products[1].value == 789.8 diff --git a/tests/test_foreignkey.py b/tests/test_foreignkey.py index 1ab6175..8b9e5ef 100644 --- a/tests/test_foreignkey.py +++ b/tests/test_foreignkey.py @@ -278,3 +278,22 @@ async def test_nullable_foreign_key(): assert member.email == "dev@encode.io" assert member.team.pk is None + + +async def test_bulk_update_with_relation(): + album = await Album.objects.create(name="foo") + album2 = await Album.objects.create(name="bar") + + await Track.objects.bulk_create( + [ + {"name": "foo", "album": album, "position": 1, "title": "foo"}, + {"name": "bar", "album": album, "position": 2, "title": "bar"}, + ] + ) + tracks = await Track.objects.all() + for track in tracks: + track.album = album2 + await Track.objects.bulk_update(tracks, fields=["album"]) + tracks = await Track.objects.all() + assert tracks[0].album.pk == album2.pk + assert tracks[1].album.pk == album2.pk