Skip to content

Commit aaaccf9

Browse files
committed
fix: rewrite transaction context manager
1 parent 1a44d4b commit aaaccf9

File tree

3 files changed

+73
-53
lines changed

3 files changed

+73
-53
lines changed

peewee_async.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -236,28 +236,6 @@ async def __aexit__(self, *args):
236236
connection_context.set(None)
237237

238238

239-
class TransactionContextManager(ConnectionContextManager):
240-
async def __aenter__(self):
241-
connection = await super().__aenter__()
242-
begin_transaction = self.connection_context.transaction_is_opened is False
243-
244-
self.transaction = Transaction(connection, is_savepoint=begin_transaction is False)
245-
await self.transaction.__aenter__()
246-
247-
if begin_transaction is True:
248-
self.connection_context.transaction_is_opened = True
249-
return connection
250-
251-
async def __aexit__(self, exc_type, exc_value, exc_tb):
252-
await self.transaction.__aexit__(exc_type, exc_value, exc_tb)
253-
254-
end_transaction = self.transaction.is_savepoint is False
255-
if end_transaction is True:
256-
self.connection_context.transaction_is_opened = False
257-
258-
await super().__aexit__()
259-
260-
261239
############
262240
# Database #
263241
############
@@ -286,11 +264,21 @@ async def aio_close(self):
286264
"""
287265
await self.aio_pool.terminate()
288266

289-
def aio_atomic(self):
267+
@contextlib.asynccontextmanager
268+
async def aio_atomic(self):
290269
"""Similar to peewee `Database.atomic()` method, but returns
291270
asynchronous context manager.
292271
"""
293-
return TransactionContextManager(self.aio_pool)
272+
async with self.aio_connection() as connection:
273+
_connection_context = connection_context.get()
274+
begin_transaction = _connection_context.transaction_is_opened is False
275+
try:
276+
async with Transaction(connection, is_savepoint=begin_transaction is False):
277+
_connection_context.transaction_is_opened = True
278+
yield
279+
finally:
280+
if begin_transaction is True:
281+
_connection_context.transaction_is_opened = False
294282

295283
def set_allow_sync(self, value):
296284
"""Allow or forbid sync queries for the database. See also

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "peewee-async"
3-
version = "0.9.1"
3+
version = "0.10.1-beta"
44
description = "Asynchronous interface for peewee ORM powered by asyncio."
55
authors = ["Alexey Kinev <rudy@05bit.com>", "Gorshkov Nikolay(contributor) <nogamemorebrain@gmail.com>"]
66
readme = "README.md"
@@ -19,13 +19,14 @@ aiomysql = { version = "^0.2.0", optional = true }
1919
cryptography = { version = "^41.0.3", optional = true }
2020
pytest = { version = "^7.4.1", optional = true }
2121
pytest-asyncio = { version = "^0.21.1", optional = true }
22+
pytest-mock = { version = "^3.14.0", optional = true }
2223
sphinx = { version = "^7.1.2", optional = true }
2324
sphinx-rtd-theme = { version = "^1.3.0rc1", optional = true }
2425

2526
[tool.poetry.extras]
2627
postgresql = ["aiopg"]
2728
mysql = ["aiomysql", "cryptography"]
28-
develop = ["aiopg", "aiomysql", "cryptography", "pytest", "pytest-asyncio"]
29+
develop = ["aiopg", "aiomysql", "cryptography", "pytest", "pytest-asyncio", "pytest-mock"]
2930
docs = ["aiopg", "aiomysql", "cryptography", "sphinx", "sphinx-rtd-theme"]
3031

3132
[build-system]

tests/test_transaction.py

Lines changed: 58 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,65 @@
11
import asyncio
22

3+
import pytest
4+
from peewee import IntegrityError
5+
36
from peewee_async import Transaction
47
from tests.conftest import dbs_all
58
from tests.models import TestModel
69

710

8-
@dbs_all
9-
async def test_savepoint_success(db):
11+
class FakeConnectionError(Exception):
12+
pass
1013

11-
async with db.aio_atomic():
12-
await TestModel.aio_create(text='FOO')
1314

15+
@dbs_all
16+
async def test_transaction_error_on_begin(db, mocker):
17+
mocker.patch.object(Transaction, "begin", side_effect=FakeConnectionError)
18+
with pytest.raises(FakeConnectionError):
1419
async with db.aio_atomic():
15-
await TestModel.update(text="BAR").aio_execute()
20+
await TestModel.aio_create(text='FOO')
21+
assert db.aio_pool.has_acquired_connections() is False
1622

