From 7325e5b66230e6a2244d6c0c18f5d0ff45d0739f Mon Sep 17 00:00:00 2001 From: kalombo Date: Sat, 18 May 2024 22:08:39 +0500 Subject: [PATCH] fix: rewrite transaction context manager --- peewee_async.py | 36 ++++++----------- pyproject.toml | 5 ++- tests/test_transaction.py | 85 ++++++++++++++++++++++++++------------- 3 files changed, 73 insertions(+), 53 deletions(-) diff --git a/peewee_async.py b/peewee_async.py index fc4315e..b19d05f 100644 --- a/peewee_async.py +++ b/peewee_async.py @@ -236,28 +236,6 @@ async def __aexit__(self, *args): connection_context.set(None) -class TransactionContextManager(ConnectionContextManager): - async def __aenter__(self): - connection = await super().__aenter__() - begin_transaction = self.connection_context.transaction_is_opened is False - - self.transaction = Transaction(connection, is_savepoint=begin_transaction is False) - await self.transaction.__aenter__() - - if begin_transaction is True: - self.connection_context.transaction_is_opened = True - return connection - - async def __aexit__(self, exc_type, exc_value, exc_tb): - await self.transaction.__aexit__(exc_type, exc_value, exc_tb) - - end_transaction = self.transaction.is_savepoint is False - if end_transaction is True: - self.connection_context.transaction_is_opened = False - - await super().__aexit__() - - ############ # Database # ############ @@ -286,11 +264,21 @@ async def aio_close(self): """ await self.aio_pool.terminate() - def aio_atomic(self): + @contextlib.asynccontextmanager + async def aio_atomic(self): """Similar to peewee `Database.atomic()` method, but returns asynchronous context manager. """ - return TransactionContextManager(self.aio_pool) + async with self.aio_connection() as connection: + _connection_context = connection_context.get() + begin_transaction = _connection_context.transaction_is_opened is False + try: + async with Transaction(connection, is_savepoint=begin_transaction is False): + _connection_context.transaction_is_opened = True + yield + finally: + if begin_transaction is True: + _connection_context.transaction_is_opened = False def set_allow_sync(self, value): """Allow or forbid sync queries for the database. See also diff --git a/pyproject.toml b/pyproject.toml index 691e8ee..6863107 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "peewee-async" -version = "0.9.1" +version = "0.10.1-beta" description = "Asynchronous interface for peewee ORM powered by asyncio." authors = ["Alexey Kinev ", "Gorshkov Nikolay(contributor) "] readme = "README.md" @@ -19,13 +19,14 @@ aiomysql = { version = "^0.2.0", optional = true } cryptography = { version = "^41.0.3", optional = true } pytest = { version = "^7.4.1", optional = true } pytest-asyncio = { version = "^0.21.1", optional = true } +pytest-mock = { version = "^3.14.0", optional = true } sphinx = { version = "^7.1.2", optional = true } sphinx-rtd-theme = { version = "^1.3.0rc1", optional = true } [tool.poetry.extras] postgresql = ["aiopg"] mysql = ["aiomysql", "cryptography"] -develop = ["aiopg", "aiomysql", "cryptography", "pytest", "pytest-asyncio"] +develop = ["aiopg", "aiomysql", "cryptography", "pytest", "pytest-asyncio", "pytest-mock"] docs = ["aiopg", "aiomysql", "cryptography", "sphinx", "sphinx-rtd-theme"] [build-system] diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 9c238ff..e5187b1 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -1,46 +1,53 @@ import asyncio +import pytest +from peewee import IntegrityError + from peewee_async import Transaction from tests.conftest import dbs_all from tests.models import TestModel -@dbs_all -async def test_savepoint_success(db): +class FakeConnectionError(Exception): + pass - async with db.aio_atomic(): - await TestModel.aio_create(text='FOO') +@dbs_all +async def test_transaction_error_on_begin(db, mocker): + mocker.patch.object(Transaction, "begin", side_effect=FakeConnectionError) + with pytest.raises(FakeConnectionError): async with db.aio_atomic(): - await TestModel.update(text="BAR").aio_execute() + await TestModel.aio_create(text='FOO') + assert db.aio_pool.has_acquired_connections() is False - assert await TestModel.aio_get_or_none(text="BAR") is not None +@dbs_all +async def test_transaction_error_on_commit(db, mocker): + mocker.patch.object(Transaction, "commit", side_effect=FakeConnectionError) + with pytest.raises(FakeConnectionError): + async with db.aio_atomic(): + await TestModel.aio_create(text='FOO') assert db.aio_pool.has_acquired_connections() is False @dbs_all -async def test_transaction_success(db): - async with db.aio_atomic(): - await TestModel.aio_create(text='FOO') +async def test_transaction_error_on_rollback(db, mocker): + await TestModel.aio_create(text='FOO', data="") + mocker.patch.object(Transaction, "rollback", side_effect=FakeConnectionError) + with pytest.raises(FakeConnectionError): + async with db.aio_atomic(): + await TestModel.update(data="BAR").aio_execute() + assert await TestModel.aio_get_or_none(data="BAR") is not None + await TestModel.aio_create(text='FOO') - assert await TestModel.aio_get_or_none(text="FOO") is not None assert db.aio_pool.has_acquired_connections() is False @dbs_all -async def test_savepoint_rollback(db): - await TestModel.aio_create(text='FOO', data="") - +async def test_transaction_success(db): async with db.aio_atomic(): - await TestModel.update(data="BAR").aio_execute() - - try: - async with db.aio_atomic(): - await TestModel.aio_create(text='FOO') - except: - pass + await TestModel.aio_create(text='FOO') - assert await TestModel.aio_get_or_none(data="BAR") is not None + assert await TestModel.aio_get_or_none(text="FOO") is not None assert db.aio_pool.has_acquired_connections() is False @@ -48,13 +55,11 @@ async def test_savepoint_rollback(db): async def test_transaction_rollback(db): await TestModel.aio_create(text='FOO', data="") - try: + with pytest.raises(IntegrityError): async with db.aio_atomic(): await TestModel.update(data="BAR").aio_execute() assert await TestModel.aio_get_or_none(data="BAR") is not None await TestModel.aio_create(text='FOO') - except: - pass assert await TestModel.aio_get_or_none(data="BAR") is None assert db.aio_pool.has_acquired_connections() is False @@ -72,11 +77,9 @@ async def t1(): async def t2(): async with db.aio_atomic(): await TestModel.aio_create(text='FOO2', data="") - try: + with pytest.raises(IntegrityError): async with db.aio_atomic(): await TestModel.aio_create(text='FOO2', data="not_created") - except: - pass async def t3(): async with db.aio_atomic(): @@ -110,6 +113,33 @@ async def test_transaction_manual_work(db): assert db.aio_pool.has_acquired_connections() is False +@dbs_all +async def test_savepoint_success(db): + async with db.aio_atomic(): + await TestModel.aio_create(text='FOO') + + async with db.aio_atomic(): + await TestModel.update(text="BAR").aio_execute() + + assert await TestModel.aio_get_or_none(text="BAR") is not None + assert db.aio_pool.has_acquired_connections() is False + + +@dbs_all +async def test_savepoint_rollback(db): + await TestModel.aio_create(text='FOO', data="") + + async with db.aio_atomic(): + await TestModel.update(data="BAR").aio_execute() + + with pytest.raises(IntegrityError): + async with db.aio_atomic(): + await TestModel.aio_create(text='FOO') + + assert await TestModel.aio_get_or_none(data="BAR") is not None + assert db.aio_pool.has_acquired_connections() is False + + @dbs_all async def test_savepoint_manual_work(db): async with db.aio_connection() as connection: @@ -178,3 +208,4 @@ async def insert_records(event_for_wait: asyncio.Event): # The transaction has not been committed assert len(list(await TestModel.select().aio_execute())) == 0 assert db.aio_pool.has_acquired_connections() is False +