From bf6033c90e6bfe6c23a0e75d3ab445c465fdc071 Mon Sep 17 00:00:00 2001 From: rony batista Date: Sun, 28 Aug 2022 00:18:18 +0100 Subject: [PATCH 01/23] Add target session attribute connection param --- asyncpg/_testbase/__init__.py | 89 +++++++++++++++++++++++++++++++++++ asyncpg/connect_utils.py | 85 +++++++++++++++++++++++++++++---- asyncpg/connection.py | 24 +++++++++- asyncpg/exceptions/_base.py | 6 ++- tests/test_connect.py | 67 +++++++++++++++++++++++++- tests/test_pool.py | 48 +------------------ 6 files changed, 261 insertions(+), 58 deletions(-) diff --git a/asyncpg/_testbase/__init__.py b/asyncpg/_testbase/__init__.py index 9944b20f..3dd8a314 100644 --- a/asyncpg/_testbase/__init__.py +++ b/asyncpg/_testbase/__init__.py @@ -435,3 +435,92 @@ def tearDown(self): self.con = None finally: super().tearDown() + + +class HotStandbyTestCase(ClusterTestCase): + @classmethod + def setup_cluster(cls): + cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster) + cls.start_cluster( + cls.master_cluster, + server_settings={ + 'max_wal_senders': 10, + 'wal_level': 'hot_standby' + } + ) + + con = None + + try: + con = cls.loop.run_until_complete( + cls.master_cluster.connect( + database='postgres', user='postgres', loop=cls.loop)) + + cls.loop.run_until_complete( + con.execute(''' + CREATE ROLE replication WITH LOGIN REPLICATION + ''')) + + cls.master_cluster.trust_local_replication_by('replication') + + conn_spec = cls.master_cluster.get_connection_spec() + + cls.standby_cluster = cls.new_cluster( + pg_cluster.HotStandbyCluster, + cluster_kwargs={ + 'master': conn_spec, + 'replication_user': 'replication' + } + ) + cls.start_cluster( + cls.standby_cluster, + server_settings={ + 'hot_standby': True + } + ) + + finally: + if con is not None: + cls.loop.run_until_complete(con.close()) + + @classmethod + def get_cluster_connection_spec(cls, cluster, kwargs={}): + conn_spec = cluster.get_connection_spec() + if kwargs.get('dsn'): + conn_spec.pop('host') + conn_spec.update(kwargs) + if not os.environ.get('PGHOST') and not kwargs.get('dsn'): + if 'database' not in conn_spec: + conn_spec['database'] = 'postgres' + if 'user' not in conn_spec: + conn_spec['user'] = 'postgres' + return conn_spec + + @classmethod + def get_connection_spec(cls, kwargs={}): + primary_spec = cls.get_cluster_connection_spec( + cls.master_cluster, kwargs + ) + standby_spec = cls.get_cluster_connection_spec( + cls.standby_cluster, kwargs + ) + return { + 'host': [primary_spec['host'], standby_spec['host']], + 'port': [primary_spec['port'], standby_spec['port']], + 'database': primary_spec['database'], + 'user': primary_spec['user'], + **kwargs + } + + @classmethod + def connect_primary(cls, **kwargs): + conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs) + return pg_connection.connect(**conn_spec, loop=cls.loop) + + @classmethod + def connect_standby(cls, **kwargs): + conn_spec = cls.get_cluster_connection_spec( + cls.standby_cluster, + kwargs + ) + return pg_connection.connect(**conn_spec, loop=cls.loop) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 90a61503..a51eb789 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -13,6 +13,7 @@ import os import pathlib import platform +import random import re import socket import ssl as ssl_module @@ -56,6 +57,7 @@ def parse(cls, sslmode): 'direct_tls', 'connect_timeout', 'server_settings', + 'target_session_attribute', ]) @@ -259,7 +261,8 @@ def _dot_postgresql_path(filename) -> pathlib.Path: def _parse_connect_dsn_and_args(*, dsn, host, port, user, password, passfile, database, ssl, - direct_tls, connect_timeout, server_settings): + direct_tls, connect_timeout, server_settings, + target_session_attribute): # `auth_hosts` is the version of host information for the purposes # of reading the pgpass file. auth_hosts = None @@ -603,7 +606,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, params = _ConnectionParameters( user=user, password=password, database=database, ssl=ssl, sslmode=sslmode, direct_tls=direct_tls, - connect_timeout=connect_timeout, server_settings=server_settings) + connect_timeout=connect_timeout, server_settings=server_settings, + target_session_attribute=target_session_attribute) return addrs, params @@ -613,8 +617,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, statement_cache_size, max_cached_statement_lifetime, max_cacheable_statement_size, - ssl, direct_tls, server_settings): - + ssl, direct_tls, server_settings, + target_session_attribute): local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', @@ -642,7 +646,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, dsn=dsn, host=host, port=port, user=user, password=password, passfile=passfile, ssl=ssl, direct_tls=direct_tls, database=database, - connect_timeout=timeout, server_settings=server_settings) + connect_timeout=timeout, server_settings=server_settings, + target_session_attribute=target_session_attribute) config = _ClientConfiguration( command_timeout=command_timeout, @@ -875,18 +880,64 @@ async def __connect_addr( return con +class SessionAttribute(str, enum.Enum): + any = 'any' + primary = 'primary' + standby = 'standby' + prefer_standby = 'prefer-standby' + + +def _accept_in_hot_standby(should_be_in_hot_standby: bool): + """ + If the server didn't report "in_hot_standby" at startup, we must determine + the state by checking "SELECT pg_catalog.pg_is_in_recovery()". + """ + async def can_be_used(connection): + settings = connection.get_settings() + hot_standby_status = getattr(settings, 'in_hot_standby', None) + if hot_standby_status is not None: + is_in_hot_standby = hot_standby_status == 'on' + else: + is_in_hot_standby = await connection.fetchval( + "SELECT pg_catalog.pg_is_in_recovery()" + ) + + return is_in_hot_standby == should_be_in_hot_standby + + return can_be_used + + +async def _accept_any(_): + return True + + +target_attrs_check = { + SessionAttribute.any: _accept_any, + SessionAttribute.primary: _accept_in_hot_standby(False), + SessionAttribute.standby: _accept_in_hot_standby(True), + SessionAttribute.prefer_standby: _accept_in_hot_standby(True), +} + + +async def _can_use_connection(connection, attr: SessionAttribute): + can_use = target_attrs_check[attr] + return await can_use(connection) + + async def _connect(*, loop, timeout, connection_class, record_class, **kwargs): if loop is None: loop = asyncio.get_event_loop() addrs, params, config = _parse_connect_arguments(timeout=timeout, **kwargs) + target_attr = params.target_session_attribute + candidates = [] + chosen_connection = None last_error = None - addr = None for addr in addrs: before = time.monotonic() try: - return await _connect_addr( + conn = await _connect_addr( addr=addr, loop=loop, timeout=timeout, @@ -895,12 +946,30 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs): connection_class=connection_class, record_class=record_class, ) + candidates.append(conn) + if await _can_use_connection(conn, target_attr): + chosen_connection = conn + break except (OSError, asyncio.TimeoutError, ConnectionError) as ex: last_error = ex finally: timeout -= time.monotonic() - before + else: + if target_attr == SessionAttribute.prefer_standby and candidates: + chosen_connection = random.choice(candidates) + + await asyncio.gather( + (c.close() for c in candidates if c is not chosen_connection), + return_exceptions=True + ) + + if chosen_connection: + return chosen_connection - raise last_error + raise last_error or exceptions.TargetServerAttributeNotMatched( + 'None of the hosts match the target attribute requirement ' + '{!r}'.format(target_attr) + ) async def _cancel(*, loop, addr, params: _ConnectionParameters, diff --git a/asyncpg/connection.py b/asyncpg/connection.py index ea128aab..6797c54e 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -30,6 +30,7 @@ from . import serverversion from . import transaction from . import utils +from .connect_utils import SessionAttribute class ConnectionMeta(type): @@ -1792,7 +1793,8 @@ async def connect(dsn=None, *, direct_tls=False, connection_class=Connection, record_class=protocol.Record, - server_settings=None): + server_settings=None, + target_session_attribute=SessionAttribute.any): r"""A coroutine to establish a connection to a PostgreSQL server. The connection parameters may be specified either as a connection @@ -2003,6 +2005,16 @@ async def connect(dsn=None, *, this connection object. Must be a subclass of :class:`~asyncpg.Record`. + :param SessionAttribute target_session_attribute: + If specified, check that the host has the correct attribute. + Can be one of: + "any": the first successfully connected host + "primary": the host must NOT be in hot standby mode + "standby": the host must be in hot standby mode + "prefer-standby": first try to find a standby host, but if + none of the listed hosts is a standby server, + return any of them. + :return: A :class:`~asyncpg.connection.Connection` instance. Example: @@ -2087,6 +2099,15 @@ async def connect(dsn=None, *, if record_class is not protocol.Record: _check_record_class(record_class) + try: + target_session_attribute = SessionAttribute(target_session_attribute) + except ValueError as exc: + raise exceptions.InterfaceError( + "target_session_attribute is expected to be one of " + "'any', 'primary', 'standby' or 'prefer-standby'" + ", got {!r}".format(target_session_attribute) + ) from exc + if loop is None: loop = asyncio.get_event_loop() @@ -2109,6 +2130,7 @@ async def connect(dsn=None, *, statement_cache_size=statement_cache_size, max_cached_statement_lifetime=max_cached_statement_lifetime, max_cacheable_statement_size=max_cacheable_statement_size, + target_session_attribute=target_session_attribute ) diff --git a/asyncpg/exceptions/_base.py b/asyncpg/exceptions/_base.py index 783b5eb5..de981d25 100644 --- a/asyncpg/exceptions/_base.py +++ b/asyncpg/exceptions/_base.py @@ -13,7 +13,7 @@ __all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError', 'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage', 'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError', - 'UnsupportedClientFeatureError') + 'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched') def _is_asyncpg_class(cls): @@ -244,6 +244,10 @@ class ProtocolError(InternalClientError): """Unexpected condition in the handling of PostgreSQL protocol input.""" +class TargetServerAttributeNotMatched(InternalClientError): + """Could not find a host that satisfies the target attribute requirement""" + + class OutdatedSchemaCacheError(InternalClientError): """A value decoding error caused by a schema change before row fetching.""" diff --git a/tests/test_connect.py b/tests/test_connect.py index db7817f6..f905e3cd 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -789,6 +789,7 @@ def run_testcase(self, testcase): database = testcase.get('database') sslmode = testcase.get('ssl') server_settings = testcase.get('server_settings') + target_session_attribute = testcase.get('target_session_attribute') expected = testcase.get('result') expected_error = testcase.get('error') @@ -812,7 +813,8 @@ def run_testcase(self, testcase): dsn=dsn, host=host, port=port, user=user, password=password, passfile=passfile, database=database, ssl=sslmode, direct_tls=False, connect_timeout=None, - server_settings=server_settings) + server_settings=server_settings, + target_session_attribute=target_session_attribute) params = { k: v for k, v in params._asdict().items() @@ -1743,3 +1745,66 @@ async def test_no_explicit_close_with_debug(self): self.assertIn('in test_no_explicit_close_with_debug', msg) finally: self.loop.set_debug(olddebug) + + +class TestConnectionAttributes(tb.HotStandbyTestCase): + + async def _run_connection_test( + self, connect, target_attribute, expected_host + ): + conn = await connect(target_session_attribute=target_attribute) + self.assertTrue(_get_connected_host(conn).startswith(expected_host)) + await conn.close() + + async def test_target_server_attribute_host(self): + master_host = self.master_cluster.get_connection_spec()['host'] + standby_host = self.standby_cluster.get_connection_spec()['host'] + tests = [ + (self.connect_primary, 'primary', master_host), + (self.connect_standby, 'standby', standby_host), + ] + + for connect, target_attr, expected_host in tests: + await self._run_connection_test( + connect, target_attr, expected_host + ) + + async def test_target_attribute_not_matched(self): + tests = [ + (self.connect_standby, 'primary'), + (self.connect_primary, 'standby'), + ] + + for connect, target_attr in tests: + with self.assertRaises(exceptions.TargetServerAttributeNotMatched): + await connect(target_session_attribute=target_attr) + + async def test_prefer_standby_when_standby_is_up(self): + con = await self.connect(target_session_attribute='prefer-standby') + standby_host = self.standby_cluster.get_connection_spec()['host'] + connected_host = _get_connected_host(con) + self.assertTrue(connected_host.startswith(standby_host)) + await con.close() + + async def test_prefer_standby_picks_master_when_standby_is_down(self): + primary_spec = self.get_cluster_connection_spec(self.master_cluster) + connection_spec = { + 'host': [ + primary_spec['host'], + '/var/test/a/cluster/that/does/not/exist', + ], + 'port': [primary_spec['port'], 12345], + 'database': primary_spec['database'], + 'user': primary_spec['user'], + 'target_session_attribute': 'prefer-standby' + } + + con = await connection.connect(**connection_spec, loop=self.loop) + master_host = self.master_cluster.get_connection_spec()['host'] + connected_host = _get_connected_host(con) + self.assertTrue(connected_host.startswith(master_host)) + await con.close() + + +def _get_connected_host(con): + return con._transport.get_extra_info('peername') diff --git a/tests/test_pool.py b/tests/test_pool.py index e2c99efc..f96cd2a6 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -18,7 +18,6 @@ import asyncpg from asyncpg import _testbase as tb from asyncpg import connection as pg_connection -from asyncpg import cluster as pg_cluster from asyncpg import pool as pg_pool _system = platform.uname().system @@ -964,52 +963,7 @@ async def worker(): @unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing') -class TestHotStandby(tb.ClusterTestCase): - @classmethod - def setup_cluster(cls): - cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster) - cls.start_cluster( - cls.master_cluster, - server_settings={ - 'max_wal_senders': 10, - 'wal_level': 'hot_standby' - } - ) - - con = None - - try: - con = cls.loop.run_until_complete( - cls.master_cluster.connect( - database='postgres', user='postgres', loop=cls.loop)) - - cls.loop.run_until_complete( - con.execute(''' - CREATE ROLE replication WITH LOGIN REPLICATION - ''')) - - cls.master_cluster.trust_local_replication_by('replication') - - conn_spec = cls.master_cluster.get_connection_spec() - - cls.standby_cluster = cls.new_cluster( - pg_cluster.HotStandbyCluster, - cluster_kwargs={ - 'master': conn_spec, - 'replication_user': 'replication' - } - ) - cls.start_cluster( - cls.standby_cluster, - server_settings={ - 'hot_standby': True - } - ) - - finally: - if con is not None: - cls.loop.run_until_complete(con.close()) - +class TestHotStandby(tb.HotStandbyTestCase): def create_pool(self, **kwargs): conn_spec = self.standby_cluster.get_connection_spec() conn_spec.update(kwargs) From bee17cb8415c76204ec21d25399d39faa5333fd7 Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Wed, 7 Dec 2022 22:23:28 +0100 Subject: [PATCH 02/23] Fixed tests so they pass on windows. Are the hosts unique for the replica clusters ? I find it hard to tell, but at least on Windows all clusters are on localhost, so the test was not actually verifying the code. --- tests/test_connect.py | 41 ++++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/tests/test_connect.py b/tests/test_connect.py index f905e3cd..e8145a80 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -25,7 +25,7 @@ import asyncpg from asyncpg import _testbase as tb -from asyncpg import connection +from asyncpg import connection as pg_connection from asyncpg import connect_utils from asyncpg import cluster as pg_cluster from asyncpg import exceptions @@ -1167,7 +1167,7 @@ async def test_connect_args_validation(self): class TestConnection(tb.ConnectedTestCase): async def test_connection_isinstance(self): - self.assertTrue(isinstance(self.con, connection.Connection)) + self.assertTrue(isinstance(self.con, pg_connection.Connection)) self.assertTrue(isinstance(self.con, object)) self.assertFalse(isinstance(self.con, list)) @@ -1750,23 +1750,23 @@ async def test_no_explicit_close_with_debug(self): class TestConnectionAttributes(tb.HotStandbyTestCase): async def _run_connection_test( - self, connect, target_attribute, expected_host + self, connect, target_attribute, expected_port ): conn = await connect(target_session_attribute=target_attribute) - self.assertTrue(_get_connected_host(conn).startswith(expected_host)) + self.assertTrue(_get_connected_host(conn).endswith(expected_port)) await conn.close() - async def test_target_server_attribute_host(self): - master_host = self.master_cluster.get_connection_spec()['host'] - standby_host = self.standby_cluster.get_connection_spec()['host'] + async def test_target_server_attribute_port(self): + master_port = self.master_cluster.get_connection_spec()['port'] + standby_port = self.standby_cluster.get_connection_spec()['port'] tests = [ - (self.connect_primary, 'primary', master_host), - (self.connect_standby, 'standby', standby_host), + (self.connect_primary, 'primary', master_port), + (self.connect_standby, 'standby', standby_port), ] - for connect, target_attr, expected_host in tests: + for connect, target_attr, expected_port in tests: await self._run_connection_test( - connect, target_attr, expected_host + connect, target_attr, expected_port ) async def test_target_attribute_not_matched(self): @@ -1781,9 +1781,9 @@ async def test_target_attribute_not_matched(self): async def test_prefer_standby_when_standby_is_up(self): con = await self.connect(target_session_attribute='prefer-standby') - standby_host = self.standby_cluster.get_connection_spec()['host'] + standby_port = self.standby_cluster.get_connection_spec()['port'] connected_host = _get_connected_host(con) - self.assertTrue(connected_host.startswith(standby_host)) + self.assertTrue(connected_host.endswith(standby_port)) await con.close() async def test_prefer_standby_picks_master_when_standby_is_down(self): @@ -1791,20 +1791,23 @@ async def test_prefer_standby_picks_master_when_standby_is_down(self): connection_spec = { 'host': [ primary_spec['host'], - '/var/test/a/cluster/that/does/not/exist', + 'unlocalhost', ], - 'port': [primary_spec['port'], 12345], + 'port': [primary_spec['port'], 15345], 'database': primary_spec['database'], 'user': primary_spec['user'], 'target_session_attribute': 'prefer-standby' } - con = await connection.connect(**connection_spec, loop=self.loop) - master_host = self.master_cluster.get_connection_spec()['host'] + con = await self.connect(**connection_spec) + master_port = self.master_cluster.get_connection_spec()['port'] connected_host = _get_connected_host(con) - self.assertTrue(connected_host.startswith(master_host)) + self.assertTrue(connected_host.endswith(master_port)) await con.close() def _get_connected_host(con): - return con._transport.get_extra_info('peername') + peername = con._transport.get_extra_info('peername') + if isinstance(peername, tuple): + peername = "".join((str(s) for s in peername if s)) + return peername From 1c675119fca1c55ce0b331a4133c461510f745a9 Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Thu, 8 Dec 2022 08:10:04 +0100 Subject: [PATCH 03/23] push for workflows --- tests/test_connect.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_connect.py b/tests/test_connect.py index e8145a80..c766f00e 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1806,6 +1806,7 @@ async def test_prefer_standby_picks_master_when_standby_is_down(self): await con.close() + def _get_connected_host(con): peername = con._transport.get_extra_info('peername') if isinstance(peername, tuple): From fae9d9cf930c2dcfd75c749a78719f81455ba0f2 Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Thu, 8 Dec 2022 08:43:43 +0100 Subject: [PATCH 04/23] push for workflows --- tests/test_connect.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_connect.py b/tests/test_connect.py index c766f00e..e8145a80 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1806,7 +1806,6 @@ async def test_prefer_standby_picks_master_when_standby_is_down(self): await con.close() - def _get_connected_host(con): peername = con._transport.get_extra_info('peername') if isinstance(peername, tuple): From 7a847ce9c74e7117e9c6b5ec926a524da888f2e3 Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Thu, 8 Dec 2022 08:53:09 +0100 Subject: [PATCH 05/23] no image for python 3.6 --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d61573db..28e38e4e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,7 +17,7 @@ jobs: # job. strategy: matrix: - python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] + python-version: ["3.7", "3.8", "3.9", "3.10"] os: [ubuntu-latest, macos-latest, windows-latest] loop: [asyncio, uvloop] exclude: From a3d7342e29d9a9e2813ecd81de97f386c817f2f1 Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Thu, 8 Dec 2022 09:21:17 +0100 Subject: [PATCH 06/23] merge ci changes from master --- .github/workflows/tests.yml | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 28e38e4e..a120e9a6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,13 +17,10 @@ jobs: # job. strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] os: [ubuntu-latest, macos-latest, windows-latest] loop: [asyncio, uvloop] exclude: - # uvloop does not support Python 3.6 - - loop: uvloop - python-version: "3.6" # uvloop does not support windows - loop: uvloop os: windows-latest @@ -38,7 +35,7 @@ jobs: PIP_DISABLE_PIP_VERSION_CHECK: 1 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: fetch-depth: 50 submodules: true @@ -54,7 +51,7 @@ jobs: __version__\s*=\s*(?:['"])([[:PEP440:]])(?:['"]) - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 if: steps.release.outputs.version == 0 with: python-version: ${{ matrix.python-version }} @@ -79,7 +76,7 @@ jobs: test-postgres: strategy: matrix: - postgres-version: ["9.5", "9.6", "10", "11", "12", "13", "14"] + postgres-version: ["9.5", "9.6", "10", "11", "12", "13", "14", "15"] runs-on: ubuntu-latest @@ -87,7 +84,7 @@ jobs: PIP_DISABLE_PIP_VERSION_CHECK: 1 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: fetch-depth: 50 submodules: true @@ -114,8 +111,10 @@ jobs: >> "${GITHUB_ENV}" - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 if: steps.release.outputs.version == 0 + with: + python-version: "3.x" - name: Install Python Deps if: steps.release.outputs.version == 0 From 7d9234f87aff97728f05237b79af25328dd80b70 Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Sun, 11 Dec 2022 11:06:26 +0100 Subject: [PATCH 07/23] merge setup.py changes from master --- setup.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/setup.py b/setup.py index 332bad3f..af0bcdc3 100644 --- a/setup.py +++ b/setup.py @@ -7,8 +7,8 @@ import sys -if sys.version_info < (3, 6): - raise RuntimeError('asyncpg requires Python 3.6 or greater') +if sys.version_info < (3, 7): + raise RuntimeError('asyncpg requires Python 3.7 or greater') import os import os.path @@ -29,12 +29,8 @@ # Minimal dependencies required to test asyncpg. TEST_DEPENDENCIES = [ - # pycodestyle is a dependency of flake8, but it must be frozen because - # their combination breaks too often - # (example breakage: https://gitlab.com/pycqa/flake8/issues/427) - 'pycodestyle~=2.7.0', - 'flake8~=3.9.2', - 'uvloop>=0.15.3; platform_system != "Windows" and python_version >= "3.7"', + 'flake8~=5.0.4', + 'uvloop>=0.15.3; platform_system != "Windows"', ] # Dependencies required to build documentation. @@ -259,7 +255,6 @@ def finalize_options(self): 'Operating System :: MacOS :: MacOS X', 'Operating System :: Microsoft :: Windows', 'Programming Language :: Python :: 3 :: Only', - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', @@ -268,7 +263,7 @@ def finalize_options(self): 'Topic :: Database :: Front-Ends', ], platforms=['macOS', 'POSIX', 'Windows'], - python_requires='>=3.6.0', + python_requires='>=3.7.0', zip_safe=False, author='MagicStack Inc', author_email='hello@magic.io', From a43b0646fdf08777bb92b327d3393ab96dcb0c74 Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Mon, 12 Dec 2022 08:33:26 +0100 Subject: [PATCH 08/23] Add logging to server selection procedure --- asyncpg/connect_utils.py | 12 +++++++++--- tests/test_connect.py | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index a51eb789..97dd8065 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -10,6 +10,7 @@ import enum import functools import getpass +import logging import os import pathlib import platform @@ -30,6 +31,7 @@ from . import exceptions from . import protocol +logger = logging.getLogger(__name__) class SSLMode(enum.IntEnum): disable = 0 @@ -898,11 +900,15 @@ async def can_be_used(connection): if hot_standby_status is not None: is_in_hot_standby = hot_standby_status == 'on' else: - is_in_hot_standby = await connection.fetchval( + is_in_recovery = await connection.fetchval( "SELECT pg_catalog.pg_is_in_recovery()" ) - - return is_in_hot_standby == should_be_in_hot_standby + if is_in_recovery: + logger.warning("Connection {!r} is still in recovery mode".format(connection)) + is_in_hot_standby = not is_in_recovery + connection_eligible = is_in_hot_standby == should_be_in_hot_standby + logger.debug("Connection {!r} is eligible ({!r}). Allow".format(connection, connection_eligible)) + return connection_eligible return can_be_used diff --git a/tests/test_connect.py b/tests/test_connect.py index e8145a80..161b6f0c 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1740,7 +1740,7 @@ async def test_no_explicit_close_with_debug(self): r'unclosed connection') as rw: await self._run_no_explicit_close_test() - msg = rw.warning.args[0] + msg = " ".join(rw.warning.args) self.assertIn(' created at:\n', msg) self.assertIn('in test_no_explicit_close_with_debug', msg) finally: From 24ea4e58e84a8a27092a15b26ee4e561f3d26dac Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Mon, 12 Dec 2022 08:38:37 +0100 Subject: [PATCH 09/23] formatting --- asyncpg/connect_utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 97dd8065..e37ce310 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -33,6 +33,7 @@ logger = logging.getLogger(__name__) + class SSLMode(enum.IntEnum): disable = 0 allow = 1 @@ -904,10 +905,14 @@ async def can_be_used(connection): "SELECT pg_catalog.pg_is_in_recovery()" ) if is_in_recovery: - logger.warning("Connection {!r} is still in recovery mode".format(connection)) + logger.warning("Connection {!r} is still in recovery mode" + .format(connection)) is_in_hot_standby = not is_in_recovery connection_eligible = is_in_hot_standby == should_be_in_hot_standby - logger.debug("Connection {!r} is eligible ({!r}). Allow".format(connection, connection_eligible)) + logger.debug( + "Connection {!r} eligible=({!r}). Allow hot standby={!r}". + format(connection, connection_eligible, should_be_in_hot_standby) + ) return connection_eligible return can_be_used From 3df816dca375f6c3ade07b3bce7cdcd4b74edd80 Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Mon, 12 Dec 2022 08:51:55 +0100 Subject: [PATCH 10/23] undo mistake --- asyncpg/connect_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index e37ce310..57edfac0 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -894,6 +894,8 @@ def _accept_in_hot_standby(should_be_in_hot_standby: bool): """ If the server didn't report "in_hot_standby" at startup, we must determine the state by checking "SELECT pg_catalog.pg_is_in_recovery()". + If the server allows a connection and states it is in recovery it must + be a replica/standby server. """ async def can_be_used(connection): settings = connection.get_settings() @@ -901,13 +903,9 @@ async def can_be_used(connection): if hot_standby_status is not None: is_in_hot_standby = hot_standby_status == 'on' else: - is_in_recovery = await connection.fetchval( + is_in_hot_standby = await connection.fetchval( "SELECT pg_catalog.pg_is_in_recovery()" ) - if is_in_recovery: - logger.warning("Connection {!r} is still in recovery mode" - .format(connection)) - is_in_hot_standby = not is_in_recovery connection_eligible = is_in_hot_standby == should_be_in_hot_standby logger.debug( "Connection {!r} eligible=({!r}). Allow hot standby={!r}". From e24a091d7900e57c34f3757260af8d05f0e747aa Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Mon, 12 Dec 2022 08:58:47 +0100 Subject: [PATCH 11/23] fix test? --- tests/test_connect.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_connect.py b/tests/test_connect.py index 161b6f0c..a18a22e4 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1464,7 +1464,7 @@ async def test_executemany_uvloop_ssl_issue_700(self): ) finally: try: - await con.execute('DROP TABLE test_many') + await con.execute('DROP TABLE IF EXISTS test_many') finally: await con.close() From 38984a68895697f013316311203bc9ca27b1dc4e Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Mon, 12 Dec 2022 09:27:59 +0100 Subject: [PATCH 12/23] disable tests for pg11 to see if all the rest of the test cases pass --- asyncpg/connect_utils.py | 7 ------- tests/test_connect.py | 6 ++++++ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 57edfac0..50b49d0a 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -10,7 +10,6 @@ import enum import functools import getpass -import logging import os import pathlib import platform @@ -31,8 +30,6 @@ from . import exceptions from . import protocol -logger = logging.getLogger(__name__) - class SSLMode(enum.IntEnum): disable = 0 @@ -907,10 +904,6 @@ async def can_be_used(connection): "SELECT pg_catalog.pg_is_in_recovery()" ) connection_eligible = is_in_hot_standby == should_be_in_hot_standby - logger.debug( - "Connection {!r} eligible=({!r}). Allow hot standby={!r}". - format(connection, connection_eligible, should_be_in_hot_standby) - ) return connection_eligible return can_be_used diff --git a/tests/test_connect.py b/tests/test_connect.py index a18a22e4..4935f776 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1757,6 +1757,8 @@ async def _run_connection_test( await conn.close() async def test_target_server_attribute_port(self): + if self.cluster.get_pg_version()[0] == 11: + self.skipTest("PostgreSQL 11 seems to have issues with this test") master_port = self.master_cluster.get_connection_spec()['port'] standby_port = self.standby_cluster.get_connection_spec()['port'] tests = [ @@ -1770,6 +1772,8 @@ async def test_target_server_attribute_port(self): ) async def test_target_attribute_not_matched(self): + if self.cluster.get_pg_version()[0] == 11: + self.skipTest("PostgreSQL 11 seems to have issues with this test") tests = [ (self.connect_standby, 'primary'), (self.connect_primary, 'standby'), @@ -1780,6 +1784,8 @@ async def test_target_attribute_not_matched(self): await connect(target_session_attribute=target_attr) async def test_prefer_standby_when_standby_is_up(self): + if self.cluster.get_pg_version()[0] == 11: + self.skipTest("PostgreSQL 11 seems to have issues with this test") con = await self.connect(target_session_attribute='prefer-standby') standby_port = self.standby_cluster.get_connection_spec()['port'] connected_host = _get_connected_host(con) From 86423c31e6118e68daf7965811cf44b3c64cf047 Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Thu, 15 Dec 2022 08:21:39 +0100 Subject: [PATCH 13/23] disable tests for pg11 to see if all the rest of the test cases pass --- tests/test_connect.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_connect.py b/tests/test_connect.py index 4935f776..18236f64 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1757,7 +1757,7 @@ async def _run_connection_test( await conn.close() async def test_target_server_attribute_port(self): - if self.cluster.get_pg_version()[0] == 11: + if self.master_cluster.get_pg_version()[0] == 11: self.skipTest("PostgreSQL 11 seems to have issues with this test") master_port = self.master_cluster.get_connection_spec()['port'] standby_port = self.standby_cluster.get_connection_spec()['port'] @@ -1772,7 +1772,7 @@ async def test_target_server_attribute_port(self): ) async def test_target_attribute_not_matched(self): - if self.cluster.get_pg_version()[0] == 11: + if self.master_cluster.get_pg_version()[0] == 11: self.skipTest("PostgreSQL 11 seems to have issues with this test") tests = [ (self.connect_standby, 'primary'), @@ -1784,7 +1784,7 @@ async def test_target_attribute_not_matched(self): await connect(target_session_attribute=target_attr) async def test_prefer_standby_when_standby_is_up(self): - if self.cluster.get_pg_version()[0] == 11: + if self.master_cluster.get_pg_version()[0] == 11: self.skipTest("PostgreSQL 11 seems to have issues with this test") con = await self.connect(target_session_attribute='prefer-standby') standby_port = self.standby_cluster.get_connection_spec()['port'] @@ -1793,6 +1793,8 @@ async def test_prefer_standby_when_standby_is_up(self): await con.close() async def test_prefer_standby_picks_master_when_standby_is_down(self): + if self.master_cluster.get_pg_version()[0] == 11: + self.skipTest("PostgreSQL 11 seems to have issues with this test") primary_spec = self.get_cluster_connection_spec(self.master_cluster) connection_spec = { 'host': [ From 277ed968a3c672a24fb192ba1500f49af2f258d0 Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Fri, 16 Dec 2022 09:02:40 +0100 Subject: [PATCH 14/23] add some more fixes that were already implemented --- tests/test_connect.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_connect.py b/tests/test_connect.py index 18236f64..51e92391 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1290,6 +1290,7 @@ def setUp(self): create_script = [] create_script.append('CREATE ROLE ssl_user WITH LOGIN;') + create_script.append('GRANT ALL ON SCHEMA public TO ssl_user;') self._add_hba_entry() @@ -1304,6 +1305,7 @@ def tearDown(self): self.cluster.trust_local_connections() drop_script = [] + drop_script.append('REVOKE ALL ON SCHEMA public FROM ssl_user;') drop_script.append('DROP ROLE ssl_user;') drop_script = '\n'.join(drop_script) self.loop.run_until_complete(self.con.execute(drop_script)) From 5737f767f7da5ddb687f61328ba8d948e674cc6a Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Fri, 16 Dec 2022 21:11:05 +0100 Subject: [PATCH 15/23] Add support for read-write and read-only target_session_attribute options --- asyncpg/_testbase/__init__.py | 1 + asyncpg/connect_utils.py | 25 +++++++++++++++++++++++++ tests/test_connect.py | 22 ++++++++++++++++++++++ 3 files changed, 48 insertions(+) diff --git a/asyncpg/_testbase/__init__.py b/asyncpg/_testbase/__init__.py index 3dd8a314..7aca834f 100644 --- a/asyncpg/_testbase/__init__.py +++ b/asyncpg/_testbase/__init__.py @@ -438,6 +438,7 @@ def tearDown(self): class HotStandbyTestCase(ClusterTestCase): + @classmethod def setup_cluster(cls): cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 50b49d0a..93fb0973 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -885,6 +885,8 @@ class SessionAttribute(str, enum.Enum): primary = 'primary' standby = 'standby' prefer_standby = 'prefer-standby' + read_write = "read-write" + read_only = "read-only" def _accept_in_hot_standby(should_be_in_hot_standby: bool): @@ -909,6 +911,27 @@ async def can_be_used(connection): return can_be_used +def _accept_read_only(should_be_read_only: bool): + """ + Verify the server has not set default_transaction_read_only=True + """ + async def can_be_used(connection): + settings = connection.get_settings() + is_read_only = getattr(settings, 'default_transaction_read_only', None) + if is_read_only is not None: + is_read_only = is_read_only == "on" + else: + is_read_only = False + if should_be_read_only: + if is_read_only: + return True + elif await _accept_in_hot_standby(True)(connection): + return True + return False + return _accept_in_hot_standby(False)(connection) + return can_be_used + + async def _accept_any(_): return True @@ -918,6 +941,8 @@ async def _accept_any(_): SessionAttribute.primary: _accept_in_hot_standby(False), SessionAttribute.standby: _accept_in_hot_standby(True), SessionAttribute.prefer_standby: _accept_in_hot_standby(True), + SessionAttribute.read_write: _accept_read_only(False), + SessionAttribute.read_only: _accept_read_only(True), } diff --git a/tests/test_connect.py b/tests/test_connect.py index 51e92391..5ccfd4a9 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1768,6 +1768,17 @@ async def test_target_server_attribute_port(self): (self.connect_standby, 'standby', standby_port), ] + for connect, target_attr, expected_port in tests: + await self._run_connection_test( + connect, target_attr, expected_port + ) + if self.master_cluster.get_pg_version()[0] < 14: + self.skipTest("PostgreSQL<14 does not support these features") + tests = [ + (self.connect_primary, 'read-write', master_port), + (self.connect_standby, 'read-only', standby_port), + ] + for connect, target_attr, expected_port in tests: await self._run_connection_test( connect, target_attr, expected_port @@ -1785,6 +1796,17 @@ async def test_target_attribute_not_matched(self): with self.assertRaises(exceptions.TargetServerAttributeNotMatched): await connect(target_session_attribute=target_attr) + if self.master_cluster.get_pg_version()[0] < 14: + self.skipTest("PostgreSQL<14 does not support these features") + tests = [ + (self.connect_standby, 'read-write'), + (self.connect_primary, 'read-only'), + ] + + for connect, target_attr in tests: + with self.assertRaises(exceptions.TargetServerAttributeNotMatched): + await connect(target_session_attribute=target_attr) + async def test_prefer_standby_when_standby_is_up(self): if self.master_cluster.get_pg_version()[0] == 11: self.skipTest("PostgreSQL 11 seems to have issues with this test") From c3133c04c9cc1b0cbe0805a42331f56eddc98f0c Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Fri, 16 Dec 2022 21:27:16 +0100 Subject: [PATCH 16/23] fix little logical error --- asyncpg/connect_utils.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 93fb0973..e5329982 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -917,18 +917,11 @@ def _accept_read_only(should_be_read_only: bool): """ async def can_be_used(connection): settings = connection.get_settings() - is_read_only = getattr(settings, 'default_transaction_read_only', None) - if is_read_only is not None: - is_read_only = is_read_only == "on" - else: - is_read_only = False - if should_be_read_only: - if is_read_only: - return True - elif await _accept_in_hot_standby(True)(connection): - return True - return False - return _accept_in_hot_standby(False)(connection) + is_read_only = getattr(settings, 'default_transaction_read_only', 'off') + + if should_be_read_only and is_read_only == "on": + return True + return await _accept_in_hot_standby(should_be_read_only)(connection) return can_be_used From ae325ba83c075ef11dfaa3cef86f427ab0ab17dc Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Fri, 16 Dec 2022 21:31:12 +0100 Subject: [PATCH 17/23] linter --- asyncpg/connect_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index e5329982..c966556b 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -917,9 +917,9 @@ def _accept_read_only(should_be_read_only: bool): """ async def can_be_used(connection): settings = connection.get_settings() - is_read_only = getattr(settings, 'default_transaction_read_only', 'off') + is_readonly = getattr(settings, 'default_transaction_read_only', 'off') - if should_be_read_only and is_read_only == "on": + if should_be_read_only and is_readonly == "on": return True return await _accept_in_hot_standby(should_be_read_only)(connection) return can_be_used From 3bc832245e5a49a0766dfa29469497688df4c2b0 Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Sun, 18 Dec 2022 10:25:05 +0100 Subject: [PATCH 18/23] fix logic issue --- asyncpg/connect_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index c966556b..5b78f27d 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -905,8 +905,7 @@ async def can_be_used(connection): is_in_hot_standby = await connection.fetchval( "SELECT pg_catalog.pg_is_in_recovery()" ) - connection_eligible = is_in_hot_standby == should_be_in_hot_standby - return connection_eligible + return is_in_hot_standby == should_be_in_hot_standby return can_be_used @@ -919,8 +918,9 @@ async def can_be_used(connection): settings = connection.get_settings() is_readonly = getattr(settings, 'default_transaction_read_only', 'off') - if should_be_read_only and is_readonly == "on": - return True + if is_readonly == "on": + return should_be_read_only + return await _accept_in_hot_standby(should_be_read_only)(connection) return can_be_used From d824333212e807178d4a40abaadcac5cb1848de2 Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Sun, 29 Jan 2023 10:54:07 +0100 Subject: [PATCH 19/23] Update based on review. --- asyncpg/connect_utils.py | 25 +++++++++++++++++++------ asyncpg/connection.py | 20 ++++++++------------ tests/test_connect.py | 14 +++++++------- 3 files changed, 34 insertions(+), 25 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index ccc19fa2..89c6c939 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -57,7 +57,7 @@ def parse(cls, sslmode): 'direct_tls', 'connect_timeout', 'server_settings', - 'target_session_attribute', + 'target_session_attrs', ]) @@ -258,7 +258,7 @@ def _dot_postgresql_path(filename) -> pathlib.Path: def _parse_connect_dsn_and_args(*, dsn, host, port, user, password, passfile, database, ssl, direct_tls, connect_timeout, server_settings, - target_session_attribute): + target_session_attrs): # `auth_hosts` is the version of host information for the purposes # of reading the pgpass file. auth_hosts = None @@ -595,11 +595,24 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, 'server_settings is expected to be None or ' 'a Dict[str, str]') + if target_session_attrs is None: + + target_session_attrs = os.getenv("PGTARGETSESSIONATTRS", SessionAttribute.any) + try: + + target_session_attrs = SessionAttribute(target_session_attrs) + except ValueError as exc: + raise exceptions.InterfaceError( + "target_session_attrs is expected to be one of " + "{!r}" + ", got {!r}".format(SessionAttribute.__members__.values, target_session_attrs) + ) from exc + params = _ConnectionParameters( user=user, password=password, database=database, ssl=ssl, sslmode=sslmode, direct_tls=direct_tls, connect_timeout=connect_timeout, server_settings=server_settings, - target_session_attribute=target_session_attribute) + target_session_attrs=target_session_attrs) return addrs, params @@ -610,7 +623,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, max_cached_statement_lifetime, max_cacheable_statement_size, ssl, direct_tls, server_settings, - target_session_attribute): + target_session_attrs): local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', @@ -639,7 +652,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, password=password, passfile=passfile, ssl=ssl, direct_tls=direct_tls, database=database, connect_timeout=timeout, server_settings=server_settings, - target_session_attribute=target_session_attribute) + target_session_attrs=target_session_attrs) config = _ClientConfiguration( command_timeout=command_timeout, @@ -941,7 +954,7 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs): loop = asyncio.get_event_loop() addrs, params, config = _parse_connect_arguments(timeout=timeout, **kwargs) - target_attr = params.target_session_attribute + target_attr = params.target_session_attrs candidates = [] chosen_connection = None diff --git a/asyncpg/connection.py b/asyncpg/connection.py index cec576f0..095ad398 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -1794,7 +1794,7 @@ async def connect(dsn=None, *, connection_class=Connection, record_class=protocol.Record, server_settings=None, - target_session_attribute=SessionAttribute.any): + target_session_attrs=None): r"""A coroutine to establish a connection to a PostgreSQL server. The connection parameters may be specified either as a connection @@ -2005,16 +2005,21 @@ async def connect(dsn=None, *, this connection object. Must be a subclass of :class:`~asyncpg.Record`. - :param SessionAttribute target_session_attribute: + :param SessionAttribute target_session_attrs: If specified, check that the host has the correct attribute. Can be one of: "any": the first successfully connected host "primary": the host must NOT be in hot standby mode "standby": the host must be in hot standby mode + "read-write": the host must allow writes + "read-only": the host most NOT allow writes "prefer-standby": first try to find a standby host, but if none of the listed hosts is a standby server, return any of them. + If not specified will try to use PGTARGETSESSIONATTRS from the environment. + Defaults to "any" if no value is set. + :return: A :class:`~asyncpg.connection.Connection` instance. Example: @@ -2099,15 +2104,6 @@ async def connect(dsn=None, *, if record_class is not protocol.Record: _check_record_class(record_class) - try: - target_session_attribute = SessionAttribute(target_session_attribute) - except ValueError as exc: - raise exceptions.InterfaceError( - "target_session_attribute is expected to be one of " - "'any', 'primary', 'standby' or 'prefer-standby'" - ", got {!r}".format(target_session_attribute) - ) from exc - if loop is None: loop = asyncio.get_event_loop() @@ -2130,7 +2126,7 @@ async def connect(dsn=None, *, statement_cache_size=statement_cache_size, max_cached_statement_lifetime=max_cached_statement_lifetime, max_cacheable_statement_size=max_cacheable_statement_size, - target_session_attribute=target_session_attribute + target_session_attrs=target_session_attrs ) diff --git a/tests/test_connect.py b/tests/test_connect.py index 3b8f69ff..baf2ebe5 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -788,7 +788,7 @@ def run_testcase(self, testcase): database = testcase.get('database') sslmode = testcase.get('ssl') server_settings = testcase.get('server_settings') - target_session_attribute = testcase.get('target_session_attribute') + target_session_attrs = testcase.get('target_session_attrs') expected = testcase.get('result') expected_error = testcase.get('error') @@ -813,7 +813,7 @@ def run_testcase(self, testcase): passfile=passfile, database=database, ssl=sslmode, direct_tls=False, connect_timeout=None, server_settings=server_settings, - target_session_attribute=target_session_attribute) + target_session_attrs=target_session_attrs) params = { k: v for k, v in params._asdict().items() @@ -1750,7 +1750,7 @@ class TestConnectionAttributes(tb.HotStandbyTestCase): async def _run_connection_test( self, connect, target_attribute, expected_port ): - conn = await connect(target_session_attribute=target_attribute) + conn = await connect(target_session_attrs=target_attribute) self.assertTrue(_get_connected_host(conn).endswith(expected_port)) await conn.close() @@ -1790,7 +1790,7 @@ async def test_target_attribute_not_matched(self): for connect, target_attr in tests: with self.assertRaises(exceptions.TargetServerAttributeNotMatched): - await connect(target_session_attribute=target_attr) + await connect(target_session_attrs=target_attr) if self.master_cluster.get_pg_version()[0] < 14: self.skipTest("PostgreSQL<14 does not support these features") @@ -1801,12 +1801,12 @@ async def test_target_attribute_not_matched(self): for connect, target_attr in tests: with self.assertRaises(exceptions.TargetServerAttributeNotMatched): - await connect(target_session_attribute=target_attr) + await connect(target_session_attrs=target_attr) async def test_prefer_standby_when_standby_is_up(self): if self.master_cluster.get_pg_version()[0] == 11: self.skipTest("PostgreSQL 11 seems to have issues with this test") - con = await self.connect(target_session_attribute='prefer-standby') + con = await self.connect(target_session_attrs='prefer-standby') standby_port = self.standby_cluster.get_connection_spec()['port'] connected_host = _get_connected_host(con) self.assertTrue(connected_host.endswith(standby_port)) @@ -1824,7 +1824,7 @@ async def test_prefer_standby_picks_master_when_standby_is_down(self): 'port': [primary_spec['port'], 15345], 'database': primary_spec['database'], 'user': primary_spec['user'], - 'target_session_attribute': 'prefer-standby' + 'target_session_attrs': 'prefer-standby' } con = await self.connect(**connection_spec) From 08ad1f84e5c5deaf6a6a5a2274fd4dd1b13eafc8 Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Sun, 29 Jan 2023 11:15:55 +0100 Subject: [PATCH 20/23] fix tests --- asyncpg/connect_utils.py | 8 +++-- asyncpg/connection.py | 4 +-- tests/test_connect.py | 66 ++++++++++++++++++++++++++++++++-------- 3 files changed, 61 insertions(+), 17 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 89c6c939..e0c10442 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -597,7 +597,9 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if target_session_attrs is None: - target_session_attrs = os.getenv("PGTARGETSESSIONATTRS", SessionAttribute.any) + target_session_attrs = os.getenv( + "PGTARGETSESSIONATTRS", SessionAttribute.any + ) try: target_session_attrs = SessionAttribute(target_session_attrs) @@ -605,7 +607,9 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, raise exceptions.InterfaceError( "target_session_attrs is expected to be one of " "{!r}" - ", got {!r}".format(SessionAttribute.__members__.values, target_session_attrs) + ", got {!r}".format( + SessionAttribute.__members__.values, target_session_attrs + ) ) from exc params = _ConnectionParameters( diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 095ad398..432fcef6 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -29,7 +29,6 @@ from . import serverversion from . import transaction from . import utils -from .connect_utils import SessionAttribute class ConnectionMeta(type): @@ -2017,7 +2016,8 @@ async def connect(dsn=None, *, none of the listed hosts is a standby server, return any of them. - If not specified will try to use PGTARGETSESSIONATTRS from the environment. + If not specified will try to use PGTARGETSESSIONATTRS + from the environment. Defaults to "any" if no value is set. :return: A :class:`~asyncpg.connection.Connection` instance. diff --git a/tests/test_connect.py b/tests/test_connect.py index baf2ebe5..b38963ad 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -384,7 +384,8 @@ class TestConnectParams(tb.TestCase): 'password': 'passw', 'database': 'testdb', 'ssl': True, - 'sslmode': SSLMode.prefer}) + 'sslmode': SSLMode.prefer, + 'target_session_attrs': 'any'}) }, { @@ -406,7 +407,8 @@ class TestConnectParams(tb.TestCase): 'result': ([('host2', 456)], { 'user': 'user2', 'password': 'passw2', - 'database': 'db2'}) + 'database': 'db2', + 'target_session_attrs': 'any'}) }, { @@ -434,7 +436,8 @@ class TestConnectParams(tb.TestCase): 'password': 'passw2', 'database': 'db2', 'sslmode': SSLMode.disable, - 'ssl': False}) + 'ssl': False, + 'target_session_attrs': 'any'}) }, { @@ -455,7 +458,8 @@ class TestConnectParams(tb.TestCase): 'password': '123123', 'database': 'abcdef', 'ssl': True, - 'sslmode': SSLMode.allow}) + 'sslmode': SSLMode.allow, + 'target_session_attrs': 'any'}) }, { @@ -483,7 +487,8 @@ class TestConnectParams(tb.TestCase): 'password': 'passw2', 'database': 'db2', 'sslmode': SSLMode.disable, - 'ssl': False}) + 'ssl': False, + 'target_session_attrs': 'any'}) }, { @@ -504,7 +509,8 @@ class TestConnectParams(tb.TestCase): 'password': '123123', 'database': 'abcdef', 'ssl': True, - 'sslmode': SSLMode.prefer}) + 'sslmode': SSLMode.prefer, + 'target_session_attrs': 'any'}) }, { @@ -513,7 +519,8 @@ class TestConnectParams(tb.TestCase): 'result': ([('localhost', 5555)], { 'user': 'user3', 'password': '123123', - 'database': 'abcdef'}) + 'database': 'abcdef', + 'target_session_attrs': 'any'}) }, { @@ -522,6 +529,7 @@ class TestConnectParams(tb.TestCase): 'result': ([('host1', 5432), ('host2', 5432)], { 'database': 'db', 'user': 'user', + 'target_session_attrs': 'any', }) }, @@ -531,6 +539,7 @@ class TestConnectParams(tb.TestCase): 'result': ([('host1', 1111), ('host2', 2222)], { 'database': 'db', 'user': 'user', + 'target_session_attrs': 'any', }) }, @@ -540,6 +549,7 @@ class TestConnectParams(tb.TestCase): 'result': ([('2001:db8::1234%eth0', 5432), ('::1', 5432)], { 'database': 'db', 'user': 'user', + 'target_session_attrs': 'any', }) }, @@ -549,6 +559,7 @@ class TestConnectParams(tb.TestCase): 'result': ([('2001:db8::1234', 1111), ('::1', 2222)], { 'database': 'db', 'user': 'user', + 'target_session_attrs': 'any', }) }, @@ -558,6 +569,7 @@ class TestConnectParams(tb.TestCase): 'result': ([('2001:db8::1234', 5432), ('::1', 5432)], { 'database': 'db', 'user': 'user', + 'target_session_attrs': 'any', }) }, @@ -572,6 +584,7 @@ class TestConnectParams(tb.TestCase): 'result': ([('host1', 1111), ('host2', 2222)], { 'database': 'db', 'user': 'foo', + 'target_session_attrs': 'any', }) }, @@ -584,6 +597,7 @@ class TestConnectParams(tb.TestCase): 'result': ([('host1', 1111), ('host2', 2222)], { 'database': 'db', 'user': 'foo', + 'target_session_attrs': 'any', }) }, @@ -597,6 +611,7 @@ class TestConnectParams(tb.TestCase): 'result': ([('host1', 5432), ('host2', 5432)], { 'database': 'db', 'user': 'foo', + 'target_session_attrs': 'any', }) }, @@ -616,7 +631,8 @@ class TestConnectParams(tb.TestCase): 'password': 'ask', 'database': 'db', 'ssl': True, - 'sslmode': SSLMode.require}) + 'sslmode': SSLMode.require, + 'target_session_attrs': 'any'}) }, { @@ -637,7 +653,8 @@ class TestConnectParams(tb.TestCase): 'password': 'ask', 'database': 'db', 'sslmode': SSLMode.verify_full, - 'ssl': True}) + 'ssl': True, + 'target_session_attrs': 'any'}) }, { @@ -645,7 +662,8 @@ class TestConnectParams(tb.TestCase): 'dsn': 'postgresql:///dbname?host=/unix_sock/test&user=spam', 'result': ([os.path.join('/unix_sock/test', '.s.PGSQL.5432')], { 'user': 'spam', - 'database': 'dbname'}) + 'database': 'dbname', + 'target_session_attrs': 'any'}) }, { @@ -657,6 +675,7 @@ class TestConnectParams(tb.TestCase): 'user': 'us@r', 'password': 'p@ss', 'database': 'db', + 'target_session_attrs': 'any', } ) }, @@ -670,6 +689,7 @@ class TestConnectParams(tb.TestCase): 'user': 'user', 'password': 'p', 'database': 'db', + 'target_session_attrs': 'any', } ) }, @@ -682,6 +702,7 @@ class TestConnectParams(tb.TestCase): { 'user': 'us@r', 'database': 'db', + 'target_session_attrs': 'any', } ) }, @@ -709,7 +730,8 @@ class TestConnectParams(tb.TestCase): 'user': 'user', 'database': 'user', 'sslmode': SSLMode.disable, - 'ssl': None + 'ssl': None, + 'target_session_attrs': 'any', } ) }, @@ -723,7 +745,8 @@ class TestConnectParams(tb.TestCase): '.s.PGSQL.5432' )], { 'user': 'spam', - 'database': 'db' + 'database': 'db', + 'target_session_attrs': 'any', } ) }, @@ -744,6 +767,7 @@ class TestConnectParams(tb.TestCase): 'database': 'db', 'ssl': True, 'sslmode': SSLMode.prefer, + 'target_session_attrs': 'any', } ) }, @@ -874,7 +898,9 @@ def test_test_connect_params_run_testcase(self): 'host': 'abc', 'result': ( [('abc', 5432)], - {'user': '__test__', 'database': '__test__'} + {'user': '__test__', + 'database': '__test__', + 'target_session_attrs': 'any'} ) }) @@ -912,6 +938,7 @@ def test_connect_pgpass_regular(self): 'password': 'password from pgpass for user@abc', 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -928,6 +955,7 @@ def test_connect_pgpass_regular(self): 'password': 'password from pgpass for user@abc', 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -942,6 +970,7 @@ def test_connect_pgpass_regular(self): 'password': 'password from pgpass for user@abc', 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -957,6 +986,7 @@ def test_connect_pgpass_regular(self): 'password': 'password from pgpass for localhost', 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -974,6 +1004,7 @@ def test_connect_pgpass_regular(self): 'password': 'password from pgpass for localhost', 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -991,6 +1022,7 @@ def test_connect_pgpass_regular(self): 'password': 'password from pgpass for cde:5433', 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -1007,6 +1039,7 @@ def test_connect_pgpass_regular(self): 'password': 'password from pgpass for testuser', 'user': 'testuser', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -1023,6 +1056,7 @@ def test_connect_pgpass_regular(self): 'password': 'password from pgpass for testdb', 'user': 'user', 'database': 'testdb', + 'target_session_attrs': 'any', } ) }) @@ -1039,6 +1073,7 @@ def test_connect_pgpass_regular(self): 'password': 'password from pgpass with escapes', 'user': R'test\\', 'database': R'test\:db', + 'target_session_attrs': 'any', } ) }) @@ -1066,6 +1101,7 @@ def test_connect_pgpass_badness_mode(self): { 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -1086,6 +1122,7 @@ def test_connect_pgpass_badness_non_file(self): { 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -1102,6 +1139,7 @@ def test_connect_pgpass_nonexistent(self): { 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -1122,6 +1160,7 @@ def test_connect_pgpass_inaccessible_file(self): { 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) @@ -1144,6 +1183,7 @@ def test_connect_pgpass_inaccessible_directory(self): { 'user': 'user', 'database': 'db', + 'target_session_attrs': 'any', } ) }) From 7cdd2ba2212e78b2f65fe89f6bf9d9e076325d5c Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Sun, 29 Jan 2023 11:19:44 +0100 Subject: [PATCH 21/23] See the results on pg11 --- tests/test_connect.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_connect.py b/tests/test_connect.py index b38963ad..3b56241b 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1795,8 +1795,8 @@ async def _run_connection_test( await conn.close() async def test_target_server_attribute_port(self): - if self.master_cluster.get_pg_version()[0] == 11: - self.skipTest("PostgreSQL 11 seems to have issues with this test") + #if self.master_cluster.get_pg_version()[0] == 11: + # self.skipTest("PostgreSQL 11 seems to have issues with this test") master_port = self.master_cluster.get_connection_spec()['port'] standby_port = self.standby_cluster.get_connection_spec()['port'] tests = [ From 19b7a17c8a3be0181027c1e03a224bfc60254a8d Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Mon, 30 Jan 2023 12:39:06 +0100 Subject: [PATCH 22/23] ... --- tests/test_connect.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_connect.py b/tests/test_connect.py index 3b56241b..02a6a50b 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1795,8 +1795,6 @@ async def _run_connection_test( await conn.close() async def test_target_server_attribute_port(self): - #if self.master_cluster.get_pg_version()[0] == 11: - # self.skipTest("PostgreSQL 11 seems to have issues with this test") master_port = self.master_cluster.get_connection_spec()['port'] standby_port = self.standby_cluster.get_connection_spec()['port'] tests = [ From ea002ed903dff5a15d8a0647ffccb8cdf70f546d Mon Sep 17 00:00:00 2001 From: Jesse De Loore Date: Thu, 4 May 2023 19:47:31 +0200 Subject: [PATCH 23/23] Apply fix --- asyncpg/cluster.py | 2 +- tests/test_connect.py | 6 ------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/asyncpg/cluster.py b/asyncpg/cluster.py index 0999e41c..4467cc2a 100644 --- a/asyncpg/cluster.py +++ b/asyncpg/cluster.py @@ -626,7 +626,7 @@ def init(self, **settings): 'pg_basebackup init exited with status {:d}:\n{}'.format( process.returncode, output.decode())) - if self._pg_version <= (11, 0): + if self._pg_version < (12, 0): with open(os.path.join(self._data_dir, 'recovery.conf'), 'w') as f: f.write(textwrap.dedent("""\ standby_mode = 'on' diff --git a/tests/test_connect.py b/tests/test_connect.py index 02a6a50b..3701b5e2 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1819,8 +1819,6 @@ async def test_target_server_attribute_port(self): ) async def test_target_attribute_not_matched(self): - if self.master_cluster.get_pg_version()[0] == 11: - self.skipTest("PostgreSQL 11 seems to have issues with this test") tests = [ (self.connect_standby, 'primary'), (self.connect_primary, 'standby'), @@ -1842,8 +1840,6 @@ async def test_target_attribute_not_matched(self): await connect(target_session_attrs=target_attr) async def test_prefer_standby_when_standby_is_up(self): - if self.master_cluster.get_pg_version()[0] == 11: - self.skipTest("PostgreSQL 11 seems to have issues with this test") con = await self.connect(target_session_attrs='prefer-standby') standby_port = self.standby_cluster.get_connection_spec()['port'] connected_host = _get_connected_host(con) @@ -1851,8 +1847,6 @@ async def test_prefer_standby_when_standby_is_up(self): await con.close() async def test_prefer_standby_picks_master_when_standby_is_down(self): - if self.master_cluster.get_pg_version()[0] == 11: - self.skipTest("PostgreSQL 11 seems to have issues with this test") primary_spec = self.get_cluster_connection_spec(self.master_cluster) connection_spec = { 'host': [