Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transaction refactoring #234

Merged
merged 15 commits into from
May 13, 2024
348 changes: 130 additions & 218 deletions peewee_async.py

Large diffs are not rendered by default.

85 changes: 80 additions & 5 deletions peewee_async_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Copyright (c) 2024, Alexey Kinëv <rudy@05bit.com>

"""
import uuid
from functools import partial

import warnings
Expand All @@ -31,6 +32,9 @@
'prefetch',
'execute',
'scalar',
'atomic',
'transaction',
'savepoint',
]


Expand Down Expand Up @@ -390,23 +394,28 @@ def atomic(self):
await objects.create(
PageBlock, key='signature', text="William Shakespeare")
"""
from peewee_async import atomic # noqa
return atomic(self.database)
warnings.warn(
"`atomic` is deprecated, use `database.aio_atomic` method.",
DeprecationWarning
)
return self.database.aio_atomic()

def transaction(self):
"""
Similar to `peewee.Database.transaction()` method, but returns
**asynchronous** context manager.
"""
from peewee_async import transaction # noqa
return transaction(self.database)
warnings.warn(
"`transaction` is deprecated, use `database.aio_atomic` method.",
DeprecationWarning
)
return self.database.aio_atomic()

def savepoint(self, sid=None):
"""
Similar to `peewee.Database.savepoint()` method, but returns
**asynchronous** context manager.
"""
from peewee_async import savepoint # noqa
return savepoint(self.database, sid=sid)

def allow_sync(self):
Expand Down Expand Up @@ -436,3 +445,69 @@ def _prune_fields(field_dict, only):
if f not in fields:
field_dict.pop(f)
return field_dict


def transaction(db):
"""Asynchronous context manager (`async with`), similar to
`peewee.transaction()`. Will start new `asyncio` task for
transaction if not started already.
"""
warnings.warn(
"`transaction` is deprecated, use `database.aio_atomic` or `Transaction` class instead.",
DeprecationWarning
)
from peewee_async import TransactionContextManager
return TransactionContextManager(db.aio_pool)


def atomic(db):
"""Asynchronous context manager (`async with`), similar to
`peewee.atomic()`.
"""
warnings.warn(
"`atomic` is deprecated, use `database.aio_atomic` or `Transaction` class instead.",
DeprecationWarning
)
from peewee_async import TransactionContextManager
return TransactionContextManager(db.aio_pool)

class _savepoint:
"""Asynchronous context manager (`async with`), similar to
`peewee.savepoint()`.
"""
def __init__(self, db, sid=None):

self.db = db
self.sid = sid or 's' + uuid.uuid4().hex
self.quoted_sid = self.sid.join(self.db.quote)

async def commit(self):
await self.db.aio_execute_sql('RELEASE SAVEPOINT %s;' % self.quoted_sid)