17-
assert await TestModel.aio_get_or_none(text="BAR") is not None
23+
@dbs_all
24+
async def test_transaction_error_on_commit(db, mocker):
25+
mocker.patch.object(Transaction, "commit", side_effect=FakeConnectionError)
26+
with pytest.raises(FakeConnectionError):
27+
async with db.aio_atomic():
28+
await TestModel.aio_create(text='FOO')
1829
assert db.aio_pool.has_acquired_connections() is False
1930

2031

2132
@dbs_all
22-
async def test_transaction_success(db):
23-
async with db.aio_atomic():
24-
await TestModel.aio_create(text='FOO')
33+
async def test_transaction_error_on_rollback(db, mocker):
34+
await TestModel.aio_create(text='FOO', data="")
35+
mocker.patch.object(Transaction, "rollback", side_effect=FakeConnectionError)
36+
with pytest.raises(FakeConnectionError):
37+
async with db.aio_atomic():
38+
await TestModel.update(data="BAR").aio_execute()
39+
assert await TestModel.aio_get_or_none(data="BAR") is not None
40+
await TestModel.aio_create(text='FOO')
2541

26-
assert await TestModel.aio_get_or_none(text="FOO") is not None
2742
assert db.aio_pool.has_acquired_connections() is False
2843

2944

3045
@dbs_all
31-
async def test_savepoint_rollback(db):
32-
await TestModel.aio_create(text='FOO', data="")
33-
46+
async def test_transaction_success(db):
3447
async with db.aio_atomic():
35-
await TestModel.update(data="BAR").aio_execute()
36-
37-
try:
38-
async with db.aio_atomic():
39-
await TestModel.aio_create(text='FOO')
40-
except:
41-
pass
48+
await TestModel.aio_create(text='FOO')
4249

43-
assert await TestModel.aio_get_or_none(data="BAR") is not None
50+
assert await TestModel.aio_get_or_none(text="FOO") is not None
4451
assert db.aio_pool.has_acquired_connections() is False
4552

4653

4754
@dbs_all
4855
async def test_transaction_rollback(db):
4956
await TestModel.aio_create(text='FOO', data="")
5057

51-
try:
58+
with pytest.raises(IntegrityError):
5259
async with db.aio_atomic():
5360
await TestModel.update(data="BAR").aio_execute()
5461
assert await TestModel.aio_get_or_none(data="BAR") is not None
5562
await TestModel.aio_create(text='FOO')
56-
except:
57-
pass
5863

5964
assert await TestModel.aio_get_or_none(data="BAR") is None
6065
assert db.aio_pool.has_acquired_connections() is False
@@ -72,11 +77,9 @@ async def t1():
7277
async def t2():
7378
async with db.aio_atomic():
7479
await TestModel.aio_create(text='FOO2', data="")
75-
try:
80+
with pytest.raises(IntegrityError):
7681
async with db.aio_atomic():
7782
await TestModel.aio_create(text='FOO2', data="not_created")
78-
except:
79-
pass
8083

8184
async def t3():
8285
async with db.aio_atomic():
@@ -110,6 +113,33 @@ async def test_transaction_manual_work(db):
110113
assert db.aio_pool.has_acquired_connections() is False
111114

112115

116+
@dbs_all
117+
async def test_savepoint_success(db):
118+
async with db.aio_atomic():
119+
await TestModel.aio_create(text='FOO')
120+
121+
async with db.aio_atomic():
122+
await TestModel.update(text="BAR").aio_execute()
123+
124+
assert await TestModel.aio_get_or_none(text="BAR") is not None
125+
assert db.aio_pool.has_acquired_connections() is False
126+
127+
128+
@dbs_all
129+
async def test_savepoint_rollback(db):
130+
await TestModel.aio_create(text='FOO', data="")
131+
132+
async with db.aio_atomic():
133+
await TestModel.update(data="BAR").aio_execute()
134+
135+
with pytest.raises(IntegrityError):
136+
async with db.aio_atomic():
137+
await TestModel.aio_create(text='FOO')
138+
139+
assert await TestModel.aio_get_or_none(data="BAR") is not None
140+
assert db.aio_pool.has_acquired_connections() is False
141+
142+
113143
@dbs_all
114144
async def test_savepoint_manual_work(db):
115145
async with db.aio_connection() as connection:
@@ -178,3 +208,4 @@ async def insert_records(event_for_wait: asyncio.Event):
178208
# The transaction has not been committed
179209
assert len(list(await TestModel.select().aio_execute())) == 0
180210
assert db.aio_pool.has_acquired_connections() is False
211+

0 commit comments

Comments
 (0)