1313import os
1414import pathlib
1515import platform
16+ import random
1617import re
1718import socket
1819import ssl as ssl_module
@@ -56,6 +57,7 @@ def parse(cls, sslmode):
5657 'direct_tls' ,
5758 'connect_timeout' ,
5859 'server_settings' ,
60+ 'target_session_attribute' ,
5961 ])
6062
6163
@@ -259,7 +261,8 @@ def _dot_postgresql_path(filename) -> pathlib.Path:
259261
260262def _parse_connect_dsn_and_args (* , dsn , host , port , user ,
261263 password , passfile , database , ssl ,
262- direct_tls , connect_timeout , server_settings ):
264+ direct_tls , connect_timeout , server_settings ,
265+ target_session_attribute ):
263266 # `auth_hosts` is the version of host information for the purposes
264267 # of reading the pgpass file.
265268 auth_hosts = None
@@ -603,7 +606,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
603606 params = _ConnectionParameters (
604607 user = user , password = password , database = database , ssl = ssl ,
605608 sslmode = sslmode , direct_tls = direct_tls ,
606- connect_timeout = connect_timeout , server_settings = server_settings )
609+ connect_timeout = connect_timeout , server_settings = server_settings ,
610+ target_session_attribute = target_session_attribute )
607611
608612 return addrs , params
609613
@@ -613,8 +617,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
613617 statement_cache_size ,
614618 max_cached_statement_lifetime ,
615619 max_cacheable_statement_size ,
616- ssl , direct_tls , server_settings ):
617-
620+ ssl , direct_tls , server_settings ,
621+ target_session_attribute ):
618622 local_vars = locals ()
619623 for var_name in {'max_cacheable_statement_size' ,
620624 'max_cached_statement_lifetime' ,
@@ -642,7 +646,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
642646 dsn = dsn , host = host , port = port , user = user ,
643647 password = password , passfile = passfile , ssl = ssl ,
644648 direct_tls = direct_tls , database = database ,
645- connect_timeout = timeout , server_settings = server_settings )
649+ connect_timeout = timeout , server_settings = server_settings ,
650+ target_session_attribute = target_session_attribute )
646651
647652 config = _ClientConfiguration (
648653 command_timeout = command_timeout ,
@@ -875,18 +880,64 @@ async def __connect_addr(
875880 return con
876881
877882
883+ class SessionAttribute (str , enum .Enum ):
884+ any = 'any'
885+ primary = 'primary'
886+ standby = 'standby'
887+ prefer_standby = 'prefer-standby'
888+
889+
890+ def _accept_in_hot_standby (should_be_in_hot_standby : bool ):
891+ """
892+ If the server didn't report "in_hot_standby" at startup, we must determine
893+ the state by checking "SELECT pg_catalog.pg_is_in_recovery()".
894+ """
895+ async def can_be_used (connection ):
896+ settings = connection .get_settings ()
897+ hot_standby_status = getattr (settings , 'in_hot_standby' , None )
898+ if hot_standby_status is not None :
899+ is_in_hot_standby = hot_standby_status == 'on'
900+ else :
901+ is_in_hot_standby = await connection .fetchval (
902+ "SELECT pg_catalog.pg_is_in_recovery()"
903+ )
904+
905+ return is_in_hot_standby == should_be_in_hot_standby
906+
907+ return can_be_used
908+
909+
910+ async def _accept_any (_ ):
911+ return True
912+
913+
914+ target_attrs_check = {
915+ SessionAttribute .any : _accept_any ,
916+ SessionAttribute .primary : _accept_in_hot_standby (False ),
917+ SessionAttribute .standby : _accept_in_hot_standby (True ),
918+ SessionAttribute .prefer_standby : _accept_in_hot_standby (True ),
919+ }
920+
921+
922+ async def _can_use_connection (connection , attr : SessionAttribute ):
923+ can_use = target_attrs_check [attr ]
924+ return await can_use (connection )
925+
926+
878927async def _connect (* , loop , timeout , connection_class , record_class , ** kwargs ):
879928 if loop is None :
880929 loop = asyncio .get_event_loop ()
881930
882931 addrs , params , config = _parse_connect_arguments (timeout = timeout , ** kwargs )
932+ target_attr = params .target_session_attribute
883933
934+ candidates = []
935+ chosen_connection = None
884936 last_error = None
885- addr = None
886937 for addr in addrs :
887938 before = time .monotonic ()
888939 try :
889- return await _connect_addr (
940+ conn = await _connect_addr (
890941 addr = addr ,
891942 loop = loop ,
892943 timeout = timeout ,
@@ -895,12 +946,30 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
895946 connection_class = connection_class ,
896947 record_class = record_class ,
897948 )
949+ candidates .append (conn )
950+ if await _can_use_connection (conn , target_attr ):
951+ chosen_connection = conn
952+ break
898953 except (OSError , asyncio .TimeoutError , ConnectionError ) as ex :
899954 last_error = ex
900955 finally :
901956 timeout -= time .monotonic () - before
957+ else :
958+ if target_attr == SessionAttribute .prefer_standby and candidates :
959+ chosen_connection = random .choice (candidates )
960+
961+ await asyncio .gather (
962+ (c .close () for c in candidates if c is not chosen_connection ),
963+ return_exceptions = True
964+ )
965+
966+ if chosen_connection :
967+ return chosen_connection
902968
903- raise last_error
969+ raise last_error or exceptions .TargetServerAttributeNotMatched (
970+ 'None of the hosts match the target attribute requirement '
971+ '{!r}' .format (target_attr )
972+ )
904973
905974
906975async def _cancel (* , loop , addr , params : _ConnectionParameters ,
0 commit comments