async def rollback(self):
await self.db.aio_execute_sql('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid)

async def __aenter__(self):
await self.db.aio_execute_sql('SAVEPOINT %s;' % self.quoted_sid)
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
try:
if exc_type:
await self.rollback()
else:
try:
await self.commit()
except:
await self.rollback()
raise
finally:
pass


def savepoint(db, sid=None):
warnings.warn(
"`savepoint` is deprecated, use `database.aio_atomic` or `Transaction` class instead.",
DeprecationWarning
)
return _savepoint(db.db, sid)
Empty file added tests/compat/__init__.py
Empty file.
File renamed without changes.
112 changes: 112 additions & 0 deletions tests/compat/test_transaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import asyncio

from tests.conftest import all_dbs
from tests.models import TestModel


class FakeUpdateError(Exception):
"""Fake error while updating database.
"""
pass


@all_dbs
async def test_atomic_success(manager):
obj = await manager.create(TestModel, text='FOO')
obj_id = obj.id

async with manager.atomic():
obj.text = 'BAR'
await manager.update(obj)

res = await manager.get(TestModel, id=obj_id)
assert res.text == 'BAR'


@all_dbs
async def test_atomic_failed(manager):
"""Failed update in transaction.
"""

obj = await manager.create(TestModel, text='FOO')
obj_id = obj.id

try:
async with manager.atomic():
obj.text = 'BAR'
await manager.update(obj)
raise FakeUpdateError()
except FakeUpdateError as e:
error = True
res = await manager.get(TestModel, id=obj_id)

assert error is True
assert res.text == 'FOO'


@all_dbs
async def test_acid_when_connetion_has_been_brooken(manager):
# TODO REWRITE FOR NEW STYLE TRANSACTIONS
async def restart_connections(event_for_lock: asyncio.Event) -> None:
event_for_lock.set()
# С этого момента БД доступна, пулл в порядке, всё хорошо. Таски могут работать работу
await asyncio.sleep(0.05)

# Ниже происходит падение метеорита
# (приложением обнаруживается разрыв коннекта с БД и оно сливает пулл + заполняет заново)
# (может выглядеть нереалистично, но такое бывает и на проде, мы просто симулируем это редкое событие)
#
# Мы через event запрещаем таскам трогать БД на время переподнятия пулла.
# Это нужно, чтобы воспроизвести очень редкие условия, при которых peewee-async косячит.
event_for_lock.clear()

await manager.database.close_async()
await manager.database.connect_async()

event_for_lock.set()

# БД самопочинилась (пулл заполнен и готов к работе).
# С этого момента таски опять могут работать работу
return None

async def insert_records(event_for_wait: asyncio.Event):
await event_for_wait.wait()
async with manager.transaction():
# BEGIN
# INSERT 1
await manager.create(TestModel, text="1")

# Это место для падения метеорита.
# Тут произойдёт разрыв соединения и подключение заново.
# event здесь нужен чтобы таска не упала с исключением, а воспроизвела редкое поведение peewee-async
await asyncio.sleep(0.05)
await event_for_wait.wait()
# # Метеорит позади. event-loop вернул управление в таску
#
# # INSERT 2
await manager.create(TestModel, text="2")
# END ?

return None


# Этот event-семафор нужно чтобы удобно воссоздать редкую ситуацию,
# когда разрыв + восстановление происходит до того, как в asyncio.Task вернётся управление.
# В дикой природе такое происходит редко и при определённых стечениях обстоятельств,
# но приносит ощутимый ущерб
event = asyncio.Event()

results = await asyncio.gather(
restart_connections(event),
insert_records(event),
return_exceptions=True,
)
# Проверяем, что ни одна из тасок не упала с исключением
# assert results == [None, None]

# (!) Убеждаемся, что атомарность работает (!)
# Т.е. у нас должны либо закоммититься 2 записи, либо ни одной
a = list(await manager.execute(TestModel.select()))
assert len(a) == 0, f'WTF, peewee-async ?! Saved rows: {a}'
# Если assert выше упал, то в БД оказалась 1 запись, а не 0 или 2.
# Хотя мы на уровне кода пытались гарантировать, что такого не будет
41 changes: 36 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def event_loop():


@pytest.fixture
async def manager(request):
async def db(request):
db = request.param
if db.startswith('postgres') and aiopg is None:
pytest.skip("aiopg is not installed")
Expand All @@ -36,15 +36,14 @@ async def manager(request):
params = DB_DEFAULTS[db]
database = DB_CLASSES[db](**params)
database._allow_sync = False
manager = peewee_async.Manager(database)
with manager.allow_sync():
with database.allow_sync():
for model in ALL_MODELS:
model._meta.database = database
model.create_table(True)

yield peewee_async.Manager(database)
yield database

with manager.allow_sync():
with database.allow_sync():
for model in reversed(sort_models(ALL_MODELS)):
model.delete().execute()
model._meta.database = None
Expand All @@ -61,6 +60,12 @@ async def manager(request):
MYSQL_DBS = ["mysql", "mysql-pool"]


dbs_all = pytest.mark.parametrize(
"db", PG_DBS + MYSQL_DBS, indirect=["db"]
)

# MANAGERS fixtures will be removed in v1.0.0
# TODO add manager prefix to fixtures
postgres_only = pytest.mark.parametrize(
"manager", PG_DBS, indirect=["manager"]
)
Expand All @@ -72,3 +77,29 @@ async def manager(request):
all_dbs = pytest.mark.parametrize(
"manager", PG_DBS + MYSQL_DBS, indirect=["manager"]
)


@pytest.fixture
async def manager(request):
db = request.param
if db.startswith('postgres') and aiopg is None:
pytest.skip("aiopg is not installed")
if db.startswith('mysql') and aiomysql is None:
pytest.skip("aiomysql is not installed")

params = DB_DEFAULTS[db]
database = DB_CLASSES[db](**params)
database._allow_sync = False
manager = peewee_async.Manager(database)
with manager.allow_sync():
for model in ALL_MODELS:
model._meta.database = database
model.create_table(True)

yield peewee_async.Manager(database)

with manager.allow_sync():
for model in reversed(sort_models(ALL_MODELS)):
model.delete().execute()
model._meta.database = None
await database.close_async()
19 changes: 19 additions & 0 deletions tests/test_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from peewee_async import connection_context
from tests.conftest import dbs_all
from tests.models import TestModel


@dbs_all
async def test_nested_connection(db):
async with db.aio_connection() as connection_1:
async with connection_1.cursor() as cursor:
await cursor.execute("SELECT 1")
await TestModel.aio_get_or_none(id=5)
async with db.aio_connection() as connection_2:
assert connection_1 is connection_2
_connection = connection_context.get().connection
assert _connection is connection_2
async with connection_2.cursor() as cursor:
await cursor.execute("SELECT 1")
assert connection_context.get() is None
assert db.aio_pool.has_acquired_connections() is False
Loading
Loading