@@ -42,11 +42,12 @@ class Connection(metaclass=ConnectionMeta):
42
42
'_stmt_cache' , '_stmts_to_close' ,
43
43
'_addr' , '_opts' , '_command_timeout' , '_listeners' ,
44
44
'_server_version' , '_server_caps' , '_intro_query' ,
45
- '_reset_query' , '_proxy' , '_stmt_exclusive_section' )
45
+ '_reset_query' , '_proxy' , '_stmt_exclusive_section' ,
46
+ '_ssl_context' )
46
47
47
48
def __init__ (self , protocol , transport , loop , addr , opts , * ,
48
49
statement_cache_size , command_timeout ,
49
- max_cached_statement_lifetime ):
50
+ max_cached_statement_lifetime , ssl_context ):
50
51
self ._protocol = protocol
51
52
self ._transport = transport
52
53
self ._loop = loop
@@ -58,6 +59,7 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
58
59
59
60
self ._addr = addr
60
61
self ._opts = opts
62
+ self ._ssl_context = ssl_context
61
63
62
64
self ._stmt_cache = _StatementCache (
63
65
loop = loop ,
@@ -521,12 +523,24 @@ async def cancel():
521
523
r , w = await asyncio .open_unix_connection (
522
524
self ._addr , loop = self ._loop )
523
525
else :
524
- r , w = await asyncio .open_connection (
525
- * self ._addr , loop = self ._loop )
526
-
527
- sock = w .transport .get_extra_info ('socket' )
528
- sock .setsockopt (socket .IPPROTO_TCP ,
529
- socket .TCP_NODELAY , 1 )
526
+ if self ._ssl_context :
527
+ sock = await _get_ssl_ready_socket (
528
+ * self ._addr , loop = self ._loop )
529
+
530
+ try :
531
+ r , w = await asyncio .open_connection (
532
+ sock = sock ,
533
+ loop = self ._loop ,
534
+ ssl = self ._ssl_context ,
535
+ server_hostname = self ._addr [0 ])
536
+ except Exception :
537
+ sock .close ()
538
+ raise
539
+
540
+ else :
541
+ r , w = await asyncio .open_connection (
542
+ * self ._addr , loop = self ._loop )
543
+ _set_nodelay (_get_socket (w .transport ))
530
544
531
545
# Pack CancelRequest message
532
546
msg = struct .pack ('!llll' , 16 , 80877102 ,
@@ -708,9 +722,10 @@ async def connect(dsn=None, *,
708
722
statement_cache_size = 100 ,
709
723
max_cached_statement_lifetime = 300 ,
710
724
command_timeout = None ,
725
+ ssl = None ,
711
726
__connection_class__ = Connection ,
712
727
** opts ):
713
- """A coroutine to establish a connection to a PostgreSQL server.
728
+ r """A coroutine to establish a connection to a PostgreSQL server.
714
729
715
730
Returns a new :class:`~asyncpg.connection.Connection` object.
716
731
@@ -761,6 +776,12 @@ async def connect(dsn=None, *,
761
776
the default timeout for operations on this connection
762
777
(the default is no timeout).
763
778
779
+ :param ssl:
780
+ pass ``True`` or an `ssl.SSLContext <SSLContext_>`_ instance to
781
+ require an SSL connection. If ``True``, a default SSL context
782
+ returned by `ssl.create_default_context() <create_default_context_>`_
783
+ will be used.
784
+
764
785
:return: A :class:`~asyncpg.connection.Connection` instance.
765
786
766
787
Example:
@@ -778,42 +799,51 @@ async def connect(dsn=None, *,
778
799
779
800
.. versionchanged:: 0.10.0
780
801
Added ``max_cached_statement_use_count`` parameter.
802
+
803
+ .. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
804
+ .. _create_default_context: https://docs.python.org/3/library/ssl.html#\
805
+ ssl.create_default_context
781
806
"""
782
807
if loop is None :
783
808
loop = asyncio .get_event_loop ()
784
809
785
- host , port , opts = _parse_connect_params (
810
+ addrs , opts = _parse_connect_params (
786
811
dsn = dsn , host = host , port = port , user = user , password = password ,
787
812
database = database , opts = opts )
788
813
789
- last_ex = None
814
+ if ssl :
815
+ for addr in addrs :
816
+ if isinstance (addr , str ):
817
+ # UNIX socket
818
+ raise exceptions .InterfaceError (
819
+ '`ssl` parameter can only be enabled for TCP addresses, '
820
+ 'got a UNIX socket path: {!r}' .format (addr ))
821
+
822
+ last_error = None
790
823
addr = None
791
- for h in host :
824
+ for addr in addrs :
792
825
connected = _create_future (loop )
793
- unix = h .startswith ('/' )
794
-
795
- if unix :
796
- # UNIX socket name
797
- addr = h
798
- if '.s.PGSQL.' not in addr :
799
- addr = os .path .join (addr , '.s.PGSQL.{}' .format (port ))
800
- conn = loop .create_unix_connection (
801
- lambda : protocol .Protocol (addr , connected , opts , loop ),
802
- addr )
826
+ proto_factory = lambda : protocol .Protocol (addr , connected , opts , loop )
827
+
828
+ if isinstance (addr , str ):
829
+ # UNIX socket
830
+ assert ssl is None
831
+ connector = loop .create_unix_connection (proto_factory , addr )
832
+ elif ssl :
833
+ connector = _create_ssl_connection (
834
+ proto_factory , * addr , loop = loop , ssl_context = ssl )
803
835
else :
804
- addr = (h , port )
805
- conn = loop .create_connection (
806
- lambda : protocol .Protocol (addr , connected , opts , loop ),
807
- h , port )
836
+ connector = loop .create_connection (proto_factory , * addr )
808
837
809
838
try :
810
- tr , pr = await asyncio .wait_for (conn , timeout = timeout , loop = loop )
811
- except (OSError , asyncio .TimeoutError ) as ex :
812
- last_ex = ex
839
+ tr , pr = await asyncio .wait_for (
840
+ connector , timeout = timeout , loop = loop )
841
+ except (OSError , asyncio .TimeoutError , ConnectionError ) as ex :
842
+ last_error = ex
813
843
else :
814
844
break
815
845
else :
816
- raise last_ex
846
+ raise last_error
817
847
818
848
try :
819
849
await connected
@@ -825,12 +855,60 @@ async def connect(dsn=None, *,
825
855
pr , tr , loop , addr , opts ,
826
856
statement_cache_size = statement_cache_size ,
827
857
max_cached_statement_lifetime = max_cached_statement_lifetime ,
828
- command_timeout = command_timeout )
858
+ command_timeout = command_timeout , ssl_context = ssl )
829
859
830
860
pr .set_connection (con )
831
861
return con
832
862
833
863
864
+ async def _get_ssl_ready_socket (host , port , * , loop ):
865
+ reader , writer = await asyncio .open_connection (host , port , loop = loop )
866
+
867
+ tr = writer .transport
868
+ try :
869
+ sock = _get_socket (tr )
870
+ _set_nodelay (sock )
871
+
872
+ writer .write (struct .pack ('!ll' , 8 , 80877103 )) # SSLRequest message.
873
+ await writer .drain ()
874
+ resp = await reader .readexactly (1 )
875
+
876
+ if resp == b'S' :
877
+ return sock .dup ()
878
+ else :
879
+ raise ConnectionError (
880
+ 'PostgreSQL server at "{}:{}" rejected SSL upgrade' .format (
881
+ host , port ))
882
+ finally :
883
+ tr .close ()
884
+
885
+
886
+ async def _create_ssl_connection (protocol_factory , host , port , * ,
887
+ loop , ssl_context ):
888
+ sock = await _get_ssl_ready_socket (host , port , loop = loop )
889
+ try :
890
+ return await loop .create_connection (
891
+ protocol_factory , sock = sock , ssl = ssl_context ,
892
+ server_hostname = host )
893
+ except Exception :
894
+ sock .close ()
895
+ raise
896
+
897
+
898
+ def _get_socket (transport ):
899
+ sock = transport .get_extra_info ('socket' )
900
+ if sock is None :
901
+ # Shouldn't happen with any asyncio-complaint event loop.
902
+ raise ConnectionError (
903
+ 'could not get the socket for transport {!r}' .format (transport ))
904
+ return sock
905
+
906
+
907
+ def _set_nodelay (sock ):
908
+ if not hasattr (socket , 'AF_UNIX' ) or sock .family != socket .AF_UNIX :
909
+ sock .setsockopt (socket .IPPROTO_TCP , socket .TCP_NODELAY , 1 )
910
+
911
+
834
912
class _StatementCacheEntry :
835
913
836
914
__slots__ = ('_query' , '_statement' , '_cache' , '_cleanup_cb' )
@@ -1116,7 +1194,18 @@ def _parse_connect_params(*, dsn, host, port, user,
1116
1194
'invalid connection parameter {!r}: {!r} (str expected)'
1117
1195
.format (param , opts [param ]))
1118
1196
1119
- return host , port , opts
1197
+ addrs = []
1198
+ for h in host :
1199
+ if h .startswith ('/' ):
1200
+ # UNIX socket name
1201
+ if '.s.PGSQL.' not in h :
1202
+ h = os .path .join (h , '.s.PGSQL.{}' .format (port ))
1203
+ addrs .append (h )
1204
+ else :
1205
+ # TCP host/port
1206
+ addrs .append ((h , port ))
1207
+
1208
+ return addrs , opts
1120
1209
1121
1210
1122
1211
def _create_future (loop ):
0 commit comments