diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 97ab8998..3cc897d1 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -187,9 +187,7 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: if not args: return await self._protocol.query(query, timeout) - stmt = await self._get_statement(query, timeout) - _, status, _ = await self._protocol.bind_execute(stmt, args, '', 0, - True, timeout) + _, status, _ = await self._do_execute(query, args, 0, timeout, True) return status.decode() async def executemany(self, command: str, args, timeout: float=None): @@ -283,10 +281,7 @@ async def fetch(self, query, *args, timeout=None) -> list: :return list: A list of :class:`Record` instances. """ - stmt = await self._get_statement(query, timeout) - data = await self._protocol.bind_execute(stmt, args, '', 0, - False, timeout) - return data + return await self._do_execute(query, args, 0, timeout) async def fetchval(self, query, *args, column=0, timeout=None): """Run a query and return a value in the first row. @@ -302,9 +297,7 @@ async def fetchval(self, query, *args, column=0, timeout=None): :return: The value of the specified column of the first record. """ - stmt = await self._get_statement(query, timeout) - data = await self._protocol.bind_execute(stmt, args, '', 1, - False, timeout) + data = await self._do_execute(query, args, 1, timeout) if not data: return None return data[0][column] @@ -318,9 +311,7 @@ async def fetchrow(self, query, *args, timeout=None): :return: The first row as a :class:`Record` instance. """ - stmt = await self._get_statement(query, timeout) - data = await self._protocol.bind_execute(stmt, args, '', 1, - False, timeout) + data = await self._do_execute(query, args, 1, timeout) if not data: return None return data[0] @@ -551,6 +542,60 @@ def _set_proxy(self, proxy): self._proxy = proxy + def _drop_local_statement_cache(self): + self._stmt_cache.clear() + + def _drop_global_statement_cache(self): + if self._proxy is not None: + # This connection is a member of a pool, so we delegate + # the cache drop to the pool. + pool = self._proxy._holder._pool + pool._drop_statement_cache() + else: + self._drop_local_statement_cache() + + async def _do_execute(self, query, args, limit, timeout, + return_status=False): + stmt = await self._get_statement(query, timeout) + + try: + result = await self._protocol.bind_execute( + stmt, args, '', limit, return_status, timeout) + + except exceptions.InvalidCachedStatementError as e: + # PostgreSQL will raise an exception when it detects + # that the result type of the query has changed from + # when the statement was prepared. This may happen, + # for example, after an ALTER TABLE or SET search_path. + # + # When this happens, and there is no transaction running, + # we can simply re-prepare the statement and try once + # again. We deliberately retry only once as this is + # supposed to be a rare occurrence. + # + # If the transaction _is_ running, this error will put it + # into an error state, and we have no choice but to + # re-raise the exception. + # + # In either case we clear the statement cache for this + # connection and all other connections of the pool this + # connection belongs to (if any). + # + # See https://github.com/MagicStack/asyncpg/issues/72 + # and https://github.com/MagicStack/asyncpg/issues/76 + # for discussion. + # + self._drop_global_statement_cache() + + if self._protocol.is_in_transaction(): + raise + else: + stmt = await self._get_statement(query, timeout) + result = await self._protocol.bind_execute( + stmt, args, '', limit, return_status, timeout) + + return result + async def connect(dsn=None, *, host=None, port=None, diff --git a/asyncpg/exceptions/__init__.py b/asyncpg/exceptions/__init__.py index 07d54139..30bca3f4 100644 --- a/asyncpg/exceptions/__init__.py +++ b/asyncpg/exceptions/__init__.py @@ -1,10 +1,3 @@ -# Copyright (C) 2016-present the ayncpg authors and contributors -# -# -# This module is part of asyncpg and is released under -# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 - - # GENERATED FROM postgresql/src/backend/utils/errcodes.txt # DO NOT MODIFY, use tools/generate_exceptions.py to update @@ -92,6 +85,10 @@ class FeatureNotSupportedError(_base.PostgresError): sqlstate = '0A000' +class InvalidCachedStatementError(FeatureNotSupportedError): + pass + + class InvalidTransactionInitiationError(_base.PostgresError): sqlstate = '0B000' @@ -1025,15 +1022,16 @@ class IndexCorruptedError(InternalServerError): 'InvalidArgumentForPowerFunctionError', 'InvalidArgumentForWidthBucketFunctionError', 'InvalidAuthorizationSpecificationError', - 'InvalidBinaryRepresentationError', 'InvalidCatalogNameError', - 'InvalidCharacterValueForCastError', 'InvalidColumnDefinitionError', - 'InvalidColumnReferenceError', 'InvalidCursorDefinitionError', - 'InvalidCursorNameError', 'InvalidCursorStateError', - 'InvalidDatabaseDefinitionError', 'InvalidDatetimeFormatError', - 'InvalidEscapeCharacterError', 'InvalidEscapeOctetError', - 'InvalidEscapeSequenceError', 'InvalidForeignKeyError', - 'InvalidFunctionDefinitionError', 'InvalidGrantOperationError', - 'InvalidGrantorError', 'InvalidIndicatorParameterValueError', + 'InvalidBinaryRepresentationError', 'InvalidCachedStatementError', + 'InvalidCatalogNameError', 'InvalidCharacterValueForCastError', + 'InvalidColumnDefinitionError', 'InvalidColumnReferenceError', + 'InvalidCursorDefinitionError', 'InvalidCursorNameError', + 'InvalidCursorStateError', 'InvalidDatabaseDefinitionError', + 'InvalidDatetimeFormatError', 'InvalidEscapeCharacterError', + 'InvalidEscapeOctetError', 'InvalidEscapeSequenceError', + 'InvalidForeignKeyError', 'InvalidFunctionDefinitionError', + 'InvalidGrantOperationError', 'InvalidGrantorError', + 'InvalidIndicatorParameterValueError', 'InvalidLocatorSpecificationError', 'InvalidNameError', 'InvalidObjectDefinitionError', 'InvalidParameterValueError', 'InvalidPasswordError', 'InvalidPreparedStatementDefinitionError', diff --git a/asyncpg/exceptions/_base.py b/asyncpg/exceptions/_base.py index 89ac6f99..295671db 100644 --- a/asyncpg/exceptions/_base.py +++ b/asyncpg/exceptions/_base.py @@ -12,6 +12,11 @@ 'InterfaceError') +def _is_asyncpg_class(cls): + modname = cls.__module__ + return modname == 'asyncpg' or modname.startswith('asyncpg.') + + class PostgresMessageMeta(type): _message_map = {} _field_map = { @@ -40,8 +45,7 @@ def __new__(mcls, name, bases, dct): for f in mcls._field_map.values(): setattr(cls, f, None) - if (cls.__module__ == 'asyncpg' or - cls.__module__.startswith('asyncpg.')): + if _is_asyncpg_class(cls): mod = sys.modules[cls.__module__] if hasattr(mod, name): raise RuntimeError('exception class redefinition: {}'.format( @@ -74,21 +78,51 @@ def __str__(self): return msg @classmethod - def new(cls, fields, query=None): + def _get_error_template(cls, fields, query): errcode = fields.get('C') mcls = cls.__class__ exccls = mcls.get_message_class_for_sqlstate(errcode) - mapped = { + dct = { 'query': query } for k, v in fields.items(): field = mcls._field_map.get(k) if field: - mapped[field] = v + dct[field] = v - e = exccls(mapped.get('message', '')) - e.__dict__.update(mapped) + return exccls, dct + + @classmethod + def new(cls, fields, query=None): + exccls, dct = cls._get_error_template(fields, query) + + message = dct.get('message', '') + + # PostgreSQL will raise an exception when it detects + # that the result type of the query has changed from + # when the statement was prepared. + # + # The original error is somewhat cryptic and unspecific, + # so we raise a custom subclass that is easier to handle + # and identify. + # + # Note that we specifically do not rely on the error + # message, as it is localizable. + is_icse = ( + exccls.__name__ == 'FeatureNotSupportedError' and + _is_asyncpg_class(exccls) and + dct.get('server_source_function') == 'RevalidateCachedQuery' + ) + + if is_icse: + exceptions = sys.modules[exccls.__module__] + exccls = exceptions.InvalidCachedStatementError + message = ('cached statement plan is invalid due to a database ' + 'schema or configuration change') + + e = exccls(message) + e.__dict__.update(dct) return e diff --git a/asyncpg/pool.py b/asyncpg/pool.py index 3cfd859e..edc3c08f 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -402,6 +402,12 @@ def _check_init(self): if self._closed: raise exceptions.InterfaceError('pool is closed') + def _drop_statement_cache(self): + # Drop statement cache for all connections in the pool. + for ch in self._holders: + if ch._con is not None: + ch._con._drop_local_statement_cache() + def __await__(self): return self._async__init__().__await__() diff --git a/asyncpg/prepared_stmt.py b/asyncpg/prepared_stmt.py index fa35c9d0..4e75ebc7 100644 --- a/asyncpg/prepared_stmt.py +++ b/asyncpg/prepared_stmt.py @@ -154,11 +154,7 @@ async def fetch(self, *args, timeout=None): :return: A list of :class:`Record` instances. """ - self.__check_open() - protocol = self._connection._protocol - data, status, _ = await protocol.bind_execute( - self._state, args, '', 0, True, timeout) - self._last_status = status + data = await self.__bind_execute(args, 0, timeout) return data async def fetchval(self, *args, column=0, timeout=None): @@ -174,11 +170,7 @@ async def fetchval(self, *args, column=0, timeout=None): :return: The value of the specified column of the first record. """ - self.__check_open() - protocol = self._connection._protocol - data, status, _ = await protocol.bind_execute( - self._state, args, '', 1, True, timeout) - self._last_status = status + data = await self.__bind_execute(args, 1, timeout) if not data: return None return data[0][column] @@ -192,14 +184,18 @@ async def fetchrow(self, *args, timeout=None): :return: The first row as a :class:`Record` instance. """ + data = await self.__bind_execute(args, 1, timeout) + if not data: + return None + return data[0] + + async def __bind_execute(self, args, limit, timeout): self.__check_open() protocol = self._connection._protocol data, status, _ = await protocol.bind_execute( - self._state, args, '', 1, True, timeout) + self._state, args, '', limit, True, timeout) self._last_status = status - if not data: - return None - return data[0] + return data def __check_open(self): if self._state.closed: diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index 59a3d387..87ab8507 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -121,7 +121,9 @@ cdef class BaseProtocol(CoreProtocol): return self.settings def is_in_transaction(self): - return self.xact_status == PQTRANS_INTRANS + # PQTRANS_INTRANS = idle, within transaction block + # PQTRANS_INERROR = idle, within failed transaction + return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR) async def prepare(self, stmt_name, query, timeout): if self.cancel_waiter is not None: diff --git a/tests/test_cache_invalidation.py b/tests/test_cache_invalidation.py new file mode 100644 index 00000000..24e23eef --- /dev/null +++ b/tests/test_cache_invalidation.py @@ -0,0 +1,85 @@ +# Copyright (C) 2016-present the ayncpg authors and contributors +# +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import asyncpg +from asyncpg import _testbase as tb + + +class TestCacheInvalidation(tb.ConnectedTestCase): + async def test_prepare_cache_invalidation_silent(self): + await self.con.execute('CREATE TABLE tab1(a int, b int)') + + try: + await self.con.execute('INSERT INTO tab1 VALUES (1, 2)') + result = await self.con.fetchrow('SELECT * FROM tab1') + self.assertEqual(result, (1, 2)) + + await self.con.execute( + 'ALTER TABLE tab1 ALTER COLUMN b SET DATA TYPE text') + + result = await self.con.fetchrow('SELECT * FROM tab1') + self.assertEqual(result, (1, '2')) + finally: + await self.con.execute('DROP TABLE tab1') + + async def test_prepare_cache_invalidation_in_transaction(self): + await self.con.execute('CREATE TABLE tab1(a int, b int)') + + try: + await self.con.execute('INSERT INTO tab1 VALUES (1, 2)') + result = await self.con.fetchrow('SELECT * FROM tab1') + self.assertEqual(result, (1, 2)) + + await self.con.execute( + 'ALTER TABLE tab1 ALTER COLUMN b SET DATA TYPE text') + + with self.assertRaisesRegex(asyncpg.InvalidCachedStatementError, + 'cached statement plan is invalid'): + async with self.con.transaction(): + result = await self.con.fetchrow('SELECT * FROM tab1') + + # This is now OK, + result = await self.con.fetchrow('SELECT * FROM tab1') + self.assertEqual(result, (1, '2')) + finally: + await self.con.execute('DROP TABLE tab1') + + async def test_prepare_cache_invalidation_in_pool(self): + pool = await self.create_pool(database='postgres', + min_size=2, max_size=2) + + await self.con.execute('CREATE TABLE tab1(a int, b int)') + + try: + await self.con.execute('INSERT INTO tab1 VALUES (1, 2)') + + con1 = await pool.acquire() + con2 = await pool.acquire() + + result = await con1.fetchrow('SELECT * FROM tab1') + self.assertEqual(result, (1, 2)) + + result = await con2.fetchrow('SELECT * FROM tab1') + self.assertEqual(result, (1, 2)) + + await self.con.execute( + 'ALTER TABLE tab1 ALTER COLUMN b SET DATA TYPE text') + + # con1 tries the same plan, will invalidate the cache + # for the entire pool. + result = await con1.fetchrow('SELECT * FROM tab1') + self.assertEqual(result, (1, '2')) + + async with con2.transaction(): + # This should work, as con1 should have invalidated + # the plan cache. + result = await con2.fetchrow('SELECT * FROM tab1') + self.assertEqual(result, (1, '2')) + + finally: + await self.con.execute('DROP TABLE tab1') + await pool.close() diff --git a/tests/test_prepare.py b/tests/test_prepare.py index d09187cb..804eb7c6 100644 --- a/tests/test_prepare.py +++ b/tests/test_prepare.py @@ -406,3 +406,21 @@ async def test_prepare_22_empty(self): result = await self.con.fetchrow('SELECT') self.assertEqual(result, ()) self.assertEqual(repr(result), '') + + async def test_prepare_statement_invalid(self): + await self.con.execute('CREATE TABLE tab1(a int, b int)') + + try: + await self.con.execute('INSERT INTO tab1 VALUES (1, 2)') + + stmt = await self.con.prepare('SELECT * FROM tab1') + + await self.con.execute( + 'ALTER TABLE tab1 ALTER COLUMN b SET DATA TYPE text') + + with self.assertRaisesRegex(asyncpg.InvalidCachedStatementError, + 'cached statement plan is invalid'): + await stmt.fetchrow() + + finally: + await self.con.execute('DROP TABLE tab1') diff --git a/tools/generate_exceptions.py b/tools/generate_exceptions.py index 9543639f..5633a312 100755 --- a/tools/generate_exceptions.py +++ b/tools/generate_exceptions.py @@ -29,6 +29,13 @@ } +_subclassmap = { + # Special subclass of FeatureNotSupportedError + # raised by Postgres in RevalidateCachedQuery. + '0A000': ['InvalidCachedStatementError'] +} + + def _get_error_name(sqlstatename, msgtype, sqlstate): if sqlstate in _namemap: return _namemap[sqlstate] @@ -73,7 +80,7 @@ def main(): tpl = """\ class {clsname}({base}): - {docstring}sqlstate = '{sqlstate}'""" + {docstring}{sqlstate}""" new_section = True section_class = None @@ -85,6 +92,24 @@ class {clsname}({base}): classes = [] clsnames = set() + def _add_class(clsname, base, sqlstate, docstring): + if sqlstate: + sqlstate = "sqlstate = '{}'".format(sqlstate) + else: + sqlstate = '' + + txt = tpl.format(clsname=clsname, base=base, sqlstate=sqlstate, + docstring=docstring) + + if not sqlstate and not docstring: + txt += 'pass' + + if len(txt.splitlines()[0]) > 79: + txt = txt.replace('(', '(\n ', 1) + + classes.append(txt) + clsnames.add(clsname) + for line in errcodes.splitlines(): if not line.strip() or line.startswith('#'): continue @@ -108,8 +133,8 @@ class {clsname}({base}): continue if clsname in clsnames: - raise ValueError('dupliate exception class name: {}'.format( - clsname)) + raise ValueError( + 'duplicate exception class name: {}'.format(clsname)) if new_section: section_class = clsname @@ -134,14 +159,19 @@ class {clsname}({base}): else: docstring = '' - txt = tpl.format(clsname=clsname, base=base, sqlstate=sqlstate, - docstring=docstring) + _add_class(clsname=clsname, base=base, sqlstate=sqlstate, + docstring=docstring) - if len(txt.splitlines()[0]) > 79: - txt = txt.replace('(', '(\n ', 1) + subclasses = _subclassmap.get(sqlstate, []) + for subclass in subclasses: + existing = getattr(apg_exc, subclass, None) + if existing and existing.__doc__: + docstring = '"""{}"""\n\n '.format(existing.__doc__) + else: + docstring = '' - classes.append(txt) - clsnames.add(clsname) + _add_class(clsname=subclass, base=clsname, sqlstate=None, + docstring=docstring) buf += '\n\n\n'.join(classes)