Skip to content

Commit

Permalink
Guard transaction methods against underlying connection release
Browse files Browse the repository at this point in the history
Similarly to other connection-dependent objects, transaction methods
should not be called once the underlying connection has been released to
the pool.

Also, add a special handling for the case of asynchronous generator
finalization, in which case it's OK for `Transaction.__aexit__()` to be
called _after_ `Pool.release()`, since we cannot control when the
finalization task would execute.

Fixes: #232.
  • Loading branch information
elprans committed Dec 4, 2017
1 parent 46f468c commit 59e2878
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 13 deletions.
20 changes: 13 additions & 7 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,19 @@ async def reset(self, *, timeout=None):
self._listeners.clear()
self._log_listeners.clear()
reset_query = self._get_reset_query()

if self._protocol.is_in_transaction() or self._top_xact is not None:
if self._top_xact is None or not self._top_xact._managed:
# Managed transactions are guaranteed to __aexit__
# correctly.
self._loop.call_exception_handler({
'message': 'Resetting connection with an '
'active transaction {!r}'.format(self)
})

self._top_xact = None
reset_query = 'ROLLBACK;\n' + reset_query

if reset_query:
await self.execute(reset_query, timeout=timeout)

Expand Down Expand Up @@ -1152,13 +1165,6 @@ def _get_reset_query(self):
caps = self._server_caps

_reset_query = []
if self._protocol.is_in_transaction() or self._top_xact is not None:
self._loop.call_exception_handler({
'message': 'Resetting connection with an '
'active transaction {!r}'.format(self)
})
self._top_xact = None
_reset_query.append('ROLLBACK;')
if caps.advisory_locks:
_reset_query.append('SELECT pg_advisory_unlock_all();')
if caps.sql_close_all:
Expand Down
6 changes: 6 additions & 0 deletions asyncpg/connresource.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,9 @@ def _check_conn_validity(self, meth_name):
'cannot call {}.{}(): '
'the underlying connection has been released back '
'to the pool'.format(self.__class__.__name__, meth_name))

if self._connection.is_closed():
raise exceptions.InterfaceError(
'cannot call {}.{}(): '
'the underlying connection is closed'.format(
self.__class__.__name__, meth_name))
8 changes: 4 additions & 4 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,14 @@ def __getattr__(self, attr):
# Proxy all unresolved attributes to the wrapped Connection object.
return getattr(self._con, attr)

def _detach(self):
def _detach(self) -> connection.Connection:
if self._con is None:
raise exceptions.InterfaceError(
'cannot detach PoolConnectionProxy: already detached')

con, self._con = self._con, None
con._set_proxy(None)
return con

def __repr__(self):
if self._con is None:
Expand Down Expand Up @@ -179,8 +180,6 @@ async def release(self, timeout):
self._in_use = False
self._timeout = None

self._con._on_release()

if self._con.is_closed():
self._con = None

Expand Down Expand Up @@ -508,7 +507,8 @@ async def _release_impl(ch: PoolConnectionHolder, timeout: float):
# Already released, do nothing.
return

connection._detach()
con = connection._detach()
con._on_release()

if timeout is None:
timeout = connection._holder._timeout
Expand Down
25 changes: 23 additions & 2 deletions asyncpg/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import enum

from . import connresource
from . import exceptions as apg_errors


Expand All @@ -21,7 +22,7 @@ class TransactionState(enum.Enum):
ISOLATION_LEVELS = {'read_committed', 'serializable', 'repeatable_read'}


class Transaction:
class Transaction(connresource.ConnectionResource):
"""Represents a transaction or savepoint block.
Transactions are created by calling the
Expand All @@ -33,6 +34,8 @@ class Transaction:
'_state', '_nested', '_id', '_managed')

def __init__(self, connection, isolation, readonly, deferrable):
super().__init__(connection)

if isolation not in ISOLATION_LEVELS:
raise ValueError(
'isolation is expected to be either of {}, '
Expand All @@ -49,7 +52,6 @@ def __init__(self, connection, isolation, readonly, deferrable):
'"deferrable" is only supported for '
'serializable readonly transactions')

self._connection = connection
self._isolation = isolation
self._readonly = readonly
self._deferrable = deferrable
Expand All @@ -66,6 +68,22 @@ async def __aenter__(self):
await self.start()

async def __aexit__(self, extype, ex, tb):
try:
self._check_conn_validity('__aexit__')
except apg_errors.InterfaceError:
if extype is GeneratorExit:
# When a PoolAcquireContext is being exited, and there
# is an open transaction in an async generator that has
# not been iterated fully, there is a possibility that
# Pool.release() would race with this __aexit__(), since
# both would be in concurrent tasks. In such case we
# yield to Pool.release() to do the ROLLBACK for us.
# See https://github.com/MagicStack/asyncpg/issues/232
# for an example.
return
else:
raise

try:
if extype is not None:
await self.__rollback()
Expand All @@ -74,6 +92,7 @@ async def __aexit__(self, extype, ex, tb):
finally:
self._managed = False

@connresource.guarded
async def start(self):
"""Enter the transaction or savepoint block."""
self.__check_state_base('start')
Expand Down Expand Up @@ -183,13 +202,15 @@ async def __rollback(self):
else:
self._state = TransactionState.ROLLEDBACK

@connresource.guarded
async def commit(self):
"""Exit the transaction or savepoint block and commit changes."""
if self._managed:
raise apg_errors.InterfaceError(
'cannot manually commit from within an `async with` block')
await self.__commit()

@connresource.guarded
async def rollback(self):
"""Exit the transaction or savepoint block and rollback changes."""
if self._managed:
Expand Down
80 changes: 80 additions & 0 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import os
import platform
import random
import sys
import textwrap
import time
import unittest

Expand Down Expand Up @@ -195,6 +197,7 @@ async def test_pool_11(self):
self.assertIn(repr(con._con), repr(con)) # Test __repr__.

ps = await con.prepare('SELECT 1')
txn = con.transaction()
async with con.transaction():
cur = await con.cursor('SELECT 1')
ps_cur = await ps.cursor()
Expand Down Expand Up @@ -233,6 +236,14 @@ async def test_pool_11(self):

c.forward(1)

for meth in ('start', 'commit', 'rollback'):
with self.assertRaisesRegex(
asyncpg.InterfaceError,
r'cannot call Transaction\.{meth}.*released '
r'back to the pool'.format(meth=meth)):

getattr(txn, meth)()

await pool.close()

async def test_pool_12(self):
Expand Down Expand Up @@ -661,6 +672,75 @@ async def test_pool_handles_inactive_connection_errors(self):
await con.close()
await pool.close()

@unittest.skipIf(sys.version_info[:2] < (3, 6), 'no asyncgen support')
async def test_pool_handles_transaction_exit_in_asyncgen_1(self):
pool = await self.create_pool(database='postgres',
min_size=1, max_size=1)

locals_ = {}
exec(textwrap.dedent('''\
async def iterate(con):
async with con.transaction():
for record in await con.fetch("SELECT 1"):
yield record
'''), globals(), locals_)
iterate = locals_['iterate']

class MyException(Exception):
pass

with self.assertRaises(MyException):
async with pool.acquire() as con:
async for _ in iterate(con): # noqa
raise MyException()

@unittest.skipIf(sys.version_info[:2] < (3, 6), 'no asyncgen support')
async def test_pool_handles_transaction_exit_in_asyncgen_2(self):
pool = await self.create_pool(database='postgres',
min_size=1, max_size=1)

locals_ = {}
exec(textwrap.dedent('''\
async def iterate(con):
async with con.transaction():
for record in await con.fetch("SELECT 1"):
yield record
'''), globals(), locals_)
iterate = locals_['iterate']

class MyException(Exception):
pass

with self.assertRaises(MyException):
async with pool.acquire() as con:
iterator = iterate(con)
async for _ in iterator: # noqa
raise MyException()

del iterator

@unittest.skipIf(sys.version_info[:2] < (3, 6), 'no asyncgen support')
async def test_pool_handles_asyncgen_finalization(self):
pool = await self.create_pool(database='postgres',
min_size=1, max_size=1)

locals_ = {}
exec(textwrap.dedent('''\
async def iterate(con):
for record in await con.fetch("SELECT 1"):
yield record
'''), globals(), locals_)
iterate = locals_['iterate']

class MyException(Exception):
pass

with self.assertRaises(MyException):
async with pool.acquire() as con:
async with con.transaction():
async for _ in iterate(con): # noqa
raise MyException()


@unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing')
class TestHotStandby(tb.ConnectedTestCase):
Expand Down

0 comments on commit 59e2878

Please sign in to comment.