diff --git a/aiopg/sa/connection.py b/aiopg/sa/connection.py index 7897e89b..2f5ab88f 100644 --- a/aiopg/sa/connection.py +++ b/aiopg/sa/connection.py @@ -192,7 +192,6 @@ def _rollback_impl(self): cur.close() self._transaction = None - @asyncio.coroutine def begin_nested(self): """Begin a nested transaction and return a transaction handle. @@ -204,6 +203,11 @@ def begin_nested(self): still controls the overall .commit() or .rollback() of the transaction of a whole. """ + coro = self._begin_nested() + return _TransactionContextManager(coro) + + @asyncio.coroutine + def _begin_nested(self): if self._transaction is None: self._transaction = RootTransaction(self) yield from self._begin_impl() diff --git a/tests/pep492/test_async_await.py b/tests/pep492/test_async_await.py index c61ec0b7..13514894 100644 --- a/tests/pep492/test_async_await.py +++ b/tests/pep492/test_async_await.py @@ -214,6 +214,36 @@ async def test_transaction_context_manager_commit_once(pg_params, loop): assert conn.closed +@asyncio.coroutine +async def test_transaction_context_manager_nested_commit(pg_params, loop): + sql = 'SELECT generate_series(1, 5);' + result = [] + async with aiopg.sa.create_engine(loop=loop, **pg_params) as engine: + async with engine.acquire() as conn: + async with conn.begin_nested() as tr1: + async with conn.begin_nested() as tr2: + async with conn.execute(sql) as cursor: + async for v in cursor: + result.append(v) + assert tr1.is_active + assert tr2.is_active + assert result == [(1,), (2, ), (3, ), (4, ), (5, )] + assert cursor.closed + assert not tr2.is_active + + tr2 = await conn.begin_nested() + async with tr2: + assert tr2.is_active + async with conn.execute('SELECT 1;') as cursor: + rec = await cursor.scalar() + assert rec == 1 + cursor.close() + assert not tr2.is_active + assert not tr1.is_active + + assert conn.closed + + @asyncio.coroutine async def test_sa_connection_execute(pg_params, loop): sql = 'SELECT generate_series(1, 5);'