Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,3 +882,24 @@ async def aio_create(cls, **query):
inst = cls(**query)
await inst.aio_save(force_insert=True)
return inst

@classmethod
async def aio_get_or_create(cls, **kwargs):
defaults = kwargs.pop('defaults', {})
query = cls.select()
for field, value in kwargs.items():
query = query.where(getattr(cls, field) == value)

try:
return await query.aio_get(), False
except cls.DoesNotExist:
try:
if defaults:
kwargs.update(defaults)
async with cls._meta.database.aio_atomic():
return await cls.aio_create(**kwargs), True
except peewee.IntegrityError as exc:
try:
return await query.aio_get(), False
except cls.DoesNotExist:
raise exc
8 changes: 8 additions & 0 deletions peewee_async_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ async def get_or_create(self, model_, defaults=None, **kwargs):
Return 2-tuple containing the model instance and a boolean
indicating whether the instance was created.
"""
warnings.warn(
"`get_or_create` method is deprecated, use `AioModel.aio_get_or_create` instead.",
DeprecationWarning
)
try:
return (await self.get(model_, **kwargs)), False
except model_.DoesNotExist:
Expand Down Expand Up @@ -342,6 +346,10 @@ async def create_or_get(self, model_, **kwargs):
Try to create new object with specified data. If object already
exists, then try to get it by unique fields.
"""
warnings.warn(
"`create_or_get` method is deprecated, use `AioModel.aio_get_or_create` instead.",
DeprecationWarning
)
try:
return (await self.create(model_, **kwargs)), True
except IntegrityErrors:
Expand Down
15 changes: 15 additions & 0 deletions tests/aio_model/test_shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,18 @@ async def test_aio_save__force_insert(db):

with pytest.raises(peewee.IntegrityError):
await t.aio_save(force_insert=True)


@dbs_all
async def test_aio_get_or_create__get(db):
t1 = await TestModel.aio_create(text="text", data="data")
t2, created = await TestModel.aio_get_or_create(text="text")
assert t1.id == t2.id
assert created is False


@dbs_all
async def test_aio_get_or_create__created(db):
t2, created = await TestModel.aio_get_or_create(text="text")
assert t2.text == "text"
assert created is True