Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for bulk_update #148

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions docs/making_queries.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ notes = await Note.objects.filter(completed=True).all()

There are some special operators defined automatically on every column:

* `in` - SQL `IN` operator.
* `exact` - filter instances matching exact value.
* `iexact` - same as `exact` but case-insensitive.
* `contains` - filter instances containing value.
* `icontains` - same as `contains` but case-insensitive.
* `lt` - filter instances having value `Less Than`.
* `lte` - filter instances having value `Less Than Equal`.
* `gt` - filter instances having value `Greater Than`.
* `gte` - filter instances having value `Greater Than Equal`.
- `in` - SQL `IN` operator.
- `exact` - filter instances matching exact value.
- `iexact` - same as `exact` but case-insensitive.
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
- `contains` - filter instances containing value.
- `icontains` - same as `contains` but case-insensitive.
- `lt` - filter instances having value `Less Than`.
- `lte` - filter instances having value `Less Than Equal`.
- `gt` - filter instances having value `Greater Than`.
- `gte` - filter instances having value `Greater Than Equal`.

Example usage:

Expand All @@ -84,7 +84,7 @@ notes = await Note.objects.filter(Note.columns.id.in_([1, 2, 3])).all()
Here `Note.columns` refers to the columns of the underlying SQLAlchemy table.

!!! note
Note that `Note.columns` returns SQLAlchemy table columns, whereas `Note.fields` returns `orm` fields.
Note that `Note.columns` returns SQLAlchemy table columns, whereas `Note.fields` returns `orm` fields.
aminalaee marked this conversation as resolved.
Show resolved Hide resolved

### .limit()

Expand Down Expand Up @@ -119,7 +119,7 @@ notes = await Note.objects.order_by("text", "-id").all()
```

!!! note
This will sort by ascending `text` and descending `id`.
This will sort by ascending `text` and descending `id`.

## Returning results

Expand All @@ -146,10 +146,10 @@ await Note.objects.create(text="Send invoices.", completed=True)
You need to pass a list of dictionaries of required fields to create multiple objects:

```python
await Product.objects.bulk_create(
await Note.objects.bulk_create(
[
{"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED},
{"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT},
{"text": "Buy the groceries", "completed": False},
{"text": "Call Mum.", "completed": True},

]
)
Expand Down Expand Up @@ -209,7 +209,7 @@ note = await Note.objects.get(id=1)
```

!!! note
`.get()` expects to find only one instance. This can raise `NoMatch` or `MultipleMatches`.
`.get()` expects to find only one instance. This can raise `NoMatch` or `MultipleMatches`.

### .update()

Expand All @@ -233,6 +233,18 @@ note = await Note.objects.first()
await note.update(completed=True)
```

### .bulk_update()

You can also bulk update multiple objects at once by passing a list of objects and a list of fields to update.

```python
notes = await Note.objects.all()
for note in notes :
note.completed = True

await Note.objects.bulk_update(notes, fields=["completed"])
```

## Convenience Methods

### .get_or_create()
Expand All @@ -250,8 +262,7 @@ This will query a `Note` with `text` as `"Going to car wash"`,
if it doesn't exist, it will use `defaults` argument to create the new instance.

!!! note
Since `get_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception.

Since `get_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception.

### .update_or_create()

Expand All @@ -269,4 +280,4 @@ if an instance is found, it will use the `defaults` argument to update the insta
If it matches no records, it will use the comibnation of arguments to create the new instance.

!!! note
Since `update_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception.
Since `update_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception.
48 changes: 48 additions & 0 deletions orm/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import enum
import json
import typing

import databases
Expand All @@ -20,6 +22,8 @@
"lte": "__le__",
}

MODEL = typing.TypeVar("MODEL", bound="Model")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What value does this bring? I mean we could call bulk_update with Model itself. right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean? We use it as a type annotation for obj

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I understand. I mean you've done this:

def bulk_update(self, objects: typing.List[MODEL], ...):
    ....

What would be the difference if we did:

def bulk_update(self, objects: typing.List[Model], ...):
    ....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case Model will be undefined because it has been defined after bulk_update

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe "Model" ?



def _update_auto_now_fields(values, fields):
for key, value in fields.items():
Expand All @@ -28,6 +32,15 @@ def _update_auto_now_fields(values, fields):
return values


def _convert_value(value):
if isinstance(value, dict):
return json.dumps(value)
elif isinstance(value, enum.Enum):
return value.name
else:
return value


class ModelRegistry:
def __init__(self, database: databases.Database) -> None:
self.database = database
Expand Down Expand Up @@ -454,6 +467,41 @@ async def update(self, **kwargs) -> None:

await self.database.execute(expr)

async def bulk_update(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I should've noticed this earlier, apologies for that.
But maybe a general refactor would be useful here?
There's a lot of nested code here and it's not very readable. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree with you it needs to be more readable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aminalaee Any updates ?

self, objs: typing.List[MODEL], fields: typing.List[str]
) -> None:
fields = {
key: field.validator
for key, field in self.model_cls.fields.items()
if key in fields
}
validator = typesystem.Schema(fields=fields)
new_objs = [
_update_auto_now_fields(validator.validate(value), self.model_cls.fields)
for value in [
{
key: _convert_value(value)
for key, value in obj.__dict__.items()
if key in fields
}
for obj in objs
]
]
expr = (
self.table.update()
.where(self.table.c.id == sqlalchemy.bindparam("id"))
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
.values(
{
field: sqlalchemy.bindparam(field)
for obj in new_objs
for field in obj.keys()
}
)
)
pk_list = [{"id": obj.pk} for obj in objs]
joined_list = [{**pk, **value} for pk, value in zip(pk_list, new_objs)]
await self.database.execute_many(str(expr), joined_list)

async def get_or_create(
self, defaults: typing.Dict[str, typing.Any], **kwargs
) -> typing.Tuple[typing.Any, bool]:
Expand Down
76 changes: 76 additions & 0 deletions tests/test_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,79 @@ async def test_bulk_create():
assert products[1].data == {"foo": 456}
assert products[1].value == 456.789
assert products[1].status == StatusEnum.DRAFT


async def test_bulk_update():
await Product.objects.bulk_create(
[
{
"created": "2020-01-01T00:00:00Z",
"data": {"foo": 123},
"value": 123.456,
"status": StatusEnum.RELEASED,
},
{
"created": "2020-01-01T00:00:00Z",
"data": {"foo": 456},
"value": 456.789,
"status": StatusEnum.DRAFT,
},
]
)
products = await Product.objects.all()
products[0].created = "2021-01-01T00:00:00Z"
products[1].created = "2022-01-01T00:00:00Z"
products[0].status = StatusEnum.DRAFT
products[1].status = StatusEnum.RELEASED
products[0].data = {"foo": 1234}
products[1].data = {"foo": 5678}
products[0].value = 1234.567
products[1].value = 5678.891
await Product.objects.bulk_update(
products, fields=["created", "status", "data", "value"]
)
products = await Product.objects.all()
assert products[0].created == datetime.datetime(2021, 1, 1, 0, 0, 0)
assert products[1].created == datetime.datetime(2022, 1, 1, 0, 0, 0)
assert products[0].status == StatusEnum.DRAFT
assert products[1].status == StatusEnum.RELEASED
assert products[0].data == {"foo": 1234}
assert products[1].data == {"foo": 5678}
assert products[0].value == 1234.567
assert products[1].value == 5678.891


async def test_bulk_update_with_relation():
class Album(orm.Model):
registry = models
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
fields = {
"id": orm.Integer(primary_key=True),
"name": orm.Text(),
}

class Track(orm.Model):
registry = models
fields = {
"id": orm.Integer(primary_key=True),
"name": orm.Text(),
"album": orm.ForeignKey(Album),
}

await models.create_all()

album = await Album.objects.create(name="foo")
album2 = await Album.objects.create(name="bar")

await Track.objects.bulk_create(
[
{"name": "foo", "album": album},
{"name": "bar", "album": album},
]
)
tracks = await Track.objects.all()
for track in tracks:
track.album = album2
await Track.objects.bulk_update(tracks, fields=["album"])
tracks = await Track.objects.all()
assert tracks[0].album.pk == album2.pk
assert tracks[1].album.pk == album2.pk