From e9bdf404a64878d356c52e75f35968c52ac76a41 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 28 Jul 2022 13:25:54 -0400 Subject: [PATCH 1/4] Log most recent N statements to try to get a handle on why we croak on the history of the connection use when we die of 'cannot use Connection.transaction() in a manually started transaction' --- asyncpg/connection.py | 589 ++++++++++++++++++++++++----------------- asyncpg/transaction.py | 135 ++++++---- 2 files changed, 423 insertions(+), 301 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 7c3668ca..3effc2a1 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -33,7 +33,6 @@ class ConnectionMeta(type): - def __instancecheck__(cls, instance): mro = type(instance).__mro__ return Connection in mro or _ConnectionProxy in mro @@ -45,19 +44,42 @@ class Connection(metaclass=ConnectionMeta): Connections are created by calling :func:`~asyncpg.connection.connect`. """ - __slots__ = ('_protocol', '_transport', '_loop', - '_top_xact', '_aborted', - '_pool_release_ctr', '_stmt_cache', '_stmts_to_close', - '_listeners', '_server_version', '_server_caps', - '_intro_query', '_reset_query', '_proxy', - '_stmt_exclusive_section', '_config', '_params', '_addr', - '_log_listeners', '_termination_listeners', '_cancellations', - '_source_traceback', '__weakref__') - - def __init__(self, protocol, transport, loop, - addr, - config: connect_utils._ClientConfiguration, - params: connect_utils._ConnectionParameters): + __slots__ = ( + "_protocol", + "_transport", + "_loop", + "_top_xact", + "_aborted", + "_pool_release_ctr", + "_stmt_cache", + "_stmts_to_close", + "_listeners", + "_server_version", + "_server_caps", + "_intro_query", + "_reset_query", + "_proxy", + "_stmt_exclusive_section", + "_config", + "_params", + "_addr", + "_log_listeners", + "_termination_listeners", + "_cancellations", + "_source_traceback", + "__weakref__", + "_recent_statements", + ) + + def __init__( + self, + protocol, + transport, + loop, + addr, + config: connect_utils._ClientConfiguration, + params: connect_utils._ConnectionParameters, + ): self._protocol = protocol self._transport = transport self._loop = loop @@ -75,9 +97,9 @@ def __init__(self, protocol, transport, loop, self._stmt_cache = _StatementCache( loop=loop, max_size=config.statement_cache_size, - on_remove=functools.partial( - _weak_maybe_gc_stmt, weakref.ref(self)), - max_lifetime=config.max_cached_statement_lifetime) + on_remove=functools.partial(_weak_maybe_gc_stmt, weakref.ref(self)), + max_lifetime=config.max_cached_statement_lifetime, + ) self._stmts_to_close = set() @@ -88,11 +110,9 @@ def __init__(self, protocol, transport, loop, settings = self._protocol.get_settings() ver_string = settings.server_version - self._server_version = \ - serverversion.split_server_version_string(ver_string) + self._server_version = serverversion.split_server_version_string(ver_string) - self._server_caps = _detect_server_capabilities( - self._server_version, settings) + self._server_caps = _detect_server_capabilities(self._server_version, settings) self._intro_query = introspection.INTRO_LOOKUP_TYPES_CRDB @@ -113,11 +133,16 @@ def __init__(self, protocol, transport, loop, else: self._source_traceback = None + # circular buffer of most recent executed statements to assist in + # debugging transaction state issues. + self._recent_statements = collections.deque(maxlen=20) + def __del__(self): if not self.is_closed() and self._protocol is not None: if self._source_traceback: msg = "unclosed connection {!r}; created at:\n {}".format( - self, self._source_traceback) + self, self._source_traceback + ) else: msg = ( "unclosed connection {!r}; run in asyncio debug " @@ -147,7 +172,7 @@ async def add_listener(self, channel, callback): """ self._check_open() if channel not in self._listeners: - await self.fetch('LISTEN {}'.format(utils._quote_ident(channel))) + await self.fetch("LISTEN {}".format(utils._quote_ident(channel))) self._listeners[channel] = set() self._listeners[channel].add(_Callback.from_callable(callback)) @@ -163,7 +188,7 @@ async def remove_listener(self, channel, callback): self._listeners[channel].remove(cb) if not self._listeners[channel]: del self._listeners[channel] - await self.fetch('UNLISTEN {}'.format(utils._quote_ident(channel))) + await self.fetch("UNLISTEN {}".format(utils._quote_ident(channel))) def add_log_listener(self, callback): """Add a listener for Postgres log messages. @@ -184,7 +209,7 @@ def add_log_listener(self, callback): The ``callback`` argument may be a coroutine function. """ if self.is_closed(): - raise exceptions.InterfaceError('connection is closed') + raise exceptions.InterfaceError("connection is closed") self._log_listeners.add(_Callback.from_callable(callback)) def remove_log_listener(self, callback): @@ -246,8 +271,7 @@ def get_settings(self): """ return self._protocol.get_settings() - def transaction(self, *, isolation=None, readonly=False, - deferrable=False): + def transaction(self, *, isolation=None, readonly=False, deferrable=False): """Create a :class:`~transaction.Transaction` object. Refer to `PostgreSQL documentation`_ on the meaning of transaction @@ -281,7 +305,7 @@ def is_in_transaction(self): """ return self._protocol.is_in_transaction() - async def execute(self, query: str, *args, timeout: float=None) -> str: + async def execute(self, query: str, *args, timeout: float = None) -> str: """Execute an SQL command (or commands). This method can execute many SQL commands at once, when no arguments @@ -311,6 +335,10 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: """ self._check_open() + # Append to circular buffer of most recent executed statements + # for debugging. + self._recent_statements.append(query) + if not args: return await self._protocol.query(query, timeout) @@ -323,7 +351,7 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: ) return status.decode() - async def executemany(self, command: str, args, *, timeout: float=None): + async def executemany(self, command: str, args, *, timeout: float = None): """Execute an SQL *command* for each sequence of arguments in *args*. Example: @@ -362,7 +390,7 @@ async def _get_statement( named=False, use_cache=True, ignore_custom_codec=False, - record_class=None + record_class=None, ): if record_class is None: record_class = self._protocol.get_record_class() @@ -370,9 +398,7 @@ async def _get_statement( _check_record_class(record_class) if use_cache: - statement = self._stmt_cache.get( - (query, record_class, ignore_custom_codec) - ) + statement = self._stmt_cache.get((query, record_class, ignore_custom_codec)) if statement is not None: return statement @@ -380,17 +406,19 @@ async def _get_statement( # * `statement_cache_size` is greater than 0; # * query size is less than `max_cacheable_statement_size`. use_cache = self._stmt_cache.get_max_size() > 0 - if (use_cache and - self._config.max_cacheable_statement_size and - len(query) > self._config.max_cacheable_statement_size): + if ( + use_cache + and self._config.max_cacheable_statement_size + and len(query) > self._config.max_cacheable_statement_size + ): use_cache = False if isinstance(named, str): stmt_name = named elif use_cache or named: - stmt_name = self._get_unique_id('stmt') + stmt_name = self._get_unique_id("stmt") else: - stmt_name = '' + stmt_name = "" statement = await self._protocol.prepare( stmt_name, @@ -408,7 +436,8 @@ async def _get_statement( # Introspect newly seen types and populate the # codec cache. types, intro_stmt = await self._introspect_types( - types_with_missing_codecs, timeout) + types_with_missing_codecs, timeout + ) settings.register_data_types(types) @@ -424,8 +453,8 @@ async def _get_statement( # with reload_schema_state(), which would cause a # second try. More than five is clearly a bug. raise exceptions.InternalClientError( - 'could not resolve query result and/or argument types ' - 'in {} attempts'.format(tries) + "could not resolve query result and/or argument types " + "in {} attempts".format(tries) ) # Now that types have been resolved, populate the codec pipeline @@ -442,8 +471,7 @@ async def _get_statement( ) if use_cache: - self._stmt_cache.put( - (query, record_class, ignore_custom_codec), statement) + self._stmt_cache.put((query, record_class, ignore_custom_codec), statement) # If we've just created a new statement object, check if there # are any statements for GC. @@ -463,7 +491,7 @@ async def _introspect_types(self, typeoids, timeout): async def _introspect_type(self, typename, schema): if ( - schema == 'pg_catalog' + schema == "pg_catalog" and typename.lower() in protocol.BUILTIN_TYPE_NAME_MAP ): typeoid = protocol.BUILTIN_TYPE_NAME_MAP[typename.lower()] @@ -484,19 +512,11 @@ async def _introspect_type(self, typename, schema): ) if not rows: - raise ValueError( - 'unknown type: {}.{}'.format(schema, typename)) + raise ValueError("unknown type: {}.{}".format(schema, typename)) return rows[0] - def cursor( - self, - query, - *args, - prefetch=None, - timeout=None, - record_class=None - ): + def cursor(self, query, *args, prefetch=None, timeout=None, record_class=None): """Return a *cursor factory* for the specified query. :param args: @@ -574,8 +594,8 @@ async def _prepare( *, name=None, timeout=None, - use_cache: bool=False, - record_class=None + use_cache: bool = False, + record_class=None, ): self._check_open() stmt = await self._get_statement( @@ -587,13 +607,7 @@ async def _prepare( ) return prepared_stmt.PreparedStatement(self, query, stmt) - async def fetch( - self, - query, - *args, - timeout=None, - record_class=None - ) -> list: + async def fetch(self, query, *args, timeout=None, record_class=None) -> list: """Run a query and return the results as a list of :class:`Record`. :param str query: @@ -644,13 +658,7 @@ async def fetchval(self, query, *args, column=0, timeout=None): return None return data[0][column] - async def fetchrow( - self, - query, - *args, - timeout=None, - record_class=None - ): + async def fetchrow(self, query, *args, timeout=None, record_class=None): """Run a query and return the first row. :param str query: @@ -684,11 +692,24 @@ async def fetchrow( return None return data[0] - async def copy_from_table(self, table_name, *, output, - columns=None, schema_name=None, timeout=None, - format=None, oids=None, delimiter=None, - null=None, header=None, quote=None, - escape=None, force_quote=None, encoding=None): + async def copy_from_table( + self, + table_name, + *, + output, + columns=None, + schema_name=None, + timeout=None, + format=None, + oids=None, + delimiter=None, + null=None, + header=None, + quote=None, + escape=None, + force_quote=None, + encoding=None, + ): """Copy table contents to a file or file-like object. :param str table_name: @@ -737,30 +758,47 @@ async def copy_from_table(self, table_name, *, output, """ tabname = utils._quote_ident(table_name) if schema_name: - tabname = utils._quote_ident(schema_name) + '.' + tabname + tabname = utils._quote_ident(schema_name) + "." + tabname if columns: - cols = '({})'.format( - ', '.join(utils._quote_ident(c) for c in columns)) + cols = "({})".format(", ".join(utils._quote_ident(c) for c in columns)) else: - cols = '' + cols = "" opts = self._format_copy_opts( - format=format, oids=oids, delimiter=delimiter, - null=null, header=header, quote=quote, escape=escape, - force_quote=force_quote, encoding=encoding + format=format, + oids=oids, + delimiter=delimiter, + null=null, + header=header, + quote=quote, + escape=escape, + force_quote=force_quote, + encoding=encoding, ) - copy_stmt = 'COPY {tab}{cols} TO STDOUT {opts}'.format( - tab=tabname, cols=cols, opts=opts) + copy_stmt = "COPY {tab}{cols} TO STDOUT {opts}".format( + tab=tabname, cols=cols, opts=opts + ) return await self._copy_out(copy_stmt, output, timeout) - async def copy_from_query(self, query, *args, output, - timeout=None, format=None, oids=None, - delimiter=None, null=None, header=None, - quote=None, escape=None, force_quote=None, - encoding=None): + async def copy_from_query( + self, + query, + *args, + output, + timeout=None, + format=None, + oids=None, + delimiter=None, + null=None, + header=None, + quote=None, + escape=None, + force_quote=None, + encoding=None, + ): """Copy the results of a query to a file or file-like object. :param str query: @@ -805,26 +843,45 @@ async def copy_from_query(self, query, *args, output, .. versionadded:: 0.11.0 """ opts = self._format_copy_opts( - format=format, oids=oids, delimiter=delimiter, - null=null, header=header, quote=quote, escape=escape, - force_quote=force_quote, encoding=encoding + format=format, + oids=oids, + delimiter=delimiter, + null=null, + header=header, + quote=quote, + escape=escape, + force_quote=force_quote, + encoding=encoding, ) if args: query = await utils._mogrify(self, query, args) - copy_stmt = 'COPY ({query}) TO STDOUT {opts}'.format( - query=query, opts=opts) + copy_stmt = "COPY ({query}) TO STDOUT {opts}".format(query=query, opts=opts) return await self._copy_out(copy_stmt, output, timeout) - async def copy_to_table(self, table_name, *, source, - columns=None, schema_name=None, timeout=None, - format=None, oids=None, freeze=None, - delimiter=None, null=None, header=None, - quote=None, escape=None, force_quote=None, - force_not_null=None, force_null=None, - encoding=None): + async def copy_to_table( + self, + table_name, + *, + source, + columns=None, + schema_name=None, + timeout=None, + format=None, + oids=None, + freeze=None, + delimiter=None, + null=None, + header=None, + quote=None, + escape=None, + force_quote=None, + force_not_null=None, + force_null=None, + encoding=None, + ): """Copy data to the specified table. :param str table_name: @@ -873,29 +930,36 @@ async def copy_to_table(self, table_name, *, source, """ tabname = utils._quote_ident(table_name) if schema_name: - tabname = utils._quote_ident(schema_name) + '.' + tabname + tabname = utils._quote_ident(schema_name) + "." + tabname if columns: - cols = '({})'.format( - ', '.join(utils._quote_ident(c) for c in columns)) + cols = "({})".format(", ".join(utils._quote_ident(c) for c in columns)) else: - cols = '' + cols = "" opts = self._format_copy_opts( - format=format, oids=oids, freeze=freeze, delimiter=delimiter, - null=null, header=header, quote=quote, escape=escape, - force_not_null=force_not_null, force_null=force_null, - encoding=encoding + format=format, + oids=oids, + freeze=freeze, + delimiter=delimiter, + null=null, + header=header, + quote=quote, + escape=escape, + force_not_null=force_not_null, + force_null=force_null, + encoding=encoding, ) - copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format( - tab=tabname, cols=cols, opts=opts) + copy_stmt = "COPY {tab}{cols} FROM STDIN {opts}".format( + tab=tabname, cols=cols, opts=opts + ) return await self._copy_in(copy_stmt, source, timeout) - async def copy_records_to_table(self, table_name, *, records, - columns=None, schema_name=None, - timeout=None): + async def copy_records_to_table( + self, table_name, *, records, columns=None, schema_name=None, timeout=None + ): """Copy a list of records to the specified table using binary COPY. :param str table_name: @@ -959,56 +1023,71 @@ async def copy_records_to_table(self, table_name, *, records, """ tabname = utils._quote_ident(table_name) if schema_name: - tabname = utils._quote_ident(schema_name) + '.' + tabname + tabname = utils._quote_ident(schema_name) + "." + tabname if columns: - col_list = ', '.join(utils._quote_ident(c) for c in columns) - cols = '({})'.format(col_list) + col_list = ", ".join(utils._quote_ident(c) for c in columns) + cols = "({})".format(col_list) else: - col_list = '*' - cols = '' + col_list = "*" + cols = "" - intro_query = 'SELECT {cols} FROM {tab} LIMIT 1'.format( - tab=tabname, cols=col_list) + intro_query = "SELECT {cols} FROM {tab} LIMIT 1".format( + tab=tabname, cols=col_list + ) intro_ps = await self._prepare(intro_query, use_cache=True) - opts = '(FORMAT binary)' + opts = "(FORMAT binary)" - copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format( - tab=tabname, cols=cols, opts=opts) + copy_stmt = "COPY {tab}{cols} FROM STDIN {opts}".format( + tab=tabname, cols=cols, opts=opts + ) return await self._protocol.copy_in( - copy_stmt, None, None, records, intro_ps._state, timeout) + copy_stmt, None, None, records, intro_ps._state, timeout + ) - def _format_copy_opts(self, *, format=None, oids=None, freeze=None, - delimiter=None, null=None, header=None, quote=None, - escape=None, force_quote=None, force_not_null=None, - force_null=None, encoding=None): + def _format_copy_opts( + self, + *, + format=None, + oids=None, + freeze=None, + delimiter=None, + null=None, + header=None, + quote=None, + escape=None, + force_quote=None, + force_not_null=None, + force_null=None, + encoding=None, + ): kwargs = dict(locals()) - kwargs.pop('self') + kwargs.pop("self") opts = [] if force_quote is not None and isinstance(force_quote, bool): - kwargs.pop('force_quote') + kwargs.pop("force_quote") if force_quote: - opts.append('FORCE_QUOTE *') + opts.append("FORCE_QUOTE *") for k, v in kwargs.items(): if v is not None: - if k in ('force_not_null', 'force_null', 'force_quote'): - v = '(' + ', '.join(utils._quote_ident(c) for c in v) + ')' - elif k in ('oids', 'freeze', 'header'): + if k in ("force_not_null", "force_null", "force_quote"): + v = "(" + ", ".join(utils._quote_ident(c) for c in v) + ")" + elif k in ("oids", "freeze", "header"): v = str(v) else: v = utils._quote_literal(v) - opts.append('{} {}'.format(k.upper(), v)) + opts.append("{} {}".format(k.upper(), v)) if opts: - return '(' + ', '.join(opts) + ')' + return "(" + ", ".join(opts) + ")" else: - return '' + return "" async def _copy_out(self, copy_stmt, output, timeout): try: @@ -1023,9 +1102,9 @@ async def _copy_out(self, copy_stmt, output, timeout): if path is not None: # a path - f = await run_in_executor(None, open, path, 'wb') + f = await run_in_executor(None, open, path, "wb") opened_by_us = True - elif hasattr(output, 'write'): + elif hasattr(output, "write"): # file-like f = output elif callable(output): @@ -1033,14 +1112,16 @@ async def _copy_out(self, copy_stmt, output, timeout): writer = output else: raise TypeError( - 'output is expected to be a file-like object, ' - 'a path-like object or a coroutine function, ' - 'not {}'.format(type(output).__name__) + "output is expected to be a file-like object, " + "a path-like object or a coroutine function, " + "not {}".format(type(output).__name__) ) if writer is None: + async def _writer(data): await run_in_executor(None, f.write, data) + writer = _writer try: @@ -1064,9 +1145,9 @@ async def _copy_in(self, copy_stmt, source, timeout): if path is not None: # a path - f = await run_in_executor(None, open, path, 'rb') + f = await run_in_executor(None, open, path, "rb") opened_by_us = True - elif hasattr(source, 'read'): + elif hasattr(source, "read"): # file-like f = source elif isinstance(source, collections.abc.AsyncIterable): @@ -1096,14 +1177,15 @@ async def __anext__(self): try: return await self._protocol.copy_in( - copy_stmt, reader, data, None, None, timeout) + copy_stmt, reader, data, None, None, timeout + ) finally: if opened_by_us: await run_in_executor(None, f.close) - async def set_type_codec(self, typename, *, - schema='public', encoder, decoder, - format='text'): + async def set_type_codec( + self, typename, *, schema="public", encoder, decoder, format="text" + ): """Set an encoder/decoder pair for the specified data type. :param typename: @@ -1219,27 +1301,29 @@ async def set_type_codec(self, typename, *, typeinfo = await self._introspect_type(typename, schema) if not introspection.is_scalar_type(typeinfo): raise exceptions.InterfaceError( - 'cannot use custom codec on non-scalar type {}.{}'.format( - schema, typename)) + "cannot use custom codec on non-scalar type {}.{}".format( + schema, typename + ) + ) if introspection.is_domain_type(typeinfo): raise exceptions.UnsupportedClientFeatureError( - 'custom codecs on domain types are not supported', - hint='Set the codec on the base type.', + "custom codecs on domain types are not supported", + hint="Set the codec on the base type.", detail=( - 'PostgreSQL does not distinguish domains from ' - 'their base types in query results at the protocol level.' - ) + "PostgreSQL does not distinguish domains from " + "their base types in query results at the protocol level." + ), ) - oid = typeinfo['oid'] + oid = typeinfo["oid"] self._protocol.get_settings().add_python_codec( - oid, typename, schema, 'scalar', - encoder, decoder, format) + oid, typename, schema, "scalar", encoder, decoder, format + ) # Statement cache is no longer valid due to codec changes. self._drop_local_statement_cache() - async def reset_type_codec(self, typename, *, schema='public'): + async def reset_type_codec(self, typename, *, schema="public"): """Reset *typename* codec to the default implementation. :param typename: @@ -1254,14 +1338,15 @@ async def reset_type_codec(self, typename, *, schema='public'): typeinfo = await self._introspect_type(typename, schema) self._protocol.get_settings().remove_python_codec( - typeinfo['oid'], typename, schema) + typeinfo["oid"], typename, schema + ) # Statement cache is no longer valid due to codec changes. self._drop_local_statement_cache() - async def set_builtin_type_codec(self, typename, *, - schema='public', codec_name, - format=None): + async def set_builtin_type_codec( + self, typename, *, schema="public", codec_name, format=None + ): """Set a builtin codec for the specified scalar data type. This method has two uses. The first is to register a builtin @@ -1298,13 +1383,14 @@ async def set_builtin_type_codec(self, typename, *, typeinfo = await self._introspect_type(typename, schema) if not introspection.is_scalar_type(typeinfo): raise exceptions.InterfaceError( - 'cannot alias non-scalar type {}.{}'.format( - schema, typename)) + "cannot alias non-scalar type {}.{}".format(schema, typename) + ) - oid = typeinfo['oid'] + oid = typeinfo["oid"] self._protocol.get_settings().set_builtin_type_codec( - oid, typename, schema, 'scalar', codec_name, format) + oid, typename, schema, "scalar", codec_name, format + ) # Statement cache is no longer valid due to codec changes. self._drop_local_statement_cache() @@ -1352,13 +1438,15 @@ async def reset(self, *, timeout=None): if self._top_xact is None or not self._top_xact._managed: # Managed transactions are guaranteed to __aexit__ # correctly. - self._loop.call_exception_handler({ - 'message': 'Resetting connection with an ' - 'active transaction {!r}'.format(self) - }) + self._loop.call_exception_handler( + { + "message": "Resetting connection with an " + "active transaction {!r}".format(self) + } + ) self._top_xact = None - reset_query = 'ROLLBACK;\n' + reset_query + reset_query = "ROLLBACK;\n" + reset_query if reset_query: await self.execute(reset_query, timeout=timeout) @@ -1394,12 +1482,12 @@ def _clean_tasks(self): def _check_open(self): if self.is_closed(): - raise exceptions.InterfaceError('connection is closed') + raise exceptions.InterfaceError("connection is closed") def _get_unique_id(self, prefix): global _uid _uid += 1 - return '__asyncpg_{}_{:x}__'.format(prefix, _uid) + return "__asyncpg_{}_{:x}__".format(prefix, _uid) def _mark_stmts_as_closed(self): for stmt in self._stmt_cache.iter_statements(): @@ -1412,11 +1500,8 @@ def _mark_stmts_as_closed(self): self._stmts_to_close.clear() def _maybe_gc_stmt(self, stmt): - if ( - stmt.refs == 0 - and not self._stmt_cache.has( - (stmt.query, stmt.record_class, stmt.ignore_custom_codec) - ) + if stmt.refs == 0 and not self._stmt_cache.has( + (stmt.query, stmt.record_class, stmt.ignore_custom_codec) ): # If low-level `stmt` isn't referenced from any high-level # `PreparedStatement` object and is not in the `_stmt_cache`: @@ -1444,9 +1529,12 @@ async def _cancel(self, waiter): try: # Open new connection to the server await connect_utils._cancel( - loop=self._loop, addr=self._addr, params=self._params, + loop=self._loop, + addr=self._addr, + params=self._params, backend_pid=self._protocol.backend_pid, - backend_secret=self._protocol.backend_secret) + backend_secret=self._protocol.backend_secret, + ) except ConnectionResetError as ex: # On some systems Postgres will reset the connection # after processing the cancellation command. @@ -1464,8 +1552,7 @@ async def _cancel(self, waiter): if not waiter.done(): waiter.set_exception(ex) finally: - self._cancellations.discard( - compat.current_asyncio_task(self._loop)) + self._cancellations.discard(compat.current_asyncio_task(self._loop)) if not waiter.done(): waiter.set_result(None) @@ -1528,15 +1615,15 @@ def _get_reset_query(self): _reset_query = [] if caps.advisory_locks: - _reset_query.append('SELECT pg_advisory_unlock_all();') + _reset_query.append("SELECT pg_advisory_unlock_all();") if caps.sql_close_all: - _reset_query.append('CLOSE ALL;') + _reset_query.append("CLOSE ALL;") if caps.notifications and caps.plpgsql: - _reset_query.append('UNLISTEN *;') + _reset_query.append("UNLISTEN *;") if caps.sql_reset: - _reset_query.append('RESET ALL;') + _reset_query.append("RESET ALL;") - _reset_query = '\n'.join(_reset_query) + _reset_query = "\n".join(_reset_query) self._reset_query = _reset_query return _reset_query @@ -1545,7 +1632,8 @@ def _set_proxy(self, proxy): if self._proxy is not None and proxy is not None: # Should not happen unless there is a bug in `Pool`. raise exceptions.InterfaceError( - 'internal asyncpg error: connection is already proxied') + "internal asyncpg error: connection is already proxied" + ) self._proxy = proxy @@ -1554,10 +1642,11 @@ def _check_listeners(self, listeners, listener_type): count = len(listeners) w = exceptions.InterfaceWarning( - '{conn!r} is being released to the pool but has {c} active ' - '{type} listener{s}'.format( - conn=self, c=count, type=listener_type, - s='s' if count > 1 else '')) + "{conn!r} is being released to the pool but has {c} active " + "{type} listener{s}".format( + conn=self, c=count, type=listener_type, s="s" if count > 1 else "" + ) + ) warnings.warn(w) @@ -1568,9 +1657,9 @@ def _on_release(self, stacklevel=1): # Let's check that the user has not left any listeners on it. self._check_listeners( list(itertools.chain.from_iterable(self._listeners.values())), - 'notification') - self._check_listeners( - self._log_listeners, 'log') + "notification", + ) + self._check_listeners(self._log_listeners, "log") def _drop_local_statement_cache(self): self._stmt_cache.clear() @@ -1650,7 +1739,7 @@ async def _execute( *, return_status=False, ignore_custom_codec=False, - record_class=None + record_class=None, ): with self._stmt_exclusive_section: result, _ = await self.__execute( @@ -1673,10 +1762,11 @@ async def __execute( *, return_status=False, ignore_custom_codec=False, - record_class=None + record_class=None, ): executor = lambda stmt, timeout: self._protocol.bind_execute( - stmt, args, '', limit, return_status, timeout) + stmt, args, "", limit, return_status, timeout + ) timeout = self._protocol._get_timeout(timeout) return await self._do_execute( query, @@ -1688,7 +1778,8 @@ async def __execute( async def _executemany(self, query, args, timeout): executor = lambda stmt, timeout: self._protocol.bind_execute_many( - stmt, args, '', timeout) + stmt, args, "", timeout + ) timeout = self._protocol._get_timeout(timeout) with self._stmt_exclusive_section: result, _ = await self._do_execute(query, executor, timeout) @@ -1702,7 +1793,7 @@ async def _do_execute( retry=True, *, ignore_custom_codec=False, - record_class=None + record_class=None, ): if timeout is None: stmt = await self._get_statement( @@ -1769,26 +1860,31 @@ async def _do_execute( if self._protocol.is_in_transaction() or not retry: raise else: - return await self._do_execute( - query, executor, timeout, retry=False) + return await self._do_execute(query, executor, timeout, retry=False) return result, stmt -async def connect(dsn=None, *, - host=None, port=None, - user=None, password=None, passfile=None, - database=None, - loop=None, - timeout=60, - statement_cache_size=100, - max_cached_statement_lifetime=300, - max_cacheable_statement_size=1024 * 15, - command_timeout=None, - ssl=None, - connection_class=Connection, - record_class=protocol.Record, - server_settings=None): +async def connect( + dsn=None, + *, + host=None, + port=None, + user=None, + password=None, + passfile=None, + database=None, + loop=None, + timeout=60, + statement_cache_size=100, + max_cached_statement_lifetime=300, + max_cacheable_statement_size=1024 * 15, + command_timeout=None, + ssl=None, + connection_class=Connection, + record_class=protocol.Record, + server_settings=None, +): r"""A coroutine to establish a connection to a PostgreSQL server. The connection parameters may be specified either as a connection @@ -2070,8 +2166,9 @@ async def connect(dsn=None, *, """ if not issubclass(connection_class, Connection): raise exceptions.InterfaceError( - 'connection_class is expected to be a subclass of ' - 'asyncpg.Connection, got {!r}'.format(connection_class)) + "connection_class is expected to be a subclass of " + "asyncpg.Connection, got {!r}".format(connection_class) + ) if record_class is not protocol.Record: _check_record_class(record_class) @@ -2102,7 +2199,7 @@ async def connect(dsn=None, *, class _StatementCacheEntry: - __slots__ = ('_query', '_statement', '_cache', '_cleanup_cb') + __slots__ = ("_query", "_statement", "_cache", "_cleanup_cb") def __init__(self, cache, query, statement): self._cache = cache @@ -2113,8 +2210,7 @@ def __init__(self, cache, query, statement): class _StatementCache: - __slots__ = ('_loop', '_entries', '_max_size', '_on_remove', - '_max_lifetime') + __slots__ = ("_loop", "_entries", "_max_size", "_on_remove", "_max_lifetime") def __init__(self, *, loop, max_size, on_remove, max_lifetime): self._loop = loop @@ -2223,7 +2319,8 @@ def _set_entry_timeout(self, entry): # Set the new timeout if it's not 0. if self._max_lifetime: entry._cleanup_cb = self._loop.call_later( - self._max_lifetime, self._on_entry_expired, entry) + self._max_lifetime, self._on_entry_expired, entry + ) def _new_entry(self, query, statement): entry = _StatementCacheEntry(self, query, statement) @@ -2258,22 +2355,21 @@ class _Callback(typing.NamedTuple): is_async: bool @classmethod - def from_callable(cls, cb: typing.Callable[..., None]) -> '_Callback': + def from_callable(cls, cb: typing.Callable[..., None]) -> "_Callback": if inspect.iscoroutinefunction(cb): is_async = True elif callable(cb): is_async = False else: raise exceptions.InterfaceError( - 'expected a callable or an `async def` function,' - 'got {!r}'.format(cb) + "expected a callable or an `async def` function," "got {!r}".format(cb) ) return cls(cb, is_async) class _Atomic: - __slots__ = ('_acquired',) + __slots__ = ("_acquired",) def __init__(self): self._acquired = 0 @@ -2281,7 +2377,8 @@ def __init__(self): def __enter__(self): if self._acquired: raise exceptions.InterfaceError( - 'cannot perform operation: another operation is in progress') + "cannot perform operation: another operation is in progress" + ) self._acquired = 1 def __exit__(self, t, e, tb): @@ -2294,28 +2391,28 @@ class _ConnectionProxy: ServerCapabilities = collections.namedtuple( - 'ServerCapabilities', - ['advisory_locks', 'notifications', 'plpgsql', 'sql_reset', - 'sql_close_all']) -ServerCapabilities.__doc__ = 'PostgreSQL server capabilities.' + "ServerCapabilities", + ["advisory_locks", "notifications", "plpgsql", "sql_reset", "sql_close_all"], +) +ServerCapabilities.__doc__ = "PostgreSQL server capabilities." def _detect_server_capabilities(server_version, connection_settings): - if hasattr(connection_settings, 'padb_revision'): + if hasattr(connection_settings, "padb_revision"): # Amazon Redshift detected. advisory_locks = False notifications = False plpgsql = False sql_reset = True sql_close_all = False - elif hasattr(connection_settings, 'crdb_version'): + elif hasattr(connection_settings, "crdb_version"): # CockroachDB detected. advisory_locks = False notifications = False plpgsql = False sql_reset = False sql_close_all = False - elif hasattr(connection_settings, 'crate_version'): + elif hasattr(connection_settings, "crate_version"): # CrateDB detected. advisory_locks = False notifications = False @@ -2335,7 +2432,7 @@ def _detect_server_capabilities(server_version, connection_settings): notifications=notifications, plpgsql=plpgsql, sql_reset=sql_reset, - sql_close_all=sql_close_all + sql_close_all=sql_close_all, ) @@ -2346,7 +2443,8 @@ def _extract_stack(limit=10): frame = sys._getframe().f_back try: stack = traceback.StackSummary.extract( - traceback.walk_stack(frame), lookup_lines=False) + traceback.walk_stack(frame), lookup_lines=False + ) finally: del frame @@ -2354,30 +2452,27 @@ def _extract_stack(limit=10): i = 0 while i < len(stack) and stack[i][0].startswith(apg_path): i += 1 - stack = stack[i:i + limit] + stack = stack[i : i + limit] stack.reverse() - return ''.join(traceback.format_list(stack)) + return "".join(traceback.format_list(stack)) def _check_record_class(record_class): if record_class is protocol.Record: pass - elif ( - isinstance(record_class, type) - and issubclass(record_class, protocol.Record) - ): + elif isinstance(record_class, type) and issubclass(record_class, protocol.Record): if ( record_class.__new__ is not object.__new__ or record_class.__init__ is not object.__init__ ): raise exceptions.InterfaceError( - 'record_class must not redefine __new__ or __init__' + "record_class must not redefine __new__ or __init__" ) else: raise exceptions.InterfaceError( - 'record_class is expected to be a subclass of ' - 'asyncpg.Record, got {!r}'.format(record_class) + "record_class is expected to be a subclass of " + "asyncpg.Record, got {!r}".format(record_class) ) diff --git a/asyncpg/transaction.py b/asyncpg/transaction.py index 2d7ba49f..b7de6892 100644 --- a/asyncpg/transaction.py +++ b/asyncpg/transaction.py @@ -6,10 +6,14 @@ import enum +import structlog + from . import connresource from . import exceptions as apg_errors +logger = structlog.get_logger(__name__) + class TransactionState(enum.Enum): NEW = 0 @@ -19,11 +23,11 @@ class TransactionState(enum.Enum): FAILED = 4 -ISOLATION_LEVELS = {'read_committed', 'serializable', 'repeatable_read'} +ISOLATION_LEVELS = {"read_committed", "serializable", "repeatable_read"} ISOLATION_LEVELS_BY_VALUE = { - 'read committed': 'read_committed', - 'serializable': 'serializable', - 'repeatable read': 'repeatable_read', + "read committed": "read_committed", + "serializable": "serializable", + "repeatable read": "repeatable_read", } @@ -35,16 +39,25 @@ class Transaction(connresource.ConnectionResource): function. """ - __slots__ = ('_connection', '_isolation', '_readonly', '_deferrable', - '_state', '_nested', '_id', '_managed') + __slots__ = ( + "_connection", + "_isolation", + "_readonly", + "_deferrable", + "_state", + "_nested", + "_id", + "_managed", + ) def __init__(self, connection, isolation, readonly, deferrable): super().__init__(connection) if isolation and isolation not in ISOLATION_LEVELS: raise ValueError( - 'isolation is expected to be either of {}, ' - 'got {!r}'.format(ISOLATION_LEVELS, isolation)) + "isolation is expected to be either of {}, " + "got {!r}".format(ISOLATION_LEVELS, isolation) + ) self._isolation = isolation self._readonly = readonly @@ -57,13 +70,14 @@ def __init__(self, connection, isolation, readonly, deferrable): async def __aenter__(self): if self._managed: raise apg_errors.InterfaceError( - 'cannot enter context: already in an `async with` block') + "cannot enter context: already in an `async with` block" + ) self._managed = True await self.start() async def __aexit__(self, extype, ex, tb): try: - self._check_conn_validity('__aexit__') + self._check_conn_validity("__aexit__") except apg_errors.InterfaceError: if extype is GeneratorExit: # When a PoolAcquireContext is being exited, and there @@ -89,18 +103,26 @@ async def __aexit__(self, extype, ex, tb): @connresource.guarded async def start(self): """Enter the transaction or savepoint block.""" - self.__check_state_base('start') + self.__check_state_base("start") if self._state is TransactionState.STARTED: raise apg_errors.InterfaceError( - 'cannot start; the transaction is already started') + "cannot start; the transaction is already started" + ) con = self._connection if con._top_xact is None: if con._protocol.is_in_transaction(): + logger.error( + "bad transaction state for connection", + connection_id=id(con), + tx_state=con._protocol._protocol.xact_status, + recent_statements=list(con.recent_statements), + ) raise apg_errors.InterfaceError( - 'cannot use Connection.transaction() in ' - 'a manually started transaction') + "cannot use Connection.transaction() in " + "a manually started transaction" + ) con._top_xact = self else: # Nested transaction block @@ -108,31 +130,33 @@ async def start(self): top_xact_isolation = con._top_xact._isolation if top_xact_isolation is None: top_xact_isolation = ISOLATION_LEVELS_BY_VALUE[ - await self._connection.fetchval( - 'SHOW transaction_isolation;')] + await self._connection.fetchval("SHOW transaction_isolation;") + ] if self._isolation != top_xact_isolation: raise apg_errors.InterfaceError( - 'nested transaction has a different isolation level: ' - 'current {!r} != outer {!r}'.format( - self._isolation, top_xact_isolation)) + "nested transaction has a different isolation level: " + "current {!r} != outer {!r}".format( + self._isolation, top_xact_isolation + ) + ) self._nested = True if self._nested: - self._id = con._get_unique_id('savepoint') - query = 'SAVEPOINT {};'.format(self._id) + self._id = con._get_unique_id("savepoint") + query = "SAVEPOINT {};".format(self._id) else: - query = 'BEGIN' - if self._isolation == 'read_committed': - query += ' ISOLATION LEVEL READ COMMITTED' - elif self._isolation == 'repeatable_read': - query += ' ISOLATION LEVEL REPEATABLE READ' - elif self._isolation == 'serializable': - query += ' ISOLATION LEVEL SERIALIZABLE' + query = "BEGIN" + if self._isolation == "read_committed": + query += " ISOLATION LEVEL READ COMMITTED" + elif self._isolation == "repeatable_read": + query += " ISOLATION LEVEL REPEATABLE READ" + elif self._isolation == "serializable": + query += " ISOLATION LEVEL SERIALIZABLE" if self._readonly: - query += ' READ ONLY' + query += " READ ONLY" if self._deferrable: - query += ' DEFERRABLE' - query += ';' + query += " DEFERRABLE" + query += ";" try: await self._connection.execute(query) @@ -145,35 +169,35 @@ async def start(self): def __check_state_base(self, opname): if self._state is TransactionState.COMMITTED: raise apg_errors.InterfaceError( - 'cannot {}; the transaction is already committed'.format( - opname)) + "cannot {}; the transaction is already committed".format(opname) + ) if self._state is TransactionState.ROLLEDBACK: raise apg_errors.InterfaceError( - 'cannot {}; the transaction is already rolled back'.format( - opname)) + "cannot {}; the transaction is already rolled back".format(opname) + ) if self._state is TransactionState.FAILED: raise apg_errors.InterfaceError( - 'cannot {}; the transaction is in error state'.format( - opname)) + "cannot {}; the transaction is in error state".format(opname) + ) def __check_state(self, opname): if self._state is not TransactionState.STARTED: if self._state is TransactionState.NEW: raise apg_errors.InterfaceError( - 'cannot {}; the transaction is not yet started'.format( - opname)) + "cannot {}; the transaction is not yet started".format(opname) + ) self.__check_state_base(opname) async def __commit(self): - self.__check_state('commit') + self.__check_state("commit") if self._connection._top_xact is self: self._connection._top_xact = None if self._nested: - query = 'RELEASE SAVEPOINT {};'.format(self._id) + query = "RELEASE SAVEPOINT {};".format(self._id) else: - query = 'COMMIT;' + query = "COMMIT;" try: await self._connection.execute(query) @@ -184,15 +208,15 @@ async def __commit(self): self._state = TransactionState.COMMITTED async def __rollback(self): - self.__check_state('rollback') + self.__check_state("rollback") if self._connection._top_xact is self: self._connection._top_xact = None if self._nested: - query = 'ROLLBACK TO {};'.format(self._id) + query = "ROLLBACK TO {};".format(self._id) else: - query = 'ROLLBACK;' + query = "ROLLBACK;" try: await self._connection.execute(query) @@ -207,7 +231,8 @@ async def commit(self): """Exit the transaction or savepoint block and commit changes.""" if self._managed: raise apg_errors.InterfaceError( - 'cannot manually commit from within an `async with` block') + "cannot manually commit from within an `async with` block" + ) await self.__commit() @connresource.guarded @@ -215,24 +240,26 @@ async def rollback(self): """Exit the transaction or savepoint block and rollback changes.""" if self._managed: raise apg_errors.InterfaceError( - 'cannot manually rollback from within an `async with` block') + "cannot manually rollback from within an `async with` block" + ) await self.__rollback() def __repr__(self): attrs = [] - attrs.append('state:{}'.format(self._state.name.lower())) + attrs.append("state:{}".format(self._state.name.lower())) if self._isolation is not None: attrs.append(self._isolation) if self._readonly: - attrs.append('readonly') + attrs.append("readonly") if self._deferrable: - attrs.append('deferrable') + attrs.append("deferrable") - if self.__class__.__module__.startswith('asyncpg.'): - mod = 'asyncpg' + if self.__class__.__module__.startswith("asyncpg."): + mod = "asyncpg" else: mod = self.__class__.__module__ - return '<{}.{} {} {:#x}>'.format( - mod, self.__class__.__name__, ' '.join(attrs), id(self)) + return "<{}.{} {} {:#x}>".format( + mod, self.__class__.__name__, " ".join(attrs), id(self) + ) From 42de38df5c91ebbb961ad6321d9589d6b12d041e Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 28 Jul 2022 13:25:54 -0400 Subject: [PATCH 2/4] Log most recent N statements to try to get a handle on why we croak on the history of the connection use when we die of 'cannot use Connection.transaction() in a manually started transaction' --- asyncpg/connection.py | 593 +++++++++++++++++++++++---------------- asyncpg/prepared_stmt.py | 27 +- asyncpg/transaction.py | 134 +++++---- 3 files changed, 441 insertions(+), 313 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 7c3668ca..8f3f58d6 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -33,7 +33,6 @@ class ConnectionMeta(type): - def __instancecheck__(cls, instance): mro = type(instance).__mro__ return Connection in mro or _ConnectionProxy in mro @@ -45,19 +44,42 @@ class Connection(metaclass=ConnectionMeta): Connections are created by calling :func:`~asyncpg.connection.connect`. """ - __slots__ = ('_protocol', '_transport', '_loop', - '_top_xact', '_aborted', - '_pool_release_ctr', '_stmt_cache', '_stmts_to_close', - '_listeners', '_server_version', '_server_caps', - '_intro_query', '_reset_query', '_proxy', - '_stmt_exclusive_section', '_config', '_params', '_addr', - '_log_listeners', '_termination_listeners', '_cancellations', - '_source_traceback', '__weakref__') - - def __init__(self, protocol, transport, loop, - addr, - config: connect_utils._ClientConfiguration, - params: connect_utils._ConnectionParameters): + __slots__ = ( + "_protocol", + "_transport", + "_loop", + "_top_xact", + "_aborted", + "_pool_release_ctr", + "_stmt_cache", + "_stmts_to_close", + "_listeners", + "_server_version", + "_server_caps", + "_intro_query", + "_reset_query", + "_proxy", + "_stmt_exclusive_section", + "_config", + "_params", + "_addr", + "_log_listeners", + "_termination_listeners", + "_cancellations", + "_source_traceback", + "__weakref__", + "_recent_statements", + ) + + def __init__( + self, + protocol, + transport, + loop, + addr, + config: connect_utils._ClientConfiguration, + params: connect_utils._ConnectionParameters, + ): self._protocol = protocol self._transport = transport self._loop = loop @@ -75,9 +97,9 @@ def __init__(self, protocol, transport, loop, self._stmt_cache = _StatementCache( loop=loop, max_size=config.statement_cache_size, - on_remove=functools.partial( - _weak_maybe_gc_stmt, weakref.ref(self)), - max_lifetime=config.max_cached_statement_lifetime) + on_remove=functools.partial(_weak_maybe_gc_stmt, weakref.ref(self)), + max_lifetime=config.max_cached_statement_lifetime, + ) self._stmts_to_close = set() @@ -88,11 +110,9 @@ def __init__(self, protocol, transport, loop, settings = self._protocol.get_settings() ver_string = settings.server_version - self._server_version = \ - serverversion.split_server_version_string(ver_string) + self._server_version = serverversion.split_server_version_string(ver_string) - self._server_caps = _detect_server_capabilities( - self._server_version, settings) + self._server_caps = _detect_server_capabilities(self._server_version, settings) self._intro_query = introspection.INTRO_LOOKUP_TYPES_CRDB @@ -113,11 +133,16 @@ def __init__(self, protocol, transport, loop, else: self._source_traceback = None + # circular buffer of most recent executed statements to assist in + # debugging transaction state issues. + self._recent_statements = collections.deque(maxlen=20) + def __del__(self): if not self.is_closed() and self._protocol is not None: if self._source_traceback: msg = "unclosed connection {!r}; created at:\n {}".format( - self, self._source_traceback) + self, self._source_traceback + ) else: msg = ( "unclosed connection {!r}; run in asyncio debug " @@ -147,7 +172,7 @@ async def add_listener(self, channel, callback): """ self._check_open() if channel not in self._listeners: - await self.fetch('LISTEN {}'.format(utils._quote_ident(channel))) + await self.fetch("LISTEN {}".format(utils._quote_ident(channel))) self._listeners[channel] = set() self._listeners[channel].add(_Callback.from_callable(callback)) @@ -163,7 +188,7 @@ async def remove_listener(self, channel, callback): self._listeners[channel].remove(cb) if not self._listeners[channel]: del self._listeners[channel] - await self.fetch('UNLISTEN {}'.format(utils._quote_ident(channel))) + await self.fetch("UNLISTEN {}".format(utils._quote_ident(channel))) def add_log_listener(self, callback): """Add a listener for Postgres log messages. @@ -184,7 +209,7 @@ def add_log_listener(self, callback): The ``callback`` argument may be a coroutine function. """ if self.is_closed(): - raise exceptions.InterfaceError('connection is closed') + raise exceptions.InterfaceError("connection is closed") self._log_listeners.add(_Callback.from_callable(callback)) def remove_log_listener(self, callback): @@ -246,8 +271,7 @@ def get_settings(self): """ return self._protocol.get_settings() - def transaction(self, *, isolation=None, readonly=False, - deferrable=False): + def transaction(self, *, isolation=None, readonly=False, deferrable=False): """Create a :class:`~transaction.Transaction` object. Refer to `PostgreSQL documentation`_ on the meaning of transaction @@ -281,7 +305,7 @@ def is_in_transaction(self): """ return self._protocol.is_in_transaction() - async def execute(self, query: str, *args, timeout: float=None) -> str: + async def execute(self, query: str, *args, timeout: float = None) -> str: """Execute an SQL command (or commands). This method can execute many SQL commands at once, when no arguments @@ -312,6 +336,7 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: self._check_open() if not args: + self._recent_statements.append(query) return await self._protocol.query(query, timeout) _, status, _ = await self._execute( @@ -323,7 +348,7 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: ) return status.decode() - async def executemany(self, command: str, args, *, timeout: float=None): + async def executemany(self, command: str, args, *, timeout: float = None): """Execute an SQL *command* for each sequence of arguments in *args*. Example: @@ -352,6 +377,7 @@ async def executemany(self, command: str, args, *, timeout: float=None): ``executemany()`` was called in a transaction. """ self._check_open() + return await self._executemany(command, args, timeout) async def _get_statement( @@ -362,7 +388,7 @@ async def _get_statement( named=False, use_cache=True, ignore_custom_codec=False, - record_class=None + record_class=None, ): if record_class is None: record_class = self._protocol.get_record_class() @@ -370,9 +396,7 @@ async def _get_statement( _check_record_class(record_class) if use_cache: - statement = self._stmt_cache.get( - (query, record_class, ignore_custom_codec) - ) + statement = self._stmt_cache.get((query, record_class, ignore_custom_codec)) if statement is not None: return statement @@ -380,17 +404,19 @@ async def _get_statement( # * `statement_cache_size` is greater than 0; # * query size is less than `max_cacheable_statement_size`. use_cache = self._stmt_cache.get_max_size() > 0 - if (use_cache and - self._config.max_cacheable_statement_size and - len(query) > self._config.max_cacheable_statement_size): + if ( + use_cache + and self._config.max_cacheable_statement_size + and len(query) > self._config.max_cacheable_statement_size + ): use_cache = False if isinstance(named, str): stmt_name = named elif use_cache or named: - stmt_name = self._get_unique_id('stmt') + stmt_name = self._get_unique_id("stmt") else: - stmt_name = '' + stmt_name = "" statement = await self._protocol.prepare( stmt_name, @@ -408,7 +434,8 @@ async def _get_statement( # Introspect newly seen types and populate the # codec cache. types, intro_stmt = await self._introspect_types( - types_with_missing_codecs, timeout) + types_with_missing_codecs, timeout + ) settings.register_data_types(types) @@ -424,8 +451,8 @@ async def _get_statement( # with reload_schema_state(), which would cause a # second try. More than five is clearly a bug. raise exceptions.InternalClientError( - 'could not resolve query result and/or argument types ' - 'in {} attempts'.format(tries) + "could not resolve query result and/or argument types " + "in {} attempts".format(tries) ) # Now that types have been resolved, populate the codec pipeline @@ -442,8 +469,7 @@ async def _get_statement( ) if use_cache: - self._stmt_cache.put( - (query, record_class, ignore_custom_codec), statement) + self._stmt_cache.put((query, record_class, ignore_custom_codec), statement) # If we've just created a new statement object, check if there # are any statements for GC. @@ -463,7 +489,7 @@ async def _introspect_types(self, typeoids, timeout): async def _introspect_type(self, typename, schema): if ( - schema == 'pg_catalog' + schema == "pg_catalog" and typename.lower() in protocol.BUILTIN_TYPE_NAME_MAP ): typeoid = protocol.BUILTIN_TYPE_NAME_MAP[typename.lower()] @@ -484,19 +510,11 @@ async def _introspect_type(self, typename, schema): ) if not rows: - raise ValueError( - 'unknown type: {}.{}'.format(schema, typename)) + raise ValueError("unknown type: {}.{}".format(schema, typename)) return rows[0] - def cursor( - self, - query, - *args, - prefetch=None, - timeout=None, - record_class=None - ): + def cursor(self, query, *args, prefetch=None, timeout=None, record_class=None): """Return a *cursor factory* for the specified query. :param args: @@ -518,6 +536,9 @@ def cursor( Added the *record_class* parameter. """ self._check_open() + + self._recent_statements.append(query) + return cursor.CursorFactory( self, query, @@ -560,6 +581,7 @@ async def prepare( .. versionchanged:: 0.25.0 Added the *name* parameter. """ + return await self._prepare( query, name=name, @@ -574,8 +596,8 @@ async def _prepare( *, name=None, timeout=None, - use_cache: bool=False, - record_class=None + use_cache: bool = False, + record_class=None, ): self._check_open() stmt = await self._get_statement( @@ -587,13 +609,7 @@ async def _prepare( ) return prepared_stmt.PreparedStatement(self, query, stmt) - async def fetch( - self, - query, - *args, - timeout=None, - record_class=None - ) -> list: + async def fetch(self, query, *args, timeout=None, record_class=None) -> list: """Run a query and return the results as a list of :class:`Record`. :param str query: @@ -615,6 +631,7 @@ async def fetch( Added the *record_class* parameter. """ self._check_open() + return await self._execute( query, args, @@ -644,13 +661,7 @@ async def fetchval(self, query, *args, column=0, timeout=None): return None return data[0][column] - async def fetchrow( - self, - query, - *args, - timeout=None, - record_class=None - ): + async def fetchrow(self, query, *args, timeout=None, record_class=None): """Run a query and return the first row. :param str query: @@ -684,11 +695,24 @@ async def fetchrow( return None return data[0] - async def copy_from_table(self, table_name, *, output, - columns=None, schema_name=None, timeout=None, - format=None, oids=None, delimiter=None, - null=None, header=None, quote=None, - escape=None, force_quote=None, encoding=None): + async def copy_from_table( + self, + table_name, + *, + output, + columns=None, + schema_name=None, + timeout=None, + format=None, + oids=None, + delimiter=None, + null=None, + header=None, + quote=None, + escape=None, + force_quote=None, + encoding=None, + ): """Copy table contents to a file or file-like object. :param str table_name: @@ -737,30 +761,47 @@ async def copy_from_table(self, table_name, *, output, """ tabname = utils._quote_ident(table_name) if schema_name: - tabname = utils._quote_ident(schema_name) + '.' + tabname + tabname = utils._quote_ident(schema_name) + "." + tabname if columns: - cols = '({})'.format( - ', '.join(utils._quote_ident(c) for c in columns)) + cols = "({})".format(", ".join(utils._quote_ident(c) for c in columns)) else: - cols = '' + cols = "" opts = self._format_copy_opts( - format=format, oids=oids, delimiter=delimiter, - null=null, header=header, quote=quote, escape=escape, - force_quote=force_quote, encoding=encoding + format=format, + oids=oids, + delimiter=delimiter, + null=null, + header=header, + quote=quote, + escape=escape, + force_quote=force_quote, + encoding=encoding, ) - copy_stmt = 'COPY {tab}{cols} TO STDOUT {opts}'.format( - tab=tabname, cols=cols, opts=opts) + copy_stmt = "COPY {tab}{cols} TO STDOUT {opts}".format( + tab=tabname, cols=cols, opts=opts + ) return await self._copy_out(copy_stmt, output, timeout) - async def copy_from_query(self, query, *args, output, - timeout=None, format=None, oids=None, - delimiter=None, null=None, header=None, - quote=None, escape=None, force_quote=None, - encoding=None): + async def copy_from_query( + self, + query, + *args, + output, + timeout=None, + format=None, + oids=None, + delimiter=None, + null=None, + header=None, + quote=None, + escape=None, + force_quote=None, + encoding=None, + ): """Copy the results of a query to a file or file-like object. :param str query: @@ -805,26 +846,45 @@ async def copy_from_query(self, query, *args, output, .. versionadded:: 0.11.0 """ opts = self._format_copy_opts( - format=format, oids=oids, delimiter=delimiter, - null=null, header=header, quote=quote, escape=escape, - force_quote=force_quote, encoding=encoding + format=format, + oids=oids, + delimiter=delimiter, + null=null, + header=header, + quote=quote, + escape=escape, + force_quote=force_quote, + encoding=encoding, ) if args: query = await utils._mogrify(self, query, args) - copy_stmt = 'COPY ({query}) TO STDOUT {opts}'.format( - query=query, opts=opts) + copy_stmt = "COPY ({query}) TO STDOUT {opts}".format(query=query, opts=opts) return await self._copy_out(copy_stmt, output, timeout) - async def copy_to_table(self, table_name, *, source, - columns=None, schema_name=None, timeout=None, - format=None, oids=None, freeze=None, - delimiter=None, null=None, header=None, - quote=None, escape=None, force_quote=None, - force_not_null=None, force_null=None, - encoding=None): + async def copy_to_table( + self, + table_name, + *, + source, + columns=None, + schema_name=None, + timeout=None, + format=None, + oids=None, + freeze=None, + delimiter=None, + null=None, + header=None, + quote=None, + escape=None, + force_quote=None, + force_not_null=None, + force_null=None, + encoding=None, + ): """Copy data to the specified table. :param str table_name: @@ -873,29 +933,36 @@ async def copy_to_table(self, table_name, *, source, """ tabname = utils._quote_ident(table_name) if schema_name: - tabname = utils._quote_ident(schema_name) + '.' + tabname + tabname = utils._quote_ident(schema_name) + "." + tabname if columns: - cols = '({})'.format( - ', '.join(utils._quote_ident(c) for c in columns)) + cols = "({})".format(", ".join(utils._quote_ident(c) for c in columns)) else: - cols = '' + cols = "" opts = self._format_copy_opts( - format=format, oids=oids, freeze=freeze, delimiter=delimiter, - null=null, header=header, quote=quote, escape=escape, - force_not_null=force_not_null, force_null=force_null, - encoding=encoding + format=format, + oids=oids, + freeze=freeze, + delimiter=delimiter, + null=null, + header=header, + quote=quote, + escape=escape, + force_not_null=force_not_null, + force_null=force_null, + encoding=encoding, ) - copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format( - tab=tabname, cols=cols, opts=opts) + copy_stmt = "COPY {tab}{cols} FROM STDIN {opts}".format( + tab=tabname, cols=cols, opts=opts + ) return await self._copy_in(copy_stmt, source, timeout) - async def copy_records_to_table(self, table_name, *, records, - columns=None, schema_name=None, - timeout=None): + async def copy_records_to_table( + self, table_name, *, records, columns=None, schema_name=None, timeout=None + ): """Copy a list of records to the specified table using binary COPY. :param str table_name: @@ -959,56 +1026,71 @@ async def copy_records_to_table(self, table_name, *, records, """ tabname = utils._quote_ident(table_name) if schema_name: - tabname = utils._quote_ident(schema_name) + '.' + tabname + tabname = utils._quote_ident(schema_name) + "." + tabname if columns: - col_list = ', '.join(utils._quote_ident(c) for c in columns) - cols = '({})'.format(col_list) + col_list = ", ".join(utils._quote_ident(c) for c in columns) + cols = "({})".format(col_list) else: - col_list = '*' - cols = '' + col_list = "*" + cols = "" - intro_query = 'SELECT {cols} FROM {tab} LIMIT 1'.format( - tab=tabname, cols=col_list) + intro_query = "SELECT {cols} FROM {tab} LIMIT 1".format( + tab=tabname, cols=col_list + ) intro_ps = await self._prepare(intro_query, use_cache=True) - opts = '(FORMAT binary)' + opts = "(FORMAT binary)" - copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format( - tab=tabname, cols=cols, opts=opts) + copy_stmt = "COPY {tab}{cols} FROM STDIN {opts}".format( + tab=tabname, cols=cols, opts=opts + ) return await self._protocol.copy_in( - copy_stmt, None, None, records, intro_ps._state, timeout) + copy_stmt, None, None, records, intro_ps._state, timeout + ) - def _format_copy_opts(self, *, format=None, oids=None, freeze=None, - delimiter=None, null=None, header=None, quote=None, - escape=None, force_quote=None, force_not_null=None, - force_null=None, encoding=None): + def _format_copy_opts( + self, + *, + format=None, + oids=None, + freeze=None, + delimiter=None, + null=None, + header=None, + quote=None, + escape=None, + force_quote=None, + force_not_null=None, + force_null=None, + encoding=None, + ): kwargs = dict(locals()) - kwargs.pop('self') + kwargs.pop("self") opts = [] if force_quote is not None and isinstance(force_quote, bool): - kwargs.pop('force_quote') + kwargs.pop("force_quote") if force_quote: - opts.append('FORCE_QUOTE *') + opts.append("FORCE_QUOTE *") for k, v in kwargs.items(): if v is not None: - if k in ('force_not_null', 'force_null', 'force_quote'): - v = '(' + ', '.join(utils._quote_ident(c) for c in v) + ')' - elif k in ('oids', 'freeze', 'header'): + if k in ("force_not_null", "force_null", "force_quote"): + v = "(" + ", ".join(utils._quote_ident(c) for c in v) + ")" + elif k in ("oids", "freeze", "header"): v = str(v) else: v = utils._quote_literal(v) - opts.append('{} {}'.format(k.upper(), v)) + opts.append("{} {}".format(k.upper(), v)) if opts: - return '(' + ', '.join(opts) + ')' + return "(" + ", ".join(opts) + ")" else: - return '' + return "" async def _copy_out(self, copy_stmt, output, timeout): try: @@ -1023,9 +1105,9 @@ async def _copy_out(self, copy_stmt, output, timeout): if path is not None: # a path - f = await run_in_executor(None, open, path, 'wb') + f = await run_in_executor(None, open, path, "wb") opened_by_us = True - elif hasattr(output, 'write'): + elif hasattr(output, "write"): # file-like f = output elif callable(output): @@ -1033,14 +1115,16 @@ async def _copy_out(self, copy_stmt, output, timeout): writer = output else: raise TypeError( - 'output is expected to be a file-like object, ' - 'a path-like object or a coroutine function, ' - 'not {}'.format(type(output).__name__) + "output is expected to be a file-like object, " + "a path-like object or a coroutine function, " + "not {}".format(type(output).__name__) ) if writer is None: + async def _writer(data): await run_in_executor(None, f.write, data) + writer = _writer try: @@ -1064,9 +1148,9 @@ async def _copy_in(self, copy_stmt, source, timeout): if path is not None: # a path - f = await run_in_executor(None, open, path, 'rb') + f = await run_in_executor(None, open, path, "rb") opened_by_us = True - elif hasattr(source, 'read'): + elif hasattr(source, "read"): # file-like f = source elif isinstance(source, collections.abc.AsyncIterable): @@ -1096,14 +1180,15 @@ async def __anext__(self): try: return await self._protocol.copy_in( - copy_stmt, reader, data, None, None, timeout) + copy_stmt, reader, data, None, None, timeout + ) finally: if opened_by_us: await run_in_executor(None, f.close) - async def set_type_codec(self, typename, *, - schema='public', encoder, decoder, - format='text'): + async def set_type_codec( + self, typename, *, schema="public", encoder, decoder, format="text" + ): """Set an encoder/decoder pair for the specified data type. :param typename: @@ -1219,27 +1304,29 @@ async def set_type_codec(self, typename, *, typeinfo = await self._introspect_type(typename, schema) if not introspection.is_scalar_type(typeinfo): raise exceptions.InterfaceError( - 'cannot use custom codec on non-scalar type {}.{}'.format( - schema, typename)) + "cannot use custom codec on non-scalar type {}.{}".format( + schema, typename + ) + ) if introspection.is_domain_type(typeinfo): raise exceptions.UnsupportedClientFeatureError( - 'custom codecs on domain types are not supported', - hint='Set the codec on the base type.', + "custom codecs on domain types are not supported", + hint="Set the codec on the base type.", detail=( - 'PostgreSQL does not distinguish domains from ' - 'their base types in query results at the protocol level.' - ) + "PostgreSQL does not distinguish domains from " + "their base types in query results at the protocol level." + ), ) - oid = typeinfo['oid'] + oid = typeinfo["oid"] self._protocol.get_settings().add_python_codec( - oid, typename, schema, 'scalar', - encoder, decoder, format) + oid, typename, schema, "scalar", encoder, decoder, format + ) # Statement cache is no longer valid due to codec changes. self._drop_local_statement_cache() - async def reset_type_codec(self, typename, *, schema='public'): + async def reset_type_codec(self, typename, *, schema="public"): """Reset *typename* codec to the default implementation. :param typename: @@ -1254,14 +1341,15 @@ async def reset_type_codec(self, typename, *, schema='public'): typeinfo = await self._introspect_type(typename, schema) self._protocol.get_settings().remove_python_codec( - typeinfo['oid'], typename, schema) + typeinfo["oid"], typename, schema + ) # Statement cache is no longer valid due to codec changes. self._drop_local_statement_cache() - async def set_builtin_type_codec(self, typename, *, - schema='public', codec_name, - format=None): + async def set_builtin_type_codec( + self, typename, *, schema="public", codec_name, format=None + ): """Set a builtin codec for the specified scalar data type. This method has two uses. The first is to register a builtin @@ -1298,13 +1386,14 @@ async def set_builtin_type_codec(self, typename, *, typeinfo = await self._introspect_type(typename, schema) if not introspection.is_scalar_type(typeinfo): raise exceptions.InterfaceError( - 'cannot alias non-scalar type {}.{}'.format( - schema, typename)) + "cannot alias non-scalar type {}.{}".format(schema, typename) + ) - oid = typeinfo['oid'] + oid = typeinfo["oid"] self._protocol.get_settings().set_builtin_type_codec( - oid, typename, schema, 'scalar', codec_name, format) + oid, typename, schema, "scalar", codec_name, format + ) # Statement cache is no longer valid due to codec changes. self._drop_local_statement_cache() @@ -1352,13 +1441,15 @@ async def reset(self, *, timeout=None): if self._top_xact is None or not self._top_xact._managed: # Managed transactions are guaranteed to __aexit__ # correctly. - self._loop.call_exception_handler({ - 'message': 'Resetting connection with an ' - 'active transaction {!r}'.format(self) - }) + self._loop.call_exception_handler( + { + "message": "Resetting connection with an " + "active transaction {!r}".format(self) + } + ) self._top_xact = None - reset_query = 'ROLLBACK;\n' + reset_query + reset_query = "ROLLBACK;\n" + reset_query if reset_query: await self.execute(reset_query, timeout=timeout) @@ -1394,12 +1485,12 @@ def _clean_tasks(self): def _check_open(self): if self.is_closed(): - raise exceptions.InterfaceError('connection is closed') + raise exceptions.InterfaceError("connection is closed") def _get_unique_id(self, prefix): global _uid _uid += 1 - return '__asyncpg_{}_{:x}__'.format(prefix, _uid) + return "__asyncpg_{}_{:x}__".format(prefix, _uid) def _mark_stmts_as_closed(self): for stmt in self._stmt_cache.iter_statements(): @@ -1412,11 +1503,8 @@ def _mark_stmts_as_closed(self): self._stmts_to_close.clear() def _maybe_gc_stmt(self, stmt): - if ( - stmt.refs == 0 - and not self._stmt_cache.has( - (stmt.query, stmt.record_class, stmt.ignore_custom_codec) - ) + if stmt.refs == 0 and not self._stmt_cache.has( + (stmt.query, stmt.record_class, stmt.ignore_custom_codec) ): # If low-level `stmt` isn't referenced from any high-level # `PreparedStatement` object and is not in the `_stmt_cache`: @@ -1444,9 +1532,12 @@ async def _cancel(self, waiter): try: # Open new connection to the server await connect_utils._cancel( - loop=self._loop, addr=self._addr, params=self._params, + loop=self._loop, + addr=self._addr, + params=self._params, backend_pid=self._protocol.backend_pid, - backend_secret=self._protocol.backend_secret) + backend_secret=self._protocol.backend_secret, + ) except ConnectionResetError as ex: # On some systems Postgres will reset the connection # after processing the cancellation command. @@ -1464,8 +1555,7 @@ async def _cancel(self, waiter): if not waiter.done(): waiter.set_exception(ex) finally: - self._cancellations.discard( - compat.current_asyncio_task(self._loop)) + self._cancellations.discard(compat.current_asyncio_task(self._loop)) if not waiter.done(): waiter.set_result(None) @@ -1528,15 +1618,15 @@ def _get_reset_query(self): _reset_query = [] if caps.advisory_locks: - _reset_query.append('SELECT pg_advisory_unlock_all();') + _reset_query.append("SELECT pg_advisory_unlock_all();") if caps.sql_close_all: - _reset_query.append('CLOSE ALL;') + _reset_query.append("CLOSE ALL;") if caps.notifications and caps.plpgsql: - _reset_query.append('UNLISTEN *;') + _reset_query.append("UNLISTEN *;") if caps.sql_reset: - _reset_query.append('RESET ALL;') + _reset_query.append("RESET ALL;") - _reset_query = '\n'.join(_reset_query) + _reset_query = "\n".join(_reset_query) self._reset_query = _reset_query return _reset_query @@ -1545,7 +1635,8 @@ def _set_proxy(self, proxy): if self._proxy is not None and proxy is not None: # Should not happen unless there is a bug in `Pool`. raise exceptions.InterfaceError( - 'internal asyncpg error: connection is already proxied') + "internal asyncpg error: connection is already proxied" + ) self._proxy = proxy @@ -1554,10 +1645,11 @@ def _check_listeners(self, listeners, listener_type): count = len(listeners) w = exceptions.InterfaceWarning( - '{conn!r} is being released to the pool but has {c} active ' - '{type} listener{s}'.format( - conn=self, c=count, type=listener_type, - s='s' if count > 1 else '')) + "{conn!r} is being released to the pool but has {c} active " + "{type} listener{s}".format( + conn=self, c=count, type=listener_type, s="s" if count > 1 else "" + ) + ) warnings.warn(w) @@ -1568,9 +1660,9 @@ def _on_release(self, stacklevel=1): # Let's check that the user has not left any listeners on it. self._check_listeners( list(itertools.chain.from_iterable(self._listeners.values())), - 'notification') - self._check_listeners( - self._log_listeners, 'log') + "notification", + ) + self._check_listeners(self._log_listeners, "log") def _drop_local_statement_cache(self): self._stmt_cache.clear() @@ -1650,7 +1742,7 @@ async def _execute( *, return_status=False, ignore_custom_codec=False, - record_class=None + record_class=None, ): with self._stmt_exclusive_section: result, _ = await self.__execute( @@ -1673,10 +1765,11 @@ async def __execute( *, return_status=False, ignore_custom_codec=False, - record_class=None + record_class=None, ): executor = lambda stmt, timeout: self._protocol.bind_execute( - stmt, args, '', limit, return_status, timeout) + stmt, args, "", limit, return_status, timeout + ) timeout = self._protocol._get_timeout(timeout) return await self._do_execute( query, @@ -1688,7 +1781,8 @@ async def __execute( async def _executemany(self, query, args, timeout): executor = lambda stmt, timeout: self._protocol.bind_execute_many( - stmt, args, '', timeout) + stmt, args, "", timeout + ) timeout = self._protocol._get_timeout(timeout) with self._stmt_exclusive_section: result, _ = await self._do_execute(query, executor, timeout) @@ -1702,8 +1796,9 @@ async def _do_execute( retry=True, *, ignore_custom_codec=False, - record_class=None + record_class=None, ): + self._recent_statements.append(query) if timeout is None: stmt = await self._get_statement( query, @@ -1769,26 +1864,31 @@ async def _do_execute( if self._protocol.is_in_transaction() or not retry: raise else: - return await self._do_execute( - query, executor, timeout, retry=False) + return await self._do_execute(query, executor, timeout, retry=False) return result, stmt -async def connect(dsn=None, *, - host=None, port=None, - user=None, password=None, passfile=None, - database=None, - loop=None, - timeout=60, - statement_cache_size=100, - max_cached_statement_lifetime=300, - max_cacheable_statement_size=1024 * 15, - command_timeout=None, - ssl=None, - connection_class=Connection, - record_class=protocol.Record, - server_settings=None): +async def connect( + dsn=None, + *, + host=None, + port=None, + user=None, + password=None, + passfile=None, + database=None, + loop=None, + timeout=60, + statement_cache_size=100, + max_cached_statement_lifetime=300, + max_cacheable_statement_size=1024 * 15, + command_timeout=None, + ssl=None, + connection_class=Connection, + record_class=protocol.Record, + server_settings=None, +): r"""A coroutine to establish a connection to a PostgreSQL server. The connection parameters may be specified either as a connection @@ -2070,8 +2170,9 @@ async def connect(dsn=None, *, """ if not issubclass(connection_class, Connection): raise exceptions.InterfaceError( - 'connection_class is expected to be a subclass of ' - 'asyncpg.Connection, got {!r}'.format(connection_class)) + "connection_class is expected to be a subclass of " + "asyncpg.Connection, got {!r}".format(connection_class) + ) if record_class is not protocol.Record: _check_record_class(record_class) @@ -2102,7 +2203,7 @@ async def connect(dsn=None, *, class _StatementCacheEntry: - __slots__ = ('_query', '_statement', '_cache', '_cleanup_cb') + __slots__ = ("_query", "_statement", "_cache", "_cleanup_cb") def __init__(self, cache, query, statement): self._cache = cache @@ -2113,8 +2214,7 @@ def __init__(self, cache, query, statement): class _StatementCache: - __slots__ = ('_loop', '_entries', '_max_size', '_on_remove', - '_max_lifetime') + __slots__ = ("_loop", "_entries", "_max_size", "_on_remove", "_max_lifetime") def __init__(self, *, loop, max_size, on_remove, max_lifetime): self._loop = loop @@ -2223,7 +2323,8 @@ def _set_entry_timeout(self, entry): # Set the new timeout if it's not 0. if self._max_lifetime: entry._cleanup_cb = self._loop.call_later( - self._max_lifetime, self._on_entry_expired, entry) + self._max_lifetime, self._on_entry_expired, entry + ) def _new_entry(self, query, statement): entry = _StatementCacheEntry(self, query, statement) @@ -2258,22 +2359,21 @@ class _Callback(typing.NamedTuple): is_async: bool @classmethod - def from_callable(cls, cb: typing.Callable[..., None]) -> '_Callback': + def from_callable(cls, cb: typing.Callable[..., None]) -> "_Callback": if inspect.iscoroutinefunction(cb): is_async = True elif callable(cb): is_async = False else: raise exceptions.InterfaceError( - 'expected a callable or an `async def` function,' - 'got {!r}'.format(cb) + "expected a callable or an `async def` function," "got {!r}".format(cb) ) return cls(cb, is_async) class _Atomic: - __slots__ = ('_acquired',) + __slots__ = ("_acquired",) def __init__(self): self._acquired = 0 @@ -2281,7 +2381,8 @@ def __init__(self): def __enter__(self): if self._acquired: raise exceptions.InterfaceError( - 'cannot perform operation: another operation is in progress') + "cannot perform operation: another operation is in progress" + ) self._acquired = 1 def __exit__(self, t, e, tb): @@ -2294,28 +2395,28 @@ class _ConnectionProxy: ServerCapabilities = collections.namedtuple( - 'ServerCapabilities', - ['advisory_locks', 'notifications', 'plpgsql', 'sql_reset', - 'sql_close_all']) -ServerCapabilities.__doc__ = 'PostgreSQL server capabilities.' + "ServerCapabilities", + ["advisory_locks", "notifications", "plpgsql", "sql_reset", "sql_close_all"], +) +ServerCapabilities.__doc__ = "PostgreSQL server capabilities." def _detect_server_capabilities(server_version, connection_settings): - if hasattr(connection_settings, 'padb_revision'): + if hasattr(connection_settings, "padb_revision"): # Amazon Redshift detected. advisory_locks = False notifications = False plpgsql = False sql_reset = True sql_close_all = False - elif hasattr(connection_settings, 'crdb_version'): + elif hasattr(connection_settings, "crdb_version"): # CockroachDB detected. advisory_locks = False notifications = False plpgsql = False sql_reset = False sql_close_all = False - elif hasattr(connection_settings, 'crate_version'): + elif hasattr(connection_settings, "crate_version"): # CrateDB detected. advisory_locks = False notifications = False @@ -2335,7 +2436,7 @@ def _detect_server_capabilities(server_version, connection_settings): notifications=notifications, plpgsql=plpgsql, sql_reset=sql_reset, - sql_close_all=sql_close_all + sql_close_all=sql_close_all, ) @@ -2346,7 +2447,8 @@ def _extract_stack(limit=10): frame = sys._getframe().f_back try: stack = traceback.StackSummary.extract( - traceback.walk_stack(frame), lookup_lines=False) + traceback.walk_stack(frame), lookup_lines=False + ) finally: del frame @@ -2354,30 +2456,27 @@ def _extract_stack(limit=10): i = 0 while i < len(stack) and stack[i][0].startswith(apg_path): i += 1 - stack = stack[i:i + limit] + stack = stack[i : i + limit] stack.reverse() - return ''.join(traceback.format_list(stack)) + return "".join(traceback.format_list(stack)) def _check_record_class(record_class): if record_class is protocol.Record: pass - elif ( - isinstance(record_class, type) - and issubclass(record_class, protocol.Record) - ): + elif isinstance(record_class, type) and issubclass(record_class, protocol.Record): if ( record_class.__new__ is not object.__new__ or record_class.__init__ is not object.__init__ ): raise exceptions.InterfaceError( - 'record_class must not redefine __new__ or __init__' + "record_class must not redefine __new__ or __init__" ) else: raise exceptions.InterfaceError( - 'record_class is expected to be a subclass of ' - 'asyncpg.Record, got {!r}'.format(record_class) + "record_class is expected to be a subclass of " + "asyncpg.Record, got {!r}".format(record_class) ) diff --git a/asyncpg/prepared_stmt.py b/asyncpg/prepared_stmt.py index 8e241d67..48558c14 100644 --- a/asyncpg/prepared_stmt.py +++ b/asyncpg/prepared_stmt.py @@ -15,7 +15,7 @@ class PreparedStatement(connresource.ConnectionResource): """A representation of a prepared statement.""" - __slots__ = ('_state', '_query', '_last_status') + __slots__ = ("_state", "_query", "_last_status") def __init__(self, connection, query, state): super().__init__(connection) @@ -100,8 +100,7 @@ def get_attributes(self): return self._state._get_attributes() @connresource.guarded - def cursor(self, *args, prefetch=None, - timeout=None) -> cursor.CursorFactory: + def cursor(self, *args, prefetch=None, timeout=None) -> cursor.CursorFactory: """Return a *cursor factory* for the prepared statement. :param args: Query arguments. @@ -133,11 +132,11 @@ async def explain(self, *args, analyze=False): is actually a deserialized JSON output of the SQL ``EXPLAIN`` command. """ - query = 'EXPLAIN (FORMAT JSON, VERBOSE' + query = "EXPLAIN (FORMAT JSON, VERBOSE" if analyze: - query += ', ANALYZE) ' + query += ", ANALYZE) " else: - query += ') ' + query += ") " query += self._state.query if analyze: @@ -211,7 +210,7 @@ async def fetchrow(self, *args, timeout=None): return data[0] @connresource.guarded - async def executemany(self, args, *, timeout: float=None): + async def executemany(self, args, *, timeout: float = None): """Execute the statement for each sequence of arguments in *args*. :param args: An iterable containing sequences of arguments. @@ -221,11 +220,12 @@ async def executemany(self, args, *, timeout: float=None): .. versionadded:: 0.22.0 """ return await self.__do_execute( - lambda protocol: protocol.bind_execute_many( - self._state, args, '', timeout)) + lambda protocol: protocol.bind_execute_many(self._state, args, "", timeout) + ) async def __do_execute(self, executor): protocol = self._connection._protocol + self._connection._recent_statements.append(self._query) try: return await executor(protocol) except exceptions.OutdatedSchemaCacheError: @@ -240,15 +240,18 @@ async def __do_execute(self, executor): async def __bind_execute(self, args, limit, timeout): data, status, _ = await self.__do_execute( lambda protocol: protocol.bind_execute( - self._state, args, '', limit, True, timeout)) + self._state, args, "", limit, True, timeout + ) + ) self._last_status = status return data def _check_open(self, meth_name): if self._state.closed: raise exceptions.InterfaceError( - 'cannot call PreparedStmt.{}(): ' - 'the prepared statement is closed'.format(meth_name)) + "cannot call PreparedStmt.{}(): " + "the prepared statement is closed".format(meth_name) + ) def _check_conn_validity(self, meth_name): self._check_open(meth_name) diff --git a/asyncpg/transaction.py b/asyncpg/transaction.py index 2d7ba49f..f26240e1 100644 --- a/asyncpg/transaction.py +++ b/asyncpg/transaction.py @@ -6,10 +6,14 @@ import enum +import structlog + from . import connresource from . import exceptions as apg_errors +logger = structlog.get_logger(__name__) + class TransactionState(enum.Enum): NEW = 0 @@ -19,11 +23,11 @@ class TransactionState(enum.Enum): FAILED = 4 -ISOLATION_LEVELS = {'read_committed', 'serializable', 'repeatable_read'} +ISOLATION_LEVELS = {"read_committed", "serializable", "repeatable_read"} ISOLATION_LEVELS_BY_VALUE = { - 'read committed': 'read_committed', - 'serializable': 'serializable', - 'repeatable read': 'repeatable_read', + "read committed": "read_committed", + "serializable": "serializable", + "repeatable read": "repeatable_read", } @@ -35,16 +39,25 @@ class Transaction(connresource.ConnectionResource): function. """ - __slots__ = ('_connection', '_isolation', '_readonly', '_deferrable', - '_state', '_nested', '_id', '_managed') + __slots__ = ( + "_connection", + "_isolation", + "_readonly", + "_deferrable", + "_state", + "_nested", + "_id", + "_managed", + ) def __init__(self, connection, isolation, readonly, deferrable): super().__init__(connection) if isolation and isolation not in ISOLATION_LEVELS: raise ValueError( - 'isolation is expected to be either of {}, ' - 'got {!r}'.format(ISOLATION_LEVELS, isolation)) + "isolation is expected to be either of {}, " + "got {!r}".format(ISOLATION_LEVELS, isolation) + ) self._isolation = isolation self._readonly = readonly @@ -57,13 +70,14 @@ def __init__(self, connection, isolation, readonly, deferrable): async def __aenter__(self): if self._managed: raise apg_errors.InterfaceError( - 'cannot enter context: already in an `async with` block') + "cannot enter context: already in an `async with` block" + ) self._managed = True await self.start() async def __aexit__(self, extype, ex, tb): try: - self._check_conn_validity('__aexit__') + self._check_conn_validity("__aexit__") except apg_errors.InterfaceError: if extype is GeneratorExit: # When a PoolAcquireContext is being exited, and there @@ -89,18 +103,25 @@ async def __aexit__(self, extype, ex, tb): @connresource.guarded async def start(self): """Enter the transaction or savepoint block.""" - self.__check_state_base('start') + self.__check_state_base("start") if self._state is TransactionState.STARTED: raise apg_errors.InterfaceError( - 'cannot start; the transaction is already started') + "cannot start; the transaction is already started" + ) con = self._connection if con._top_xact is None: if con._protocol.is_in_transaction(): + logger.error( + "bad transaction state for connection", + connection_id=id(con), + recent_statements=list(con._recent_statements), + ) raise apg_errors.InterfaceError( - 'cannot use Connection.transaction() in ' - 'a manually started transaction') + "cannot use Connection.transaction() in " + "a manually started transaction" + ) con._top_xact = self else: # Nested transaction block @@ -108,31 +129,33 @@ async def start(self): top_xact_isolation = con._top_xact._isolation if top_xact_isolation is None: top_xact_isolation = ISOLATION_LEVELS_BY_VALUE[ - await self._connection.fetchval( - 'SHOW transaction_isolation;')] + await self._connection.fetchval("SHOW transaction_isolation;") + ] if self._isolation != top_xact_isolation: raise apg_errors.InterfaceError( - 'nested transaction has a different isolation level: ' - 'current {!r} != outer {!r}'.format( - self._isolation, top_xact_isolation)) + "nested transaction has a different isolation level: " + "current {!r} != outer {!r}".format( + self._isolation, top_xact_isolation + ) + ) self._nested = True if self._nested: - self._id = con._get_unique_id('savepoint') - query = 'SAVEPOINT {};'.format(self._id) + self._id = con._get_unique_id("savepoint") + query = "SAVEPOINT {};".format(self._id) else: - query = 'BEGIN' - if self._isolation == 'read_committed': - query += ' ISOLATION LEVEL READ COMMITTED' - elif self._isolation == 'repeatable_read': - query += ' ISOLATION LEVEL REPEATABLE READ' - elif self._isolation == 'serializable': - query += ' ISOLATION LEVEL SERIALIZABLE' + query = "BEGIN" + if self._isolation == "read_committed": + query += " ISOLATION LEVEL READ COMMITTED" + elif self._isolation == "repeatable_read": + query += " ISOLATION LEVEL REPEATABLE READ" + elif self._isolation == "serializable": + query += " ISOLATION LEVEL SERIALIZABLE" if self._readonly: - query += ' READ ONLY' + query += " READ ONLY" if self._deferrable: - query += ' DEFERRABLE' - query += ';' + query += " DEFERRABLE" + query += ";" try: await self._connection.execute(query) @@ -145,35 +168,35 @@ async def start(self): def __check_state_base(self, opname): if self._state is TransactionState.COMMITTED: raise apg_errors.InterfaceError( - 'cannot {}; the transaction is already committed'.format( - opname)) + "cannot {}; the transaction is already committed".format(opname) + ) if self._state is TransactionState.ROLLEDBACK: raise apg_errors.InterfaceError( - 'cannot {}; the transaction is already rolled back'.format( - opname)) + "cannot {}; the transaction is already rolled back".format(opname) + ) if self._state is TransactionState.FAILED: raise apg_errors.InterfaceError( - 'cannot {}; the transaction is in error state'.format( - opname)) + "cannot {}; the transaction is in error state".format(opname) + ) def __check_state(self, opname): if self._state is not TransactionState.STARTED: if self._state is TransactionState.NEW: raise apg_errors.InterfaceError( - 'cannot {}; the transaction is not yet started'.format( - opname)) + "cannot {}; the transaction is not yet started".format(opname) + ) self.__check_state_base(opname) async def __commit(self): - self.__check_state('commit') + self.__check_state("commit") if self._connection._top_xact is self: self._connection._top_xact = None if self._nested: - query = 'RELEASE SAVEPOINT {};'.format(self._id) + query = "RELEASE SAVEPOINT {};".format(self._id) else: - query = 'COMMIT;' + query = "COMMIT;" try: await self._connection.execute(query) @@ -184,15 +207,15 @@ async def __commit(self): self._state = TransactionState.COMMITTED async def __rollback(self): - self.__check_state('rollback') + self.__check_state("rollback") if self._connection._top_xact is self: self._connection._top_xact = None if self._nested: - query = 'ROLLBACK TO {};'.format(self._id) + query = "ROLLBACK TO {};".format(self._id) else: - query = 'ROLLBACK;' + query = "ROLLBACK;" try: await self._connection.execute(query) @@ -207,7 +230,8 @@ async def commit(self): """Exit the transaction or savepoint block and commit changes.""" if self._managed: raise apg_errors.InterfaceError( - 'cannot manually commit from within an `async with` block') + "cannot manually commit from within an `async with` block" + ) await self.__commit() @connresource.guarded @@ -215,24 +239,26 @@ async def rollback(self): """Exit the transaction or savepoint block and rollback changes.""" if self._managed: raise apg_errors.InterfaceError( - 'cannot manually rollback from within an `async with` block') + "cannot manually rollback from within an `async with` block" + ) await self.__rollback() def __repr__(self): attrs = [] - attrs.append('state:{}'.format(self._state.name.lower())) + attrs.append("state:{}".format(self._state.name.lower())) if self._isolation is not None: attrs.append(self._isolation) if self._readonly: - attrs.append('readonly') + attrs.append("readonly") if self._deferrable: - attrs.append('deferrable') + attrs.append("deferrable") - if self.__class__.__module__.startswith('asyncpg.'): - mod = 'asyncpg' + if self.__class__.__module__.startswith("asyncpg."): + mod = "asyncpg" else: mod = self.__class__.__module__ - return '<{}.{} {} {:#x}>'.format( - mod, self.__class__.__name__, ' '.join(attrs), id(self)) + return "<{}.{} {} {:#x}>".format( + mod, self.__class__.__name__, " ".join(attrs), id(self) + ) From ec4367f87f537bd8692ff4a165772f521076ae0e Mon Sep 17 00:00:00 2001 From: James Robinson Date: Mon, 1 Aug 2022 15:16:20 -0400 Subject: [PATCH 3/4] Log timestamp, context, arguments. Skip double-logging same statement. --- asyncpg/connection.py | 29 ++++++++++++++++++++++------- asyncpg/prepared_stmt.py | 8 ++++++-- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 566bec30..8207b59d 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -9,6 +9,7 @@ import asyncpg import collections import collections.abc +from datetime import datetime, timezone import functools import itertools import inspect @@ -305,6 +306,24 @@ def is_in_transaction(self): """ return self._protocol.is_in_transaction() + def log_statement(self, context: str, statement: str, args=None): + when = datetime.now(timezone.utc) + + if self._recent_statements: + most_recently_logged = self._recent_statements[-1] + ( + prior_when, + prior_context, + prior_statement, + prior_args, + ) = most_recently_logged + if prior_statement == statement: + # Do not double-log + return + + to_log = (when.strftime("%Y-%m-%d %H:%M:%S UTC"), context, statement, args) + self._recent_statements.append(to_log) + async def execute(self, query: str, *args, timeout: float = None) -> str: """Execute an SQL command (or commands). @@ -335,12 +354,8 @@ async def execute(self, query: str, *args, timeout: float = None) -> str: """ self._check_open() - # Append to circular buffer of most recent executed statements - # for debugging. - self._recent_statements.append(query) - if not args: - self._recent_statements.append(query) + self.log_statement("execute no args", query) return await self._protocol.query(query, timeout) _, status, _ = await self._execute( @@ -541,7 +556,7 @@ def cursor(self, query, *args, prefetch=None, timeout=None, record_class=None): """ self._check_open() - self._recent_statements.append(query) + self.log_statement("cursor", query, args) return cursor.CursorFactory( self, @@ -1802,7 +1817,7 @@ async def _do_execute( ignore_custom_codec=False, record_class=None, ): - self._recent_statements.append(query) + self.log_statement("_do_execute", query) if timeout is None: stmt = await self._get_statement( query, diff --git a/asyncpg/prepared_stmt.py b/asyncpg/prepared_stmt.py index 48558c14..cd46421d 100644 --- a/asyncpg/prepared_stmt.py +++ b/asyncpg/prepared_stmt.py @@ -15,7 +15,7 @@ class PreparedStatement(connresource.ConnectionResource): """A representation of a prepared statement.""" - __slots__ = ("_state", "_query", "_last_status") + __slots__ = ("_state", "_query", "_last_status", "_logged_args") def __init__(self, connection, query, state): super().__init__(connection) @@ -219,13 +219,16 @@ async def executemany(self, args, *, timeout: float = None): .. versionadded:: 0.22.0 """ + self._logged_args = args return await self.__do_execute( lambda protocol: protocol.bind_execute_many(self._state, args, "", timeout) ) async def __do_execute(self, executor): protocol = self._connection._protocol - self._connection._recent_statements.append(self._query) + self._connection.log_statement( + "preparedstatement __do_execute", self._query, self._logged_args + ) try: return await executor(protocol) except exceptions.OutdatedSchemaCacheError: @@ -238,6 +241,7 @@ async def __do_execute(self, executor): raise async def __bind_execute(self, args, limit, timeout): + self._logged_args = args data, status, _ = await self.__do_execute( lambda protocol: protocol.bind_execute( self._state, args, "", limit, True, timeout From 1bfaa0adeafceb811b6eb099c4ac4c36fc78a8c3 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Mon, 1 Aug 2022 15:19:07 -0400 Subject: [PATCH 4/4] Comment. --- asyncpg/connection.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 8207b59d..2b8a9a21 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -311,14 +311,15 @@ def log_statement(self, context: str, statement: str, args=None): if self._recent_statements: most_recently_logged = self._recent_statements[-1] + # Each entry is (timestamp, calling context, sql statement, args (if any)) tuple. ( - prior_when, - prior_context, + _, + _, prior_statement, - prior_args, + _, ) = most_recently_logged if prior_statement == statement: - # Do not double-log + # Do not double-log same query. return to_log = (when.strftime("%Y-%m-%d %H:%M:%S UTC"), context, statement, args)