From baf5ce75bd0223189ea28fe2fa3d033097de9aaf Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Thu, 30 Mar 2017 13:43:30 -0400 Subject: [PATCH] Invalidate statement cache on schema changes affecting statement result. PostgreSQL will raise an exception when it detects that the result type of the query has changed from when the statement was prepared. This may happen, for example, after an ALTER TABLE or SET search_path. When this happens, and there is no transaction running, we can simply re-prepare the statement and try again. If the transaction _is_ running, this error will put it into an error state, and we have no choice but to raise an exception. The original error is somewhat cryptic, so we raise a custom InvalidCachedStatementError with the original server exception as context. In either case we clear the statement cache for this connection and all other connections of the pool this connection belongs to (if any). See #72 and #76 for discussion. Fixes: #72. --- asyncpg/connection.py | 71 +++++++++++++++++++++----- asyncpg/exceptions/__init__.py | 30 ++++++----- asyncpg/exceptions/_base.py | 48 +++++++++++++++--- asyncpg/pool.py | 6 +++ asyncpg/prepared_stmt.py | 24 ++++----- asyncpg/protocol/protocol.pyx | 4 +- tests/test_cache_invalidation.py | 85 ++++++++++++++++++++++++++++++++ tests/test_prepare.py | 18 +++++++ tools/generate_exceptions.py | 48 ++++++++++++++---- 9 files changed, 274 insertions(+), 60 deletions(-) create mode 100644 tests/test_cache_invalidation.py diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 97ab8998..3cc897d1 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -187,9 +187,7 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: if not args: return await self._protocol.query(query, timeout) - stmt = await self._get_statement(query, timeout) - _, status, _ = await self._protocol.bind_execute(stmt, args, '', 0, - True, timeout) + _, status, _ = await self._do_execute(query, args, 0, timeout, True) return status.decode() async def executemany(self, command: str, args, timeout: float=None): @@ -283,10 +281,7 @@ async def fetch(self, query, *args, timeout=None) -> list: :return list: A list of :class:`Record` instances. """ - stmt = await self._get_statement(query, timeout) - data = await self._protocol.bind_execute(stmt, args, '', 0, - False, timeout) - return data + return await self._do_execute(query, args, 0, timeout) async def fetchval(self, query, *args, column=0, timeout=None): """Run a query and return a value in the first row. @@ -302,9 +297,7 @@ async def fetchval(self, query, *args, column=0, timeout=None): :return: The value of the specified column of the first record. """ - stmt = await self._get_statement(query, timeout) - data = await self._protocol.bind_execute(stmt, args, '', 1, - False, timeout) + data = await self._do_execute(query, args, 1, timeout) if not data: return None return data[0][column] @@ -318,9 +311,7 @@ async def fetchrow(self, query, *args, timeout=None): :return: The first row as a :class:`Record` instance. """ - stmt = await self._get_statement(query, timeout) - data = await self._protocol.bind_execute(stmt, args, '', 1, - False, timeout) + data = await self._do_execute(query, args, 1, timeout) if not data: return None return data[0] @@ -551,6 +542,60 @@ def _set_proxy(self, proxy): self._proxy = proxy + def _drop_local_statement_cache(self): + self._stmt_cache.clear() + + def _drop_global_statement_cache(self): + if self._proxy is not None: + # This connection is a member of a pool, so we delegate + # the cache drop to the pool. + pool = self._proxy._holder._pool + pool._drop_statement_cache() + else: + self._drop_local_statement_cache() + + async def _do_execute(self, query, args, limit, timeout, + return_status=False): + stmt = await self._get_statement(query, timeout) + + try: + result = await self._protocol.bind_execute( + stmt, args, '', limit, return_status, timeout) + + except exceptions.InvalidCachedStatementError as e: + # PostgreSQL will raise an exception when it detects + # that the result type of the query has changed from + # when the statement was prepared. This may happen, + # for example, after an ALTER TABLE or SET search_path. + # + # When this happens, and there is no transaction running, + # we can simply re-prepare the statement and try once + # again. We deliberately retry only once as this is + # supposed to be a rare occurrence. + # + # If the transaction _is_ running, this error will put it + # into an error state, and we have no choice but to + # re-raise the exception. + # + # In either case we clear the statement cache for this + # connection and all other connections of the pool this + # connection belongs to (if any). + # + # See https://github.com/MagicStack/asyncpg/issues/72 + # and https://github.com/MagicStack/asyncpg/issues/76 + # for discussion. + # + self._drop_global_statement_cache() + + if self._protocol.is_in_transaction(): + raise + else: + stmt = await self._get_statement(query, timeout) + result = await self._protocol.bind_execute( + stmt, args, '', limit, return_status, timeout) + + return result + async def connect(dsn=None, *, host=None, port=None, diff --git a/asyncpg/exceptions/__init__.py b/asyncpg/exceptions/__init__.py index 07d54139..30bca3f4 100644 --- a/asyncpg/exceptions/__init__.py +++ b/asyncpg/exceptions/__init__.py @@ -1,10 +1,3 @@ -# Copyright (C) 2016-present the ayncpg authors and contributors -# -# -# This module is part of asyncpg and is released under -# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 - - # GENERATED FROM postgresql/src/backend/utils/errcodes.txt # DO NOT MODIFY, use tools/generate_exceptions.py to update @@ -92,6 +85,10 @@ class FeatureNotSupportedError(_base.PostgresError): sqlstate = '0A000' +class InvalidCachedStatementError(FeatureNotSupportedError): + pass + + class InvalidTransactionInitiationError(_base.PostgresError): sqlstate = '0B000' @@ -1025,15 +1022,16 @@ class IndexCorruptedError(InternalServerError): 'InvalidArgumentForPowerFunctionError', 'InvalidArgumentForWidthBucketFunctionError', 'InvalidAuthorizationSpecificationError', - 'InvalidBinaryRepresentationError', 'InvalidCatalogNameError', - 'InvalidCharacterValueForCastError', 'InvalidColumnDefinitionError', - 'InvalidColumnReferenceError', 'InvalidCursorDefinitionError', - 'InvalidCursorNameError', 'InvalidCursorStateError', - 'InvalidDatabaseDefinitionError', 'InvalidDatetimeFormatError', - 'InvalidEscapeCharacterError', 'InvalidEscapeOctetError', - 'InvalidEscapeSequenceError', 'InvalidForeignKeyError', - 'InvalidFunctionDefinitionError', 'InvalidGrantOperationError', - 'InvalidGrantorError', 'InvalidIndicatorParameterValueError', + 'InvalidBinaryRepresentationError', 'InvalidCachedStatementError', + 'InvalidCatalogNameError', 'InvalidCharacterValueForCastError', + 'InvalidColumnDefinitionError', 'InvalidColumnReferenceError', + 'InvalidCursorDefinitionError', 'InvalidCursorNameError', + 'InvalidCursorStateError', 'InvalidDatabaseDefinitionError', + 'InvalidDatetimeFormatError', 'InvalidEscapeCharacterError', + 'InvalidEscapeOctetError', 'InvalidEscapeSequenceError', + 'InvalidForeignKeyError', 'InvalidFunctionDefinitionError', + 'InvalidGrantOperationError', 'InvalidGrantorError', + 'InvalidIndicatorParameterValueError', 'InvalidLocatorSpecificationError', 'InvalidNameError', 'InvalidObjectDefinitionError', 'InvalidParameterValueError', 'InvalidPasswordError', 'InvalidPreparedStatementDefinitionError', diff --git a/asyncpg/exceptions/_base.py b/asyncpg/exceptions/_base.py index 89ac6f99..295671db 100644 --- a/asyncpg/exceptions/_base.py +++ b/asyncpg/exceptions/_base.py @@ -12,6 +12,11 @@ 'InterfaceError') +def _is_asyncpg_class(cls): + modname = cls.__module__ + return modname == 'asyncpg' or modname.startswith('asyncpg.') + + class PostgresMessageMeta(type): _message_map = {} _field_map = { @@ -40,8 +45,7 @@ def __new__(mcls, name, bases, dct): for f in mcls._field_map.values(): setattr(cls, f, None) - if (cls.__module__ == 'asyncpg' or - cls.__module__.startswith('asyncpg.')): + if _is_asyncpg_class(cls): mod = sys.modules[cls.__module__] if hasattr(mod, name): raise RuntimeError('exception class redefinition: {}'.format( @@ -74,21 +78,51 @@ def __str__(self): return msg @classmethod - def new(cls, fields, query=None): + def _get_error_template(cls, fields, query): errcode = fields.get('C') mcls = cls.__class__ exccls = mcls.get_message_class_for_sqlstate(errcode) - mapped = { + dct = { 'query': query } for k, v in fields.items(): field = mcls._field_map.get(k) if field: - mapped[field] = v + dct[field] = v - e = exccls(mapped.get('message', '')) - e.__dict__.update(mapped) + return exccls, dct + + @classmethod + def new(cls, fields, query=None): + exccls, dct = cls._get_error_template(fields, query) + + message = dct.get('message', '') + + # PostgreSQL will raise an exception when it detects + # that the result type of the query has changed from + # when the statement was prepared. + # + # The original error is somewhat cryptic and unspecific, + # so we raise a custom subclass that is easier to handle + # and identify. + # + # Note that we specifically do not rely on the error + # message, as it is localizable. + is_icse = ( + exccls.__name__ == 'FeatureNotSupportedError' and + _is_asyncpg_class(exccls) and + dct.get('server_source_function') == 'RevalidateCachedQuery' + ) + + if is_icse: + exceptions = sys.modules[exccls.__module__] + exccls = exceptions.InvalidCachedStatementError + message = ('cached statement plan is invalid due to a database ' + 'schema or configuration change') + + e = exccls(message) + e.__dict__.update(dct) return e diff --git a/asyncpg/pool.py b/asyncpg/pool.py index 3cfd859e..edc3c08f 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -402,6 +402,12 @@ def _check_init(self): if self._closed: raise exceptions.InterfaceError('pool is closed') + def _drop_statement_cache(self): + # Drop statement cache for all connections in the pool. + for ch in self._holders: + if ch._con is not None: + ch._con._drop_local_statement_cache() + def __await__(self): return self._async__init__().__await__() diff --git a/asyncpg/prepared_stmt.py b/asyncpg/prepared_stmt.py index fa35c9d0..4e75ebc7 100644 --- a/asyncpg/prepared_stmt.py +++ b/asyncpg/prepared_stmt.py @@ -154,11 +154,7 @@ async def fetch(self, *args, timeout=None): :return: A list of :class:`Record` instances. """ - self.__check_open() - protocol = self._connection._protocol - data, status, _ = await protocol.bind_execute( - self._state, args, '', 0, True, timeout) - self._last_status = status + data = await self.__bind_execute(args, 0, timeout) return data async def fetchval(self, *args, column=0, timeout=None): @@ -174,11 +170,7 @@ async def fetchval(self, *args, column=0, timeout=None): :return: The value of the specified column of the first record. """ - self.__check_open() - protocol = self._connection._protocol - data, status, _ = await protocol.bind_execute( - self._state, args, '', 1, True, timeout) - self._last_status = status + data = await self.__bind_execute(args, 1, timeout) if not data: return None return data[0][column] @@ -192,14 +184,18 @@ async def fetchrow(self, *args, timeout=None): :return: The first row as a :class:`Record` instance. """ + data = await self.__bind_execute(args, 1, timeout) + if not data: + return None + return data[0] + + async def __bind_execute(self, args, limit, timeout): self.__check_open() protocol = self._connection._protocol data, status, _ = await protocol.bind_execute( - self._state, args, '', 1, True, timeout) + self._state, args, '', limit, True, timeout) self._last_status = status - if not data: - return None - return data[0] + return data def __check_open(self): if self._state.closed: diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index 59a3d387..87ab8507 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -121,7 +121,9 @@ cdef class BaseProtocol(CoreProtocol): return self.settings def is_in_transaction(self): - return self.xact_status == PQTRANS_INTRANS + # PQTRANS_INTRANS = idle, within transaction block + # PQTRANS_INERROR = idle, within failed transaction + return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR) async def prepare(self, stmt_name, query, timeout): if self.cancel_waiter is not None: diff --git a/tests/test_cache_invalidation.py b/tests/test_cache_invalidation.py new file mode 100644 index 00000000..24e23eef --- /dev/null +++ b/tests/test_cache_invalidation.py @@ -0,0 +1,85 @@ +# Copyright (C) 2016-present the ayncpg authors and contributors +# +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import asyncpg +from asyncpg import _testbase as tb + + +class TestCacheInvalidation(tb.ConnectedTestCase): + async def test_prepare_cache_invalidation_silent(self): + await self.con.execute('CREATE TABLE tab1(a int, b int)') + + try: + await self.con.execute('INSERT INTO tab1 VALUES (1, 2)') + result = await self.con.fetchrow('SELECT * FROM tab1') + self.assertEqual(result, (1, 2)) + + await self.con.execute( + 'ALTER TABLE tab1 ALTER COLUMN b SET DATA TYPE text') + + result = await self.con.fetchrow('SELECT * FROM tab1') + self.assertEqual(result, (1, '2')) + finally: + await self.con.execute('DROP TABLE tab1') + + async def test_prepare_cache_invalidation_in_transaction(self): + await self.con.execute('CREATE TABLE tab1(a int, b int)') + + try: + await self.con.execute('INSERT INTO tab1 VALUES (1, 2)') + result = await self.con.fetchrow('SELECT * FROM tab1') + self.assertEqual(result, (1, 2)) + + await self.con.execute( + 'ALTER TABLE tab1 ALTER COLUMN b SET DATA TYPE text') + + with self.assertRaisesRegex(asyncpg.InvalidCachedStatementError, + 'cached statement plan is invalid'): + async with self.con.transaction(): + result = await self.con.fetchrow('SELECT * FROM tab1') + + # This is now OK, + result = await self.con.fetchrow('SELECT * FROM tab1') + self.assertEqual(result, (1, '2')) + finally: + await self.con.execute('DROP TABLE tab1') + + async def test_prepare_cache_invalidation_in_pool(self): + pool = await self.create_pool(database='postgres', + min_size=2, max_size=2) + + await self.con.execute('CREATE TABLE tab1(a int, b int)') + + try: + await self.con.execute('INSERT INTO tab1 VALUES (1, 2)') + + con1 = await pool.acquire() + con2 = await pool.acquire() + + result = await con1.fetchrow('SELECT * FROM tab1') + self.assertEqual(result, (1, 2)) + + result = await con2.fetchrow('SELECT * FROM tab1') + self.assertEqual(result, (1, 2)) + + await self.con.execute( + 'ALTER TABLE tab1 ALTER COLUMN b SET DATA TYPE text') + + # con1 tries the same plan, will invalidate the cache + # for the entire pool. + result = await con1.fetchrow('SELECT * FROM tab1') + self.assertEqual(result, (1, '2')) + + async with con2.transaction(): + # This should work, as con1 should have invalidated + # the plan cache. + result = await con2.fetchrow('SELECT * FROM tab1') + self.assertEqual(result, (1, '2')) + + finally: + await self.con.execute('DROP TABLE tab1') + await pool.close() diff --git a/tests/test_prepare.py b/tests/test_prepare.py index d09187cb..804eb7c6 100644 --- a/tests/test_prepare.py +++ b/tests/test_prepare.py @@ -406,3 +406,21 @@ async def test_prepare_22_empty(self): result = await self.con.fetchrow('SELECT') self.assertEqual(result, ()) self.assertEqual(repr(result), '') + + async def test_prepare_statement_invalid(self): + await self.con.execute('CREATE TABLE tab1(a int, b int)') + + try: + await self.con.execute('INSERT INTO tab1 VALUES (1, 2)') + + stmt = await self.con.prepare('SELECT * FROM tab1') + + await self.con.execute( + 'ALTER TABLE tab1 ALTER COLUMN b SET DATA TYPE text') + + with self.assertRaisesRegex(asyncpg.InvalidCachedStatementError, + 'cached statement plan is invalid'): + await stmt.fetchrow() + + finally: + await self.con.execute('DROP TABLE tab1') diff --git a/tools/generate_exceptions.py b/tools/generate_exceptions.py index 9543639f..5633a312 100755 --- a/tools/generate_exceptions.py +++ b/tools/generate_exceptions.py @@ -29,6 +29,13 @@ } +_subclassmap = { + # Special subclass of FeatureNotSupportedError + # raised by Postgres in RevalidateCachedQuery. + '0A000': ['InvalidCachedStatementError'] +} + + def _get_error_name(sqlstatename, msgtype, sqlstate): if sqlstate in _namemap: return _namemap[sqlstate] @@ -73,7 +80,7 @@ def main(): tpl = """\ class {clsname}({base}): - {docstring}sqlstate = '{sqlstate}'""" + {docstring}{sqlstate}""" new_section = True section_class = None @@ -85,6 +92,24 @@ class {clsname}({base}): classes = [] clsnames = set() + def _add_class(clsname, base, sqlstate, docstring): + if sqlstate: + sqlstate = "sqlstate = '{}'".format(sqlstate) + else: + sqlstate = '' + + txt = tpl.format(clsname=clsname, base=base, sqlstate=sqlstate, + docstring=docstring) + + if not sqlstate and not docstring: + txt += 'pass' + + if len(txt.splitlines()[0]) > 79: + txt = txt.replace('(', '(\n ', 1) + + classes.append(txt) + clsnames.add(clsname) + for line in errcodes.splitlines(): if not line.strip() or line.startswith('#'): continue @@ -108,8 +133,8 @@ class {clsname}({base}): continue if clsname in clsnames: - raise ValueError('dupliate exception class name: {}'.format( - clsname)) + raise ValueError( + 'duplicate exception class name: {}'.format(clsname)) if new_section: section_class = clsname @@ -134,14 +159,19 @@ class {clsname}({base}): else: docstring = '' - txt = tpl.format(clsname=clsname, base=base, sqlstate=sqlstate, - docstring=docstring) + _add_class(clsname=clsname, base=base, sqlstate=sqlstate, + docstring=docstring) - if len(txt.splitlines()[0]) > 79: - txt = txt.replace('(', '(\n ', 1) + subclasses = _subclassmap.get(sqlstate, []) + for subclass in subclasses: + existing = getattr(apg_exc, subclass, None) + if existing and existing.__doc__: + docstring = '"""{}"""\n\n '.format(existing.__doc__) + else: + docstring = '' - classes.append(txt) - clsnames.add(clsname) + _add_class(clsname=subclass, base=clsname, sqlstate=None, + docstring=docstring) buf += '\n\n\n'.join(classes)