From 5fb06edbbb769561e245d0fe13002bab50e2ae60 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 3 May 2021 00:34:15 +0300 Subject: [PATCH] bpo-44011: New asyncio ssl implementation (#17975) --- Lib/asyncio/base_events.py | 43 +- Lib/asyncio/constants.py | 7 + Lib/asyncio/events.py | 19 +- Lib/asyncio/proactor_events.py | 12 +- Lib/asyncio/selector_events.py | 31 +- Lib/asyncio/sslproto.py | 1058 +++++----- Lib/asyncio/unix_events.py | 17 +- Lib/test/test_asyncio/test_base_events.py | 21 +- Lib/test/test_asyncio/test_selector_events.py | 38 - Lib/test/test_asyncio/test_ssl.py | 1718 +++++++++++++++++ Lib/test/test_asyncio/test_sslproto.py | 35 +- .../2021-05-02-23-44-21.bpo-44011.hd8iUO.rst | 2 + 12 files changed, 2474 insertions(+), 527 deletions(-) create mode 100644 Lib/test/test_asyncio/test_ssl.py create mode 100644 Misc/NEWS.d/next/Library/2021-05-02-23-44-21.bpo-44011.hd8iUO.rst diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index f789635e0f893a..e54ee309e42e6f 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -273,7 +273,7 @@ async def restore(self): class Server(events.AbstractServer): def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog, - ssl_handshake_timeout): + ssl_handshake_timeout, ssl_shutdown_timeout=None): self._loop = loop self._sockets = sockets self._active_count = 0 @@ -282,6 +282,7 @@ def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog, self._backlog = backlog self._ssl_context = ssl_context self._ssl_handshake_timeout = ssl_handshake_timeout + self._ssl_shutdown_timeout = ssl_shutdown_timeout self._serving = False self._serving_forever_fut = None @@ -313,7 +314,8 @@ def _start_serving(self): sock.listen(self._backlog) self._loop._start_serving( self._protocol_factory, sock, self._ssl_context, - self, self._backlog, self._ssl_handshake_timeout) + self, self._backlog, self._ssl_handshake_timeout, + self._ssl_shutdown_timeout) def get_loop(self): return self._loop @@ -467,6 +469,7 @@ def _make_ssl_transport( *, server_side=False, server_hostname=None, extra=None, server=None, ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, call_connection_made=True): """Create SSL transport.""" raise NotImplementedError @@ -969,6 +972,7 @@ async def create_connection( proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, happy_eyeballs_delay=None, interleave=None): """Connect to a TCP server. @@ -1004,6 +1008,10 @@ async def create_connection( raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None and not ssl: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') + if happy_eyeballs_delay is not None and interleave is None: # If using happy eyeballs, default to interleave addresses by family interleave = 1 @@ -1079,7 +1087,8 @@ async def create_connection( transport, protocol = await self._create_connection_transport( sock, protocol_factory, ssl, server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) if self._debug: # Get the socket from the transport because SSL transport closes # the old socket and creates a new SSL socket @@ -1091,7 +1100,8 @@ async def create_connection( async def _create_connection_transport( self, sock, protocol_factory, ssl, server_hostname, server_side=False, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): sock.setblocking(False) @@ -1102,7 +1112,8 @@ async def _create_connection_transport( transport = self._make_ssl_transport( sock, protocol, sslcontext, waiter, server_side=server_side, server_hostname=server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) else: transport = self._make_socket_transport(sock, protocol, waiter) @@ -1193,7 +1204,8 @@ async def _sendfile_fallback(self, transp, file, offset, count): async def start_tls(self, transport, protocol, sslcontext, *, server_side=False, server_hostname=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): """Upgrade transport to TLS. Return a new transport that *protocol* should start using @@ -1216,6 +1228,7 @@ async def start_tls(self, transport, protocol, sslcontext, *, self, protocol, sslcontext, waiter, server_side, server_hostname, ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout, call_connection_made=False) # Pause early so that "ssl_protocol.data_received()" doesn't @@ -1414,6 +1427,7 @@ async def create_server( reuse_address=None, reuse_port=None, ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, start_serving=True): """Create a TCP server. @@ -1437,6 +1451,10 @@ async def create_server( raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None and ssl is None: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') + if host is not None or port is not None: if sock is not None: raise ValueError( @@ -1509,7 +1527,8 @@ async def create_server( sock.setblocking(False) server = Server(self, sockets, protocol_factory, - ssl, backlog, ssl_handshake_timeout) + ssl, backlog, ssl_handshake_timeout, + ssl_shutdown_timeout) if start_serving: server._start_serving() # Skip one loop iteration so that all 'loop.add_reader' @@ -1523,7 +1542,8 @@ async def create_server( async def connect_accepted_socket( self, protocol_factory, sock, *, ssl=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): if sock.type != socket.SOCK_STREAM: raise ValueError(f'A Stream Socket was expected, got {sock!r}') @@ -1531,9 +1551,14 @@ async def connect_accepted_socket( raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None and not ssl: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') + transport, protocol = await self._create_connection_transport( sock, protocol_factory, ssl, '', server_side=True, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) if self._debug: # Get the socket from the transport because SSL transport closes # the old socket and creates a new SSL socket diff --git a/Lib/asyncio/constants.py b/Lib/asyncio/constants.py index 33feed60e55b00..f171ead28fecd3 100644 --- a/Lib/asyncio/constants.py +++ b/Lib/asyncio/constants.py @@ -15,10 +15,17 @@ # The default timeout matches that of Nginx. SSL_HANDSHAKE_TIMEOUT = 60.0 +# Number of seconds to wait for SSL shutdown to complete +# The default timeout mimics lingering_time +SSL_SHUTDOWN_TIMEOUT = 30.0 + # Used in sendfile fallback code. We use fallback for platforms # that don't support sendfile, or for TLS connections. SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 256 +FLOW_CONTROL_HIGH_WATER_SSL_READ = 256 # KiB +FLOW_CONTROL_HIGH_WATER_SSL_WRITE = 512 # KiB + # The enum should be here to break circular dependencies between # base_events and sslproto class _SendfileMode(enum.Enum): diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index b966ad26bf467b..d5254fa5e7e73e 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -304,6 +304,7 @@ async def create_connection( flags=0, sock=None, local_addr=None, server_hostname=None, ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, happy_eyeballs_delay=None, interleave=None): raise NotImplementedError @@ -313,6 +314,7 @@ async def create_server( flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None, reuse_port=None, ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, start_serving=True): """A coroutine which creates a TCP server bound to host and port. @@ -353,6 +355,10 @@ async def create_server( will wait for completion of the SSL handshake before aborting the connection. Default is 60s. + ssl_shutdown_timeout is the time in seconds that an SSL server + will wait for completion of the SSL shutdown procedure + before aborting the connection. Default is 30s. + start_serving set to True (default) causes the created server to start accepting connections immediately. When set to False, the user should await Server.start_serving() or Server.serve_forever() @@ -371,7 +377,8 @@ async def sendfile(self, transport, file, offset=0, count=None, async def start_tls(self, transport, protocol, sslcontext, *, server_side=False, server_hostname=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): """Upgrade a transport to TLS. Return a new transport that *protocol* should start using @@ -383,13 +390,15 @@ async def create_unix_connection( self, protocol_factory, path=None, *, ssl=None, sock=None, server_hostname=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): raise NotImplementedError async def create_unix_server( self, protocol_factory, path=None, *, sock=None, backlog=100, ssl=None, ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, start_serving=True): """A coroutine which creates a UNIX Domain Socket server. @@ -411,6 +420,9 @@ async def create_unix_server( ssl_handshake_timeout is the time in seconds that an SSL server will wait for the SSL handshake to complete (defaults to 60s). + ssl_shutdown_timeout is the time in seconds that an SSL server + will wait for the SSL shutdown to finish (defaults to 30s). + start_serving set to True (default) causes the created server to start accepting connections immediately. When set to False, the user should await Server.start_serving() or Server.serve_forever() @@ -421,7 +433,8 @@ async def create_unix_server( async def connect_accepted_socket( self, protocol_factory, sock, *, ssl=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): """Handle an accepted connection. This is used by servers that accept connections outside of diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index 45c11ee4b487ec..10852afe2b4138 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -642,11 +642,13 @@ def _make_ssl_transport( self, rawsock, protocol, sslcontext, waiter=None, *, server_side=False, server_hostname=None, extra=None, server=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): ssl_protocol = sslproto.SSLProtocol( self, protocol, sslcontext, waiter, server_side, server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) _ProactorSocketTransport(self, rawsock, ssl_protocol, extra=extra, server=server) return ssl_protocol._app_transport @@ -812,7 +814,8 @@ def _write_to_self(self): def _start_serving(self, protocol_factory, sock, sslcontext=None, server=None, backlog=100, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): def loop(f=None): try: @@ -826,7 +829,8 @@ def loop(f=None): self._make_ssl_transport( conn, protocol, sslcontext, server_side=True, extra={'peername': addr}, server=server, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) else: self._make_socket_transport( conn, protocol, diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 59cb6b1babec54..63ab15f30fb5d7 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -70,11 +70,15 @@ def _make_ssl_transport( self, rawsock, protocol, sslcontext, waiter=None, *, server_side=False, server_hostname=None, extra=None, server=None, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, + ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, + ): ssl_protocol = sslproto.SSLProtocol( - self, protocol, sslcontext, waiter, - server_side, server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout) + self, protocol, sslcontext, waiter, + server_side, server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout + ) _SelectorSocketTransport(self, rawsock, ssl_protocol, extra=extra, server=server) return ssl_protocol._app_transport @@ -146,15 +150,17 @@ def _write_to_self(self): def _start_serving(self, protocol_factory, sock, sslcontext=None, server=None, backlog=100, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, + ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT): self._add_reader(sock.fileno(), self._accept_connection, protocol_factory, sock, sslcontext, server, backlog, - ssl_handshake_timeout) + ssl_handshake_timeout, ssl_shutdown_timeout) def _accept_connection( self, protocol_factory, sock, sslcontext=None, server=None, backlog=100, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, + ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT): # This method is only called once for each event loop tick where the # listening socket has triggered an EVENT_READ. There may be multiple # connections waiting for an .accept() so it is called in a loop. @@ -185,20 +191,22 @@ def _accept_connection( self.call_later(constants.ACCEPT_RETRY_DELAY, self._start_serving, protocol_factory, sock, sslcontext, server, - backlog, ssl_handshake_timeout) + backlog, ssl_handshake_timeout, + ssl_shutdown_timeout) else: raise # The event loop will catch, log and ignore it. else: extra = {'peername': addr} accept = self._accept_connection2( protocol_factory, conn, extra, sslcontext, server, - ssl_handshake_timeout) + ssl_handshake_timeout, ssl_shutdown_timeout) self.create_task(accept) async def _accept_connection2( self, protocol_factory, conn, extra, sslcontext=None, server=None, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, + ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT): protocol = None transport = None try: @@ -208,7 +216,8 @@ async def _accept_connection2( transport = self._make_ssl_transport( conn, protocol, sslcontext, waiter=waiter, server_side=True, extra=extra, server=server, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) else: transport = self._make_socket_transport( conn, protocol, waiter=waiter, extra=extra, diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py index cad25b26539f5a..e71875ba9f0093 100644 --- a/Lib/asyncio/sslproto.py +++ b/Lib/asyncio/sslproto.py @@ -1,4 +1,5 @@ import collections +import enum import warnings try: import ssl @@ -6,10 +7,37 @@ ssl = None from . import constants +from . import exceptions from . import protocols from . import transports from .log import logger +SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError) + + +class SSLProtocolState(enum.Enum): + UNWRAPPED = "UNWRAPPED" + DO_HANDSHAKE = "DO_HANDSHAKE" + WRAPPED = "WRAPPED" + FLUSHING = "FLUSHING" + SHUTDOWN = "SHUTDOWN" + + +class AppProtocolState(enum.Enum): + # This tracks the state of app protocol (https://git.io/fj59P): + # + # INIT -cm-> CON_MADE [-dr*->] [-er-> EOF?] -cl-> CON_LOST + # + # * cm: connection_made() + # * dr: data_received() + # * er: eof_received() + # * cl: connection_lost() + + STATE_INIT = "STATE_INIT" + STATE_CON_MADE = "STATE_CON_MADE" + STATE_EOF = "STATE_EOF" + STATE_CON_LOST = "STATE_CON_LOST" + def _create_transport_context(server_side, server_hostname): if server_side: @@ -25,269 +53,35 @@ def _create_transport_context(server_side, server_hostname): return sslcontext -# States of an _SSLPipe. -_UNWRAPPED = "UNWRAPPED" -_DO_HANDSHAKE = "DO_HANDSHAKE" -_WRAPPED = "WRAPPED" -_SHUTDOWN = "SHUTDOWN" - - -class _SSLPipe(object): - """An SSL "Pipe". - - An SSL pipe allows you to communicate with an SSL/TLS protocol instance - through memory buffers. It can be used to implement a security layer for an - existing connection where you don't have access to the connection's file - descriptor, or for some reason you don't want to use it. - - An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode, - data is passed through untransformed. In wrapped mode, application level - data is encrypted to SSL record level data and vice versa. The SSL record - level is the lowest level in the SSL protocol suite and is what travels - as-is over the wire. - - An SslPipe initially is in "unwrapped" mode. To start SSL, call - do_handshake(). To shutdown SSL again, call unwrap(). - """ - - max_size = 256 * 1024 # Buffer size passed to read() - - def __init__(self, context, server_side, server_hostname=None): - """ - The *context* argument specifies the ssl.SSLContext to use. - - The *server_side* argument indicates whether this is a server side or - client side transport. - - The optional *server_hostname* argument can be used to specify the - hostname you are connecting to. You may only specify this parameter if - the _ssl module supports Server Name Indication (SNI). - """ - self._context = context - self._server_side = server_side - self._server_hostname = server_hostname - self._state = _UNWRAPPED - self._incoming = ssl.MemoryBIO() - self._outgoing = ssl.MemoryBIO() - self._sslobj = None - self._need_ssldata = False - self._handshake_cb = None - self._shutdown_cb = None - - @property - def context(self): - """The SSL context passed to the constructor.""" - return self._context - - @property - def ssl_object(self): - """The internal ssl.SSLObject instance. - - Return None if the pipe is not wrapped. - """ - return self._sslobj - - @property - def need_ssldata(self): - """Whether more record level data is needed to complete a handshake - that is currently in progress.""" - return self._need_ssldata - - @property - def wrapped(self): - """ - Whether a security layer is currently in effect. - - Return False during handshake. - """ - return self._state == _WRAPPED - - def do_handshake(self, callback=None): - """Start the SSL handshake. - - Return a list of ssldata. A ssldata element is a list of buffers - - The optional *callback* argument can be used to install a callback that - will be called when the handshake is complete. The callback will be - called with None if successful, else an exception instance. - """ - if self._state != _UNWRAPPED: - raise RuntimeError('handshake in progress or completed') - self._sslobj = self._context.wrap_bio( - self._incoming, self._outgoing, - server_side=self._server_side, - server_hostname=self._server_hostname) - self._state = _DO_HANDSHAKE - self._handshake_cb = callback - ssldata, appdata = self.feed_ssldata(b'', only_handshake=True) - assert len(appdata) == 0 - return ssldata - - def shutdown(self, callback=None): - """Start the SSL shutdown sequence. - - Return a list of ssldata. A ssldata element is a list of buffers - - The optional *callback* argument can be used to install a callback that - will be called when the shutdown is complete. The callback will be - called without arguments. - """ - if self._state == _UNWRAPPED: - raise RuntimeError('no security layer present') - if self._state == _SHUTDOWN: - raise RuntimeError('shutdown in progress') - assert self._state in (_WRAPPED, _DO_HANDSHAKE) - self._state = _SHUTDOWN - self._shutdown_cb = callback - ssldata, appdata = self.feed_ssldata(b'') - assert appdata == [] or appdata == [b''] - return ssldata - - def feed_eof(self): - """Send a potentially "ragged" EOF. - - This method will raise an SSL_ERROR_EOF exception if the EOF is - unexpected. - """ - self._incoming.write_eof() - ssldata, appdata = self.feed_ssldata(b'') - assert appdata == [] or appdata == [b''] - - def feed_ssldata(self, data, only_handshake=False): - """Feed SSL record level data into the pipe. - - The data must be a bytes instance. It is OK to send an empty bytes - instance. This can be used to get ssldata for a handshake initiated by - this endpoint. - - Return a (ssldata, appdata) tuple. The ssldata element is a list of - buffers containing SSL data that needs to be sent to the remote SSL. - - The appdata element is a list of buffers containing plaintext data that - needs to be forwarded to the application. The appdata list may contain - an empty buffer indicating an SSL "close_notify" alert. This alert must - be acknowledged by calling shutdown(). - """ - if self._state == _UNWRAPPED: - # If unwrapped, pass plaintext data straight through. - if data: - appdata = [data] - else: - appdata = [] - return ([], appdata) - - self._need_ssldata = False - if data: - self._incoming.write(data) - - ssldata = [] - appdata = [] - try: - if self._state == _DO_HANDSHAKE: - # Call do_handshake() until it doesn't raise anymore. - self._sslobj.do_handshake() - self._state = _WRAPPED - if self._handshake_cb: - self._handshake_cb(None) - if only_handshake: - return (ssldata, appdata) - # Handshake done: execute the wrapped block - - if self._state == _WRAPPED: - # Main state: read data from SSL until close_notify - while True: - chunk = self._sslobj.read(self.max_size) - appdata.append(chunk) - if not chunk: # close_notify - break +def add_flowcontrol_defaults(high, low, kb): + if high is None: + if low is None: + hi = kb * 1024 + else: + lo = low + hi = 4 * lo + else: + hi = high + if low is None: + lo = hi // 4 + else: + lo = low - elif self._state == _SHUTDOWN: - # Call shutdown() until it doesn't raise anymore. - self._sslobj.unwrap() - self._sslobj = None - self._state = _UNWRAPPED - if self._shutdown_cb: - self._shutdown_cb() - - elif self._state == _UNWRAPPED: - # Drain possible plaintext data after close_notify. - appdata.append(self._incoming.read()) - except (ssl.SSLError, ssl.CertificateError) as exc: - exc_errno = getattr(exc, 'errno', None) - if exc_errno not in ( - ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, - ssl.SSL_ERROR_SYSCALL): - if self._state == _DO_HANDSHAKE and self._handshake_cb: - self._handshake_cb(exc) - raise - self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ) - - # Check for record level data that needs to be sent back. - # Happens for the initial handshake and renegotiations. - if self._outgoing.pending: - ssldata.append(self._outgoing.read()) - return (ssldata, appdata) - - def feed_appdata(self, data, offset=0): - """Feed plaintext data into the pipe. - - Return an (ssldata, offset) tuple. The ssldata element is a list of - buffers containing record level data that needs to be sent to the - remote SSL instance. The offset is the number of plaintext bytes that - were processed, which may be less than the length of data. - - NOTE: In case of short writes, this call MUST be retried with the SAME - buffer passed into the *data* argument (i.e. the id() must be the - same). This is an OpenSSL requirement. A further particularity is that - a short write will always have offset == 0, because the _ssl module - does not enable partial writes. And even though the offset is zero, - there will still be encrypted data in ssldata. - """ - assert 0 <= offset <= len(data) - if self._state == _UNWRAPPED: - # pass through data in unwrapped mode - if offset < len(data): - ssldata = [data[offset:]] - else: - ssldata = [] - return (ssldata, len(data)) + if not hi >= lo >= 0: + raise ValueError('high (%r) must be >= low (%r) must be >= 0' % + (hi, lo)) - ssldata = [] - view = memoryview(data) - while True: - self._need_ssldata = False - try: - if offset < len(view): - offset += self._sslobj.write(view[offset:]) - except ssl.SSLError as exc: - # It is not allowed to call write() after unwrap() until the - # close_notify is acknowledged. We return the condition to the - # caller as a short write. - exc_errno = getattr(exc, 'errno', None) - if exc.reason == 'PROTOCOL_IS_SHUTDOWN': - exc_errno = exc.errno = ssl.SSL_ERROR_WANT_READ - if exc_errno not in (ssl.SSL_ERROR_WANT_READ, - ssl.SSL_ERROR_WANT_WRITE, - ssl.SSL_ERROR_SYSCALL): - raise - self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ) - - # See if there's any record level data back for us. - if self._outgoing.pending: - ssldata.append(self._outgoing.read()) - if offset == len(view) or self._need_ssldata: - break - return (ssldata, offset) + return hi, lo class _SSLProtocolTransport(transports._FlowControlMixin, transports.Transport): + _start_tls_compatible = True _sendfile_compatible = constants._SendfileMode.FALLBACK def __init__(self, loop, ssl_protocol): self._loop = loop - # SSLProtocol instance self._ssl_protocol = ssl_protocol self._closed = False @@ -315,16 +109,15 @@ def close(self): self._closed = True self._ssl_protocol._start_shutdown() - def __del__(self, _warn=warnings.warn): + def __del__(self, _warnings=warnings): if not self._closed: - _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) - self.close() + self._closed = True + _warnings.warn( + "unclosed transport ", ResourceWarning) def is_reading(self): - tr = self._ssl_protocol._transport - if tr is None: - raise RuntimeError('SSL transport has not been initialized yet') - return tr.is_reading() + return not self._ssl_protocol._app_reading_paused def pause_reading(self): """Pause the receiving end. @@ -332,7 +125,7 @@ def pause_reading(self): No data will be passed to the protocol's data_received() method until resume_reading() is called. """ - self._ssl_protocol._transport.pause_reading() + self._ssl_protocol._pause_reading() def resume_reading(self): """Resume the receiving end. @@ -340,7 +133,7 @@ def resume_reading(self): Data received will once again be passed to the protocol's data_received() method. """ - self._ssl_protocol._transport.resume_reading() + self._ssl_protocol._resume_reading() def set_write_buffer_limits(self, high=None, low=None): """Set the high- and low-water limits for write flow control. @@ -361,16 +154,51 @@ def set_write_buffer_limits(self, high=None, low=None): reduces opportunities for doing I/O and computation concurrently. """ - self._ssl_protocol._transport.set_write_buffer_limits(high, low) + self._ssl_protocol._set_write_buffer_limits(high, low) + self._ssl_protocol._control_app_writing() + + def get_write_buffer_limits(self): + return (self._ssl_protocol._outgoing_low_water, + self._ssl_protocol._outgoing_high_water) def get_write_buffer_size(self): - """Return the current size of the write buffer.""" - return self._ssl_protocol._transport.get_write_buffer_size() + """Return the current size of the write buffers.""" + return self._ssl_protocol._get_write_buffer_size() + + def set_read_buffer_limits(self, high=None, low=None): + """Set the high- and low-water limits for read flow control. + + These two values control when to call the upstream transport's + pause_reading() and resume_reading() methods. If specified, + the low-water limit must be less than or equal to the + high-water limit. Neither value can be negative. + + The defaults are implementation-specific. If only the + high-water limit is given, the low-water limit defaults to an + implementation-specific value less than or equal to the + high-water limit. Setting high to zero forces low to zero as + well, and causes pause_reading() to be called whenever the + buffer becomes non-empty. Setting low to zero causes + resume_reading() to be called only once the buffer is empty. + Use of zero for either limit is generally sub-optimal as it + reduces opportunities for doing I/O and computation + concurrently. + """ + self._ssl_protocol._set_read_buffer_limits(high, low) + self._ssl_protocol._control_ssl_reading() + + def get_read_buffer_limits(self): + return (self._ssl_protocol._incoming_low_water, + self._ssl_protocol._incoming_high_water) + + def get_read_buffer_size(self): + """Return the current size of the read buffer.""" + return self._ssl_protocol._get_read_buffer_size() @property def _protocol_paused(self): # Required for sendfile fallback pause_writing/resume_writing logic - return self._ssl_protocol._transport._protocol_paused + return self._ssl_protocol._app_writing_paused def write(self, data): """Write some data bytes to the transport. @@ -383,7 +211,22 @@ def write(self, data): f"got {type(data).__name__}") if not data: return - self._ssl_protocol._write_appdata(data) + self._ssl_protocol._write_appdata((data,)) + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation concatenates the arguments and + calls write() on the result. + """ + self._ssl_protocol._write_appdata(list_of_data) + + def write_eof(self): + """Close the write end after flushing buffered data. + + This raises :exc:`NotImplementedError` right now. + """ + raise NotImplementedError def can_write_eof(self): """Return True if this transport supports write_eof(), False if not.""" @@ -396,23 +239,36 @@ def abort(self): The protocol's connection_lost() method will (eventually) be called with None as its argument. """ + self._closed = True self._ssl_protocol._abort() + + def _force_close(self, exc): self._closed = True + self._ssl_protocol._abort(exc) + def _test__append_write_backlog(self, data): + # for test only + self._ssl_protocol._write_backlog.append(data) + self._ssl_protocol._write_buffer_size += len(data) -class SSLProtocol(protocols.Protocol): - """SSL protocol. - Implementation of SSL on top of a socket using incoming and outgoing - buffers which are ssl.MemoryBIO objects. - """ +class SSLProtocol(protocols.BufferedProtocol): + max_size = 256 * 1024 # Buffer size passed to read() + + _handshake_start_time = None + _handshake_timeout_handle = None + _shutdown_timeout_handle = None def __init__(self, loop, app_protocol, sslcontext, waiter, server_side=False, server_hostname=None, call_connection_made=True, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): if ssl is None: - raise RuntimeError('stdlib ssl module not available') + raise RuntimeError("stdlib ssl module not available") + + self._ssl_buffer = bytearray(self.max_size) + self._ssl_buffer_view = memoryview(self._ssl_buffer) if ssl_handshake_timeout is None: ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT @@ -420,6 +276,12 @@ def __init__(self, loop, app_protocol, sslcontext, waiter, raise ValueError( f"ssl_handshake_timeout should be a positive number, " f"got {ssl_handshake_timeout}") + if ssl_shutdown_timeout is None: + ssl_shutdown_timeout = constants.SSL_SHUTDOWN_TIMEOUT + elif ssl_shutdown_timeout <= 0: + raise ValueError( + f"ssl_shutdown_timeout should be a positive number, " + f"got {ssl_shutdown_timeout}") if not sslcontext: sslcontext = _create_transport_context( @@ -442,21 +304,54 @@ def __init__(self, loop, app_protocol, sslcontext, waiter, self._waiter = waiter self._loop = loop self._set_app_protocol(app_protocol) - self._app_transport = _SSLProtocolTransport(self._loop, self) - # _SSLPipe instance (None until the connection is made) - self._sslpipe = None - self._session_established = False - self._in_handshake = False - self._in_shutdown = False + self._app_transport = None + self._app_transport_created = False # transport, ex: SelectorSocketTransport self._transport = None - self._call_connection_made = call_connection_made self._ssl_handshake_timeout = ssl_handshake_timeout + self._ssl_shutdown_timeout = ssl_shutdown_timeout + # SSL and state machine + self._incoming = ssl.MemoryBIO() + self._outgoing = ssl.MemoryBIO() + self._state = SSLProtocolState.UNWRAPPED + self._conn_lost = 0 # Set when connection_lost called + if call_connection_made: + self._app_state = AppProtocolState.STATE_INIT + else: + self._app_state = AppProtocolState.STATE_CON_MADE + self._sslobj = self._sslcontext.wrap_bio( + self._incoming, self._outgoing, + server_side=self._server_side, + server_hostname=self._server_hostname) + + # Flow Control + + self._ssl_writing_paused = False + + self._app_reading_paused = False + + self._ssl_reading_paused = False + self._incoming_high_water = 0 + self._incoming_low_water = 0 + self._set_read_buffer_limits() + self._eof_received = False + + self._app_writing_paused = False + self._outgoing_high_water = 0 + self._outgoing_low_water = 0 + self._set_write_buffer_limits() + self._get_app_transport() def _set_app_protocol(self, app_protocol): self._app_protocol = app_protocol - self._app_protocol_is_buffer = \ - isinstance(app_protocol, protocols.BufferedProtocol) + # Make fast hasattr check first + if (hasattr(app_protocol, 'get_buffer') and + isinstance(app_protocol, protocols.BufferedProtocol)): + self._app_protocol_get_buffer = app_protocol.get_buffer + self._app_protocol_buffer_updated = app_protocol.buffer_updated + self._app_protocol_is_buffer = True + else: + self._app_protocol_is_buffer = False def _wakeup_waiter(self, exc=None): if self._waiter is None: @@ -468,15 +363,20 @@ def _wakeup_waiter(self, exc=None): self._waiter.set_result(None) self._waiter = None + def _get_app_transport(self): + if self._app_transport is None: + if self._app_transport_created: + raise RuntimeError('Creating _SSLProtocolTransport twice') + self._app_transport = _SSLProtocolTransport(self._loop, self) + self._app_transport_created = True + return self._app_transport + def connection_made(self, transport): """Called when the low-level connection is made. Start the SSL handshake. """ self._transport = transport - self._sslpipe = _SSLPipe(self._sslcontext, - self._server_side, - self._server_hostname) self._start_handshake() def connection_lost(self, exc): @@ -486,72 +386,58 @@ def connection_lost(self, exc): meaning a regular EOF is received or the connection was aborted or closed). """ - if self._session_established: - self._session_established = False - self._loop.call_soon(self._app_protocol.connection_lost, exc) - else: - # Most likely an exception occurred while in SSL handshake. - # Just mark the app transport as closed so that its __del__ - # doesn't complain. - if self._app_transport is not None: - self._app_transport._closed = True + self._write_backlog.clear() + self._outgoing.read() + self._conn_lost += 1 + + # Just mark the app transport as closed so that its __dealloc__ + # doesn't complain. + if self._app_transport is not None: + self._app_transport._closed = True + + if self._state != SSLProtocolState.DO_HANDSHAKE: + if ( + self._app_state == AppProtocolState.STATE_CON_MADE or + self._app_state == AppProtocolState.STATE_EOF + ): + self._app_state = AppProtocolState.STATE_CON_LOST + self._loop.call_soon(self._app_protocol.connection_lost, exc) + self._set_state(SSLProtocolState.UNWRAPPED) self._transport = None self._app_transport = None - if getattr(self, '_handshake_timeout_handle', None): - self._handshake_timeout_handle.cancel() - self._wakeup_waiter(exc) self._app_protocol = None - self._sslpipe = None + self._wakeup_waiter(exc) - def pause_writing(self): - """Called when the low-level transport's buffer goes over - the high-water mark. - """ - self._app_protocol.pause_writing() + if self._shutdown_timeout_handle: + self._shutdown_timeout_handle.cancel() + self._shutdown_timeout_handle = None + if self._handshake_timeout_handle: + self._handshake_timeout_handle.cancel() + self._handshake_timeout_handle = None - def resume_writing(self): - """Called when the low-level transport's buffer drains below - the low-water mark. - """ - self._app_protocol.resume_writing() + def get_buffer(self, n): + want = n + if want <= 0 or want > self.max_size: + want = self.max_size + if len(self._ssl_buffer) < want: + self._ssl_buffer = bytearray(want) + self._ssl_buffer_view = memoryview(self._ssl_buffer) + return self._ssl_buffer_view - def data_received(self, data): - """Called when some SSL data is received. + def buffer_updated(self, nbytes): + self._incoming.write(self._ssl_buffer_view[:nbytes]) - The argument is a bytes object. - """ - if self._sslpipe is None: - # transport closing, sslpipe is destroyed - return + if self._state == SSLProtocolState.DO_HANDSHAKE: + self._do_handshake() - try: - ssldata, appdata = self._sslpipe.feed_ssldata(data) - except (SystemExit, KeyboardInterrupt): - raise - except BaseException as e: - self._fatal_error(e, 'SSL error in data received') - return + elif self._state == SSLProtocolState.WRAPPED: + self._do_read() - for chunk in ssldata: - self._transport.write(chunk) + elif self._state == SSLProtocolState.FLUSHING: + self._do_flush() - for chunk in appdata: - if chunk: - try: - if self._app_protocol_is_buffer: - protocols._feed_data_to_buffered_proto( - self._app_protocol, chunk) - else: - self._app_protocol.data_received(chunk) - except (SystemExit, KeyboardInterrupt): - raise - except BaseException as ex: - self._fatal_error( - ex, 'application protocol failed to receive SSL data') - return - else: - self._start_shutdown() - break + elif self._state == SSLProtocolState.SHUTDOWN: + self._do_shutdown() def eof_received(self): """Called when the other end of the low-level stream @@ -561,19 +447,32 @@ def eof_received(self): will close itself. If it returns a true value, closing the transport is up to the protocol. """ + self._eof_received = True try: if self._loop.get_debug(): logger.debug("%r received EOF", self) - self._wakeup_waiter(ConnectionResetError) + if self._state == SSLProtocolState.DO_HANDSHAKE: + self._on_handshake_complete(ConnectionResetError) - if not self._in_handshake: - keep_open = self._app_protocol.eof_received() - if keep_open: - logger.warning('returning true from eof_received() ' - 'has no effect when using ssl') - finally: + elif self._state == SSLProtocolState.WRAPPED: + self._set_state(SSLProtocolState.FLUSHING) + if self._app_reading_paused: + return True + else: + self._do_flush() + + elif self._state == SSLProtocolState.FLUSHING: + self._do_write() + self._set_state(SSLProtocolState.SHUTDOWN) + self._do_shutdown() + + elif self._state == SSLProtocolState.SHUTDOWN: + self._do_shutdown() + + except Exception: self._transport.close() + raise def _get_extra_info(self, name, default=None): if name in self._extra: @@ -583,19 +482,45 @@ def _get_extra_info(self, name, default=None): else: return default - def _start_shutdown(self): - if self._in_shutdown: - return - if self._in_handshake: - self._abort() + def _set_state(self, new_state): + allowed = False + + if new_state == SSLProtocolState.UNWRAPPED: + allowed = True + + elif ( + self._state == SSLProtocolState.UNWRAPPED and + new_state == SSLProtocolState.DO_HANDSHAKE + ): + allowed = True + + elif ( + self._state == SSLProtocolState.DO_HANDSHAKE and + new_state == SSLProtocolState.WRAPPED + ): + allowed = True + + elif ( + self._state == SSLProtocolState.WRAPPED and + new_state == SSLProtocolState.FLUSHING + ): + allowed = True + + elif ( + self._state == SSLProtocolState.FLUSHING and + new_state == SSLProtocolState.SHUTDOWN + ): + allowed = True + + if allowed: + self._state = new_state + else: - self._in_shutdown = True - self._write_appdata(b'') + raise RuntimeError( + 'cannot switch state from {} to {}'.format( + self._state, new_state)) - def _write_appdata(self, data): - self._write_backlog.append((data, 0)) - self._write_buffer_size += len(data) - self._process_write_backlog() + # Handshake flow def _start_handshake(self): if self._loop.get_debug(): @@ -603,17 +528,18 @@ def _start_handshake(self): self._handshake_start_time = self._loop.time() else: self._handshake_start_time = None - self._in_handshake = True - # (b'', 1) is a special value in _process_write_backlog() to do - # the SSL handshake - self._write_backlog.append((b'', 1)) + + self._set_state(SSLProtocolState.DO_HANDSHAKE) + + # start handshake timeout count down self._handshake_timeout_handle = \ self._loop.call_later(self._ssl_handshake_timeout, - self._check_handshake_timeout) - self._process_write_backlog() + lambda: self._check_handshake_timeout()) + + self._do_handshake() def _check_handshake_timeout(self): - if self._in_handshake is True: + if self._state == SSLProtocolState.DO_HANDSHAKE: msg = ( f"SSL handshake is taking longer than " f"{self._ssl_handshake_timeout} seconds: " @@ -621,24 +547,37 @@ def _check_handshake_timeout(self): ) self._fatal_error(ConnectionAbortedError(msg)) + def _do_handshake(self): + try: + self._sslobj.do_handshake() + except SSLAgainErrors: + self._process_outgoing() + except ssl.SSLError as exc: + self._on_handshake_complete(exc) + else: + self._on_handshake_complete(None) + def _on_handshake_complete(self, handshake_exc): - self._in_handshake = False - self._handshake_timeout_handle.cancel() + if self._handshake_timeout_handle is not None: + self._handshake_timeout_handle.cancel() + self._handshake_timeout_handle = None - sslobj = self._sslpipe.ssl_object + sslobj = self._sslobj try: - if handshake_exc is not None: + if handshake_exc is None: + self._set_state(SSLProtocolState.WRAPPED) + else: raise handshake_exc peercert = sslobj.getpeercert() - except (SystemExit, KeyboardInterrupt): - raise - except BaseException as exc: + except Exception as exc: + self._set_state(SSLProtocolState.UNWRAPPED) if isinstance(exc, ssl.CertificateError): msg = 'SSL handshake failed on verifying the certificate' else: msg = 'SSL handshake failed' self._fatal_error(exc, msg) + self._wakeup_waiter(exc) return if self._loop.get_debug(): @@ -649,85 +588,330 @@ def _on_handshake_complete(self, handshake_exc): self._extra.update(peercert=peercert, cipher=sslobj.cipher(), compression=sslobj.compression(), - ssl_object=sslobj, - ) - if self._call_connection_made: - self._app_protocol.connection_made(self._app_transport) + ssl_object=sslobj) + if self._app_state == AppProtocolState.STATE_INIT: + self._app_state = AppProtocolState.STATE_CON_MADE + self._app_protocol.connection_made(self._get_app_transport()) self._wakeup_waiter() - self._session_established = True - # In case transport.write() was already called. Don't call - # immediately _process_write_backlog(), but schedule it: - # _on_handshake_complete() can be called indirectly from - # _process_write_backlog(), and _process_write_backlog() is not - # reentrant. - self._loop.call_soon(self._process_write_backlog) - - def _process_write_backlog(self): - # Try to make progress on the write backlog. - if self._transport is None or self._sslpipe is None: + self._do_read() + + # Shutdown flow + + def _start_shutdown(self): + if ( + self._state in ( + SSLProtocolState.FLUSHING, + SSLProtocolState.SHUTDOWN, + SSLProtocolState.UNWRAPPED + ) + ): + return + if self._app_transport is not None: + self._app_transport._closed = True + if self._state == SSLProtocolState.DO_HANDSHAKE: + self._abort() + else: + self._set_state(SSLProtocolState.FLUSHING) + self._shutdown_timeout_handle = self._loop.call_later( + self._ssl_shutdown_timeout, + lambda: self._check_shutdown_timeout() + ) + self._do_flush() + + def _check_shutdown_timeout(self): + if ( + self._state in ( + SSLProtocolState.FLUSHING, + SSLProtocolState.SHUTDOWN + ) + ): + self._transport._force_close( + exceptions.TimeoutError('SSL shutdown timed out')) + + def _do_flush(self): + self._do_read() + self._set_state(SSLProtocolState.SHUTDOWN) + self._do_shutdown() + + def _do_shutdown(self): + try: + if not self._eof_received: + self._sslobj.unwrap() + except SSLAgainErrors: + self._process_outgoing() + except ssl.SSLError as exc: + self._on_shutdown_complete(exc) + else: + self._process_outgoing() + self._call_eof_received() + self._on_shutdown_complete(None) + + def _on_shutdown_complete(self, shutdown_exc): + if self._shutdown_timeout_handle is not None: + self._shutdown_timeout_handle.cancel() + self._shutdown_timeout_handle = None + + if shutdown_exc: + self._fatal_error(shutdown_exc) + else: + self._loop.call_soon(self._transport.close) + + def _abort(self): + self._set_state(SSLProtocolState.UNWRAPPED) + if self._transport is not None: + self._transport.abort() + + # Outgoing flow + + def _write_appdata(self, list_of_data): + if ( + self._state in ( + SSLProtocolState.FLUSHING, + SSLProtocolState.SHUTDOWN, + SSLProtocolState.UNWRAPPED + ) + ): + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('SSL connection is closed') + self._conn_lost += 1 return + for data in list_of_data: + self._write_backlog.append(data) + self._write_buffer_size += len(data) + try: - for i in range(len(self._write_backlog)): - data, offset = self._write_backlog[0] - if data: - ssldata, offset = self._sslpipe.feed_appdata(data, offset) - elif offset: - ssldata = self._sslpipe.do_handshake( - self._on_handshake_complete) - offset = 1 + if self._state == SSLProtocolState.WRAPPED: + self._do_write() + + except Exception as ex: + self._fatal_error(ex, 'Fatal error on SSL protocol') + + def _do_write(self): + try: + while self._write_backlog: + data = self._write_backlog[0] + count = self._sslobj.write(data) + data_len = len(data) + if count < data_len: + self._write_backlog[0] = data[count:] + self._write_buffer_size -= count else: - ssldata = self._sslpipe.shutdown(self._finalize) - offset = 1 - - for chunk in ssldata: - self._transport.write(chunk) - - if offset < len(data): - self._write_backlog[0] = (data, offset) - # A short write means that a write is blocked on a read - # We need to enable reading if it is paused! - assert self._sslpipe.need_ssldata - if self._transport._paused: - self._transport.resume_reading() - break + del self._write_backlog[0] + self._write_buffer_size -= data_len + except SSLAgainErrors: + pass + self._process_outgoing() + + def _process_outgoing(self): + if not self._ssl_writing_paused: + data = self._outgoing.read() + if len(data): + self._transport.write(data) + self._control_app_writing() + + # Incoming flow + + def _do_read(self): + if ( + self._state not in ( + SSLProtocolState.WRAPPED, + SSLProtocolState.FLUSHING, + ) + ): + return + try: + if not self._app_reading_paused: + if self._app_protocol_is_buffer: + self._do_read__buffered() + else: + self._do_read__copied() + if self._write_backlog: + self._do_write() + else: + self._process_outgoing() + self._control_ssl_reading() + except Exception as ex: + self._fatal_error(ex, 'Fatal error on SSL protocol') + + def _do_read__buffered(self): + offset = 0 + count = 1 + + buf = self._app_protocol_get_buffer(self._get_read_buffer_size()) + wants = len(buf) - # An entire chunk from the backlog was processed. We can - # delete it and reduce the outstanding buffer size. - del self._write_backlog[0] - self._write_buffer_size -= len(data) - except (SystemExit, KeyboardInterrupt): + try: + count = self._sslobj.read(wants, buf) + + if count > 0: + offset = count + while offset < wants: + count = self._sslobj.read(wants - offset, buf[offset:]) + if count > 0: + offset += count + else: + break + else: + self._loop.call_soon(lambda: self._do_read()) + except SSLAgainErrors: + pass + if offset > 0: + self._app_protocol_buffer_updated(offset) + if not count: + # close_notify + self._call_eof_received() + self._start_shutdown() + + def _do_read__copied(self): + chunk = b'1' + zero = True + one = False + + try: + while True: + chunk = self._sslobj.read(self.max_size) + if not chunk: + break + if zero: + zero = False + one = True + first = chunk + elif one: + one = False + data = [first, chunk] + else: + data.append(chunk) + except SSLAgainErrors: + pass + if one: + self._app_protocol.data_received(first) + elif not zero: + self._app_protocol.data_received(b''.join(data)) + if not chunk: + # close_notify + self._call_eof_received() + self._start_shutdown() + + def _call_eof_received(self): + try: + if self._app_state == AppProtocolState.STATE_CON_MADE: + self._app_state = AppProtocolState.STATE_EOF + keep_open = self._app_protocol.eof_received() + if keep_open: + logger.warning('returning true from eof_received() ' + 'has no effect when using ssl') + except (KeyboardInterrupt, SystemExit): raise - except BaseException as exc: - if self._in_handshake: - # Exceptions will be re-raised in _on_handshake_complete. - self._on_handshake_complete(exc) - else: - self._fatal_error(exc, 'Fatal error on SSL transport') + except BaseException as ex: + self._fatal_error(ex, 'Error calling eof_received()') + + # Flow control for writes from APP socket + + def _control_app_writing(self): + size = self._get_write_buffer_size() + if size >= self._outgoing_high_water and not self._app_writing_paused: + self._app_writing_paused = True + try: + self._app_protocol.pause_writing() + except (KeyboardInterrupt, SystemExit): + raise + except BaseException as exc: + self._loop.call_exception_handler({ + 'message': 'protocol.pause_writing() failed', + 'exception': exc, + 'transport': self._app_transport, + 'protocol': self, + }) + elif size <= self._outgoing_low_water and self._app_writing_paused: + self._app_writing_paused = False + try: + self._app_protocol.resume_writing() + except (KeyboardInterrupt, SystemExit): + raise + except BaseException as exc: + self._loop.call_exception_handler({ + 'message': 'protocol.resume_writing() failed', + 'exception': exc, + 'transport': self._app_transport, + 'protocol': self, + }) + + def _get_write_buffer_size(self): + return self._outgoing.pending + self._write_buffer_size + + def _set_write_buffer_limits(self, high=None, low=None): + high, low = add_flowcontrol_defaults( + high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_WRITE) + self._outgoing_high_water = high + self._outgoing_low_water = low + + # Flow control for reads to APP socket + + def _pause_reading(self): + self._app_reading_paused = True + + def _resume_reading(self): + if self._app_reading_paused: + self._app_reading_paused = False + + def resume(): + if self._state == SSLProtocolState.WRAPPED: + self._do_read() + elif self._state == SSLProtocolState.FLUSHING: + self._do_flush() + elif self._state == SSLProtocolState.SHUTDOWN: + self._do_shutdown() + self._loop.call_soon(resume) + + # Flow control for reads from SSL socket + + def _control_ssl_reading(self): + size = self._get_read_buffer_size() + if size >= self._incoming_high_water and not self._ssl_reading_paused: + self._ssl_reading_paused = True + self._transport.pause_reading() + elif size <= self._incoming_low_water and self._ssl_reading_paused: + self._ssl_reading_paused = False + self._transport.resume_reading() + + def _set_read_buffer_limits(self, high=None, low=None): + high, low = add_flowcontrol_defaults( + high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_READ) + self._incoming_high_water = high + self._incoming_low_water = low + + def _get_read_buffer_size(self): + return self._incoming.pending + + # Flow control for writes to SSL socket + + def pause_writing(self): + """Called when the low-level transport's buffer goes over + the high-water mark. + """ + assert not self._ssl_writing_paused + self._ssl_writing_paused = True + + def resume_writing(self): + """Called when the low-level transport's buffer drains below + the low-water mark. + """ + assert self._ssl_writing_paused + self._ssl_writing_paused = False + self._process_outgoing() def _fatal_error(self, exc, message='Fatal error on transport'): + if self._transport: + self._transport._force_close(exc) + if isinstance(exc, OSError): if self._loop.get_debug(): logger.debug("%r: %s", self, message, exc_info=True) - else: + elif not isinstance(exc, exceptions.CancelledError): self._loop.call_exception_handler({ 'message': message, 'exception': exc, 'transport': self._transport, 'protocol': self, }) - if self._transport: - self._transport._force_close(exc) - - def _finalize(self): - self._sslpipe = None - - if self._transport is not None: - self._transport.close() - - def _abort(self): - try: - if self._transport is not None: - self._transport.abort() - finally: - self._finalize() diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py index a55b3a375fa22d..181e1885152fab 100644 --- a/Lib/asyncio/unix_events.py +++ b/Lib/asyncio/unix_events.py @@ -229,7 +229,8 @@ async def create_unix_connection( self, protocol_factory, path=None, *, ssl=None, sock=None, server_hostname=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_shutdown_timeout=None): assert server_hostname is None or isinstance(server_hostname, str) if ssl: if server_hostname is None: @@ -241,6 +242,9 @@ async def create_unix_connection( if ssl_handshake_timeout is not None: raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') if path is not None: if sock is not None: @@ -267,13 +271,15 @@ async def create_unix_connection( transport, protocol = await self._create_connection_transport( sock, protocol_factory, ssl, server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout) return transport, protocol async def create_unix_server( self, protocol_factory, path=None, *, sock=None, backlog=100, ssl=None, ssl_handshake_timeout=None, + ssl_shutdown_timeout=None, start_serving=True): if isinstance(ssl, bool): raise TypeError('ssl argument must be an SSLContext or None') @@ -282,6 +288,10 @@ async def create_unix_server( raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_shutdown_timeout is not None and not ssl: + raise ValueError( + 'ssl_shutdown_timeout is only meaningful with ssl') + if path is not None: if sock is not None: raise ValueError( @@ -328,7 +338,8 @@ async def create_unix_server( sock.setblocking(False) server = base_events.Server(self, [sock], protocol_factory, - ssl, backlog, ssl_handshake_timeout) + ssl, backlog, ssl_handshake_timeout, + ssl_shutdown_timeout) if start_serving: server._start_serving() # Skip one loop iteration so that all 'loop.add_reader' diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index 5691d4250aca9e..be5ea1e3c738da 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -1437,44 +1437,51 @@ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport ANY = mock.ANY handshake_timeout = object() + shutdown_timeout = object() # First try the default server_hostname. self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection( MyProto, 'python.org', 80, ssl=True, - ssl_handshake_timeout=handshake_timeout) + ssl_handshake_timeout=handshake_timeout, + ssl_shutdown_timeout=shutdown_timeout) transport, _ = self.loop.run_until_complete(coro) transport.close() self.loop._make_ssl_transport.assert_called_with( ANY, ANY, ANY, ANY, server_side=False, server_hostname='python.org', - ssl_handshake_timeout=handshake_timeout) + ssl_handshake_timeout=handshake_timeout, + ssl_shutdown_timeout=shutdown_timeout) # Next try an explicit server_hostname. self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection( MyProto, 'python.org', 80, ssl=True, server_hostname='perl.com', - ssl_handshake_timeout=handshake_timeout) + ssl_handshake_timeout=handshake_timeout, + ssl_shutdown_timeout=shutdown_timeout) transport, _ = self.loop.run_until_complete(coro) transport.close() self.loop._make_ssl_transport.assert_called_with( ANY, ANY, ANY, ANY, server_side=False, server_hostname='perl.com', - ssl_handshake_timeout=handshake_timeout) + ssl_handshake_timeout=handshake_timeout, + ssl_shutdown_timeout=shutdown_timeout) # Finally try an explicit empty server_hostname. self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection( MyProto, 'python.org', 80, ssl=True, server_hostname='', - ssl_handshake_timeout=handshake_timeout) + ssl_handshake_timeout=handshake_timeout, + ssl_shutdown_timeout=shutdown_timeout) transport, _ = self.loop.run_until_complete(coro) transport.close() self.loop._make_ssl_transport.assert_called_with( ANY, ANY, ANY, ANY, server_side=False, server_hostname='', - ssl_handshake_timeout=handshake_timeout) + ssl_handshake_timeout=handshake_timeout, + ssl_shutdown_timeout=shutdown_timeout) def test_create_connection_no_ssl_server_hostname_errors(self): # When not using ssl, server_hostname must be None. @@ -1881,7 +1888,7 @@ def test_accept_connection_exception(self, m_log): constants.ACCEPT_RETRY_DELAY, # self.loop._start_serving mock.ANY, - MyProto, sock, None, None, mock.ANY, mock.ANY) + MyProto, sock, None, None, mock.ANY, mock.ANY, mock.ANY) def test_call_coroutine(self): with self.assertWarns(DeprecationWarning): diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py index 1613c753c26ee5..349e4f2dca0634 100644 --- a/Lib/test/test_asyncio/test_selector_events.py +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -70,44 +70,6 @@ def test_make_socket_transport(self): close_transport(transport) - @unittest.skipIf(ssl is None, 'No ssl module') - def test_make_ssl_transport(self): - m = mock.Mock() - self.loop._add_reader = mock.Mock() - self.loop._add_reader._is_coroutine = False - self.loop._add_writer = mock.Mock() - self.loop._remove_reader = mock.Mock() - self.loop._remove_writer = mock.Mock() - waiter = self.loop.create_future() - with test_utils.disable_logger(): - transport = self.loop._make_ssl_transport( - m, asyncio.Protocol(), m, waiter) - - with self.assertRaisesRegex(RuntimeError, - r'SSL transport.*not.*initialized'): - transport.is_reading() - - # execute the handshake while the logger is disabled - # to ignore SSL handshake failure - test_utils.run_briefly(self.loop) - - self.assertTrue(transport.is_reading()) - transport.pause_reading() - transport.pause_reading() - self.assertFalse(transport.is_reading()) - transport.resume_reading() - transport.resume_reading() - self.assertTrue(transport.is_reading()) - - # Sanity check - class_name = transport.__class__.__name__ - self.assertIn("ssl", class_name.lower()) - self.assertIn("transport", class_name.lower()) - - transport.close() - # execute pending callbacks to close the socket transport - test_utils.run_briefly(self.loop) - @mock.patch('asyncio.selector_events.ssl', None) @mock.patch('asyncio.sslproto.ssl', None) def test_make_ssl_transport_without_ssl_error(self): diff --git a/Lib/test/test_asyncio/test_ssl.py b/Lib/test/test_asyncio/test_ssl.py new file mode 100644 index 00000000000000..38235c63e01e93 --- /dev/null +++ b/Lib/test/test_asyncio/test_ssl.py @@ -0,0 +1,1718 @@ +import asyncio +import asyncio.sslproto +import contextlib +import gc +import logging +import os +import select +import socket +import ssl +import tempfile +import threading +import time +import weakref + +from test import support +from test.test_asyncio import utils as test_utils + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class MyBaseProto(asyncio.Protocol): + connected = None + done = None + + def __init__(self, loop=None): + self.transport = None + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.connected = asyncio.Future(loop=loop) + self.done = asyncio.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + if self.connected: + self.connected.set_result(None) + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class TestSSL(test_utils.TestCase): + + PAYLOAD_SIZE = 1024 * 100 + TIMEOUT = 60 + + def setUp(self): + super().setUp() + self.loop = asyncio.new_event_loop() + self.set_event_loop(self.loop) + self.addCleanup(self.loop.close) + + def tearDown(self): + # just in case if we have transport close callbacks + if not self.loop.is_closed(): + test_utils.run_briefly(self.loop) + + self.doCleanups() + support.gc_collect() + super().tearDown() + + def tcp_server(self, server_prog, *, + family=socket.AF_INET, + addr=None, + timeout=5, + backlog=1, + max_clients=10): + + if addr is None: + if family == getattr(socket, "AF_UNIX", None): + with tempfile.NamedTemporaryFile() as tmp: + addr = tmp.name + else: + addr = ('127.0.0.1', 0) + + sock = socket.socket(family, socket.SOCK_STREAM) + + if timeout is None: + raise RuntimeError('timeout is required') + if timeout <= 0: + raise RuntimeError('only blocking sockets are supported') + sock.settimeout(timeout) + + try: + sock.bind(addr) + sock.listen(backlog) + except OSError as ex: + sock.close() + raise ex + + return TestThreadedServer( + self, sock, server_prog, timeout, max_clients) + + def tcp_client(self, client_prog, + family=socket.AF_INET, + timeout=10): + + sock = socket.socket(family, socket.SOCK_STREAM) + + if timeout is None: + raise RuntimeError('timeout is required') + if timeout <= 0: + raise RuntimeError('only blocking sockets are supported') + sock.settimeout(timeout) + + return TestThreadedClient( + self, sock, client_prog, timeout) + + def unix_server(self, *args, **kwargs): + return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs) + + def unix_client(self, *args, **kwargs): + return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs) + + def _create_server_ssl_context(self, certfile, keyfile=None): + sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.load_cert_chain(certfile, keyfile) + return sslcontext + + def _create_client_ssl_context(self, *, disable_verify=True): + sslcontext = ssl.create_default_context() + sslcontext.check_hostname = False + if disable_verify: + sslcontext.verify_mode = ssl.CERT_NONE + return sslcontext + + @contextlib.contextmanager + def _silence_eof_received_warning(self): + # TODO This warning has to be fixed in asyncio. + logger = logging.getLogger('asyncio') + filter = logging.Filter('has no effect when using ssl') + logger.addFilter(filter) + try: + yield + finally: + logger.removeFilter(filter) + + def _abort_socket_test(self, ex): + try: + self.loop.stop() + finally: + self.fail(ex) + + def new_loop(self): + return asyncio.new_event_loop() + + def new_policy(self): + return asyncio.DefaultEventLoopPolicy() + + async def wait_closed(self, obj): + if not isinstance(obj, asyncio.StreamWriter): + return + try: + await obj.wait_closed() + except (BrokenPipeError, ConnectionError): + pass + + def test_create_server_ssl_1(self): + CNT = 0 # number of clients that were successful + TOTAL_CNT = 25 # total number of clients that test will create + TIMEOUT = 10.0 # timeout for this test + + A_DATA = b'A' * 1024 * 1024 + B_DATA = b'B' * 1024 * 1024 + + sslctx = self._create_server_ssl_context( + test_utils.ONLYCERT, test_utils.ONLYKEY + ) + client_sslctx = self._create_client_ssl_context() + + clients = [] + + async def handle_client(reader, writer): + nonlocal CNT + + data = await reader.readexactly(len(A_DATA)) + self.assertEqual(data, A_DATA) + writer.write(b'OK') + + data = await reader.readexactly(len(B_DATA)) + self.assertEqual(data, B_DATA) + writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) + + await writer.drain() + writer.close() + + CNT += 1 + + async def test_client(addr): + fut = asyncio.Future() + + def prog(sock): + try: + sock.starttls(client_sslctx) + sock.connect(addr) + sock.send(A_DATA) + + data = sock.recv_all(2) + self.assertEqual(data, b'OK') + + sock.send(B_DATA) + data = sock.recv_all(4) + self.assertEqual(data, b'SPAM') + + sock.close() + + except Exception as ex: + self.loop.call_soon_threadsafe(fut.set_exception, ex) + else: + self.loop.call_soon_threadsafe(fut.set_result, None) + + client = self.tcp_client(prog) + client.start() + clients.append(client) + + await fut + + async def start_server(): + extras = {} + extras = dict(ssl_handshake_timeout=10.0) + + srv = await asyncio.start_server( + handle_client, + '127.0.0.1', 0, + family=socket.AF_INET, + ssl=sslctx, + **extras) + + try: + srv_socks = srv.sockets + self.assertTrue(srv_socks) + + addr = srv_socks[0].getsockname() + + tasks = [] + for _ in range(TOTAL_CNT): + tasks.append(test_client(addr)) + + await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) + + finally: + self.loop.call_soon(srv.close) + await srv.wait_closed() + + with self._silence_eof_received_warning(): + self.loop.run_until_complete(start_server()) + + self.assertEqual(CNT, TOTAL_CNT) + + for client in clients: + client.stop() + + def test_create_connection_ssl_1(self): + self.loop.set_exception_handler(None) + + CNT = 0 + TOTAL_CNT = 25 + + A_DATA = b'A' * 1024 * 1024 + B_DATA = b'B' * 1024 * 1024 + + sslctx = self._create_server_ssl_context( + test_utils.ONLYCERT, + test_utils.ONLYKEY + ) + client_sslctx = self._create_client_ssl_context() + + def server(sock): + sock.starttls( + sslctx, + server_side=True) + + data = sock.recv_all(len(A_DATA)) + self.assertEqual(data, A_DATA) + sock.send(b'OK') + + data = sock.recv_all(len(B_DATA)) + self.assertEqual(data, B_DATA) + sock.send(b'SPAM') + + sock.close() + + async def client(addr): + extras = {} + extras = dict(ssl_handshake_timeout=10.0) + + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + **extras) + + writer.write(A_DATA) + self.assertEqual(await reader.readexactly(2), b'OK') + + writer.write(B_DATA) + self.assertEqual(await reader.readexactly(4), b'SPAM') + + nonlocal CNT + CNT += 1 + + writer.close() + await self.wait_closed(writer) + + async def client_sock(addr): + sock = socket.socket() + sock.connect(addr) + reader, writer = await asyncio.open_connection( + sock=sock, + ssl=client_sslctx, + server_hostname='') + + writer.write(A_DATA) + self.assertEqual(await reader.readexactly(2), b'OK') + + writer.write(B_DATA) + self.assertEqual(await reader.readexactly(4), b'SPAM') + + nonlocal CNT + CNT += 1 + + writer.close() + await self.wait_closed(writer) + sock.close() + + def run(coro): + nonlocal CNT + CNT = 0 + + async def _gather(*tasks): + # trampoline + return await asyncio.gather(*tasks) + + with self.tcp_server(server, + max_clients=TOTAL_CNT, + backlog=TOTAL_CNT) as srv: + tasks = [] + for _ in range(TOTAL_CNT): + tasks.append(coro(srv.addr)) + + self.loop.run_until_complete(_gather(*tasks)) + + self.assertEqual(CNT, TOTAL_CNT) + + with self._silence_eof_received_warning(): + run(client) + + with self._silence_eof_received_warning(): + run(client_sock) + + def test_create_connection_ssl_slow_handshake(self): + client_sslctx = self._create_client_ssl_context() + + # silence error logger + self.loop.set_exception_handler(lambda *args: None) + + def server(sock): + try: + sock.recv_all(1024 * 1024) + except ConnectionAbortedError: + pass + finally: + sock.close() + + async def client(addr): + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + ssl_handshake_timeout=1.0) + writer.close() + await self.wait_closed(writer) + + with self.tcp_server(server, + max_clients=1, + backlog=1) as srv: + + with self.assertRaisesRegex( + ConnectionAbortedError, + r'SSL handshake.*is taking longer'): + + self.loop.run_until_complete(client(srv.addr)) + + def test_create_connection_ssl_failed_certificate(self): + # silence error logger + self.loop.set_exception_handler(lambda *args: None) + + sslctx = self._create_server_ssl_context( + test_utils.ONLYCERT, + test_utils.ONLYKEY + ) + client_sslctx = self._create_client_ssl_context(disable_verify=False) + + def server(sock): + try: + sock.starttls( + sslctx, + server_side=True) + sock.connect() + except (ssl.SSLError, OSError): + pass + finally: + sock.close() + + async def client(addr): + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + ssl_handshake_timeout=1.0) + writer.close() + await self.wait_closed(writer) + + with self.tcp_server(server, + max_clients=1, + backlog=1) as srv: + + with self.assertRaises(ssl.SSLCertVerificationError): + self.loop.run_until_complete(client(srv.addr)) + + def test_ssl_handshake_timeout(self): + # bpo-29970: Check that a connection is aborted if handshake is not + # completed in timeout period, instead of remaining open indefinitely + client_sslctx = test_utils.simple_client_sslcontext() + + # silence error logger + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + server_side_aborted = False + + def server(sock): + nonlocal server_side_aborted + try: + sock.recv_all(1024 * 1024) + except ConnectionAbortedError: + server_side_aborted = True + finally: + sock.close() + + async def client(addr): + await asyncio.wait_for( + self.loop.create_connection( + asyncio.Protocol, + *addr, + ssl=client_sslctx, + server_hostname='', + ssl_handshake_timeout=10.0), + 0.5) + + with self.tcp_server(server, + max_clients=1, + backlog=1) as srv: + + with self.assertRaises(asyncio.TimeoutError): + self.loop.run_until_complete(client(srv.addr)) + + self.assertTrue(server_side_aborted) + + # Python issue #23197: cancelling a handshake must not raise an + # exception or log an error, even if the handshake failed + self.assertEqual(messages, []) + + def test_ssl_handshake_connection_lost(self): + # #246: make sure that no connection_lost() is called before + # connection_made() is called first + + client_sslctx = test_utils.simple_client_sslcontext() + + # silence error logger + self.loop.set_exception_handler(lambda loop, ctx: None) + + connection_made_called = False + connection_lost_called = False + + def server(sock): + sock.recv(1024) + # break the connection during handshake + sock.close() + + class ClientProto(asyncio.Protocol): + def connection_made(self, transport): + nonlocal connection_made_called + connection_made_called = True + + def connection_lost(self, exc): + nonlocal connection_lost_called + connection_lost_called = True + + async def client(addr): + await self.loop.create_connection( + ClientProto, + *addr, + ssl=client_sslctx, + server_hostname=''), + + with self.tcp_server(server, + max_clients=1, + backlog=1) as srv: + + with self.assertRaises(ConnectionResetError): + self.loop.run_until_complete(client(srv.addr)) + + if connection_lost_called: + if connection_made_called: + self.fail("unexpected call to connection_lost()") + else: + self.fail("unexpected call to connection_lost() without" + "calling connection_made()") + elif connection_made_called: + self.fail("unexpected call to connection_made()") + + def test_ssl_connect_accepted_socket(self): + proto = ssl.PROTOCOL_TLS_SERVER + server_context = ssl.SSLContext(proto) + server_context.load_cert_chain(test_utils.ONLYCERT, test_utils.ONLYKEY) + if hasattr(server_context, 'check_hostname'): + server_context.check_hostname = False + server_context.verify_mode = ssl.CERT_NONE + + client_context = ssl.SSLContext(proto) + if hasattr(server_context, 'check_hostname'): + client_context.check_hostname = False + client_context.verify_mode = ssl.CERT_NONE + + def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None): + loop = self.loop + + class MyProto(MyBaseProto): + + def connection_lost(self, exc): + super().connection_lost(exc) + loop.call_soon(loop.stop) + + def data_received(self, data): + super().data_received(data) + self.transport.write(expected_response) + + lsock = socket.socket(socket.AF_INET) + lsock.bind(('127.0.0.1', 0)) + lsock.listen(1) + addr = lsock.getsockname() + + message = b'test data' + response = None + expected_response = b'roger' + + def client(): + nonlocal response + try: + csock = socket.socket(socket.AF_INET) + if client_ssl is not None: + csock = client_ssl.wrap_socket(csock) + csock.connect(addr) + csock.sendall(message) + response = csock.recv(99) + csock.close() + except Exception as exc: + print( + "Failure in client thread in test_connect_accepted_socket", + exc) + + thread = threading.Thread(target=client, daemon=True) + thread.start() + + conn, _ = lsock.accept() + proto = MyProto(loop=loop) + proto.loop = loop + + extras = {} + if server_ssl: + extras = dict(ssl_handshake_timeout=10.0) + + f = loop.create_task( + loop.connect_accepted_socket( + (lambda: proto), conn, ssl=server_ssl, + **extras)) + loop.run_forever() + conn.close() + lsock.close() + + thread.join(1) + self.assertFalse(thread.is_alive()) + self.assertEqual(proto.state, 'CLOSED') + self.assertEqual(proto.nbytes, len(message)) + self.assertEqual(response, expected_response) + tr, _ = f.result() + + if server_ssl: + self.assertIn('SSL', tr.__class__.__name__) + + tr.close() + # let it close + self.loop.run_until_complete(asyncio.sleep(0.1)) + + def test_start_tls_client_corrupted_ssl(self): + self.loop.set_exception_handler(lambda loop, ctx: None) + + sslctx = test_utils.simple_server_sslcontext() + client_sslctx = test_utils.simple_client_sslcontext() + + def server(sock): + orig_sock = sock.dup() + try: + sock.starttls( + sslctx, + server_side=True) + sock.sendall(b'A\n') + sock.recv_all(1) + orig_sock.send(b'please corrupt the SSL connection') + except ssl.SSLError: + pass + finally: + sock.close() + orig_sock.close() + + async def client(addr): + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='') + + self.assertEqual(await reader.readline(), b'A\n') + writer.write(b'B') + with self.assertRaises(ssl.SSLError): + await reader.readline() + writer.close() + try: + await self.wait_closed(writer) + except ssl.SSLError: + pass + return 'OK' + + with self.tcp_server(server, + max_clients=1, + backlog=1) as srv: + + res = self.loop.run_until_complete(client(srv.addr)) + + self.assertEqual(res, 'OK') + + def test_start_tls_client_reg_proto_1(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + + server_context = test_utils.simple_server_sslcontext() + client_context = test_utils.simple_client_sslcontext() + + def serve(sock): + sock.settimeout(self.TIMEOUT) + + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.starttls(server_context, server_side=True) + + sock.sendall(b'O') + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.unwrap() + sock.close() + + class ClientProto(asyncio.Protocol): + def __init__(self, on_data, on_eof): + self.on_data = on_data + self.on_eof = on_eof + self.con_made_cnt = 0 + + def connection_made(proto, tr): + proto.con_made_cnt += 1 + # Ensure connection_made gets called only once. + self.assertEqual(proto.con_made_cnt, 1) + + def data_received(self, data): + self.on_data.set_result(data) + + def eof_received(self): + self.on_eof.set_result(True) + + async def client(addr): + await asyncio.sleep(0.5) + + on_data = self.loop.create_future() + on_eof = self.loop.create_future() + + tr, proto = await self.loop.create_connection( + lambda: ClientProto(on_data, on_eof), *addr) + + tr.write(HELLO_MSG) + new_tr = await self.loop.start_tls(tr, proto, client_context) + + self.assertEqual(await on_data, b'O') + new_tr.write(HELLO_MSG) + await on_eof + + new_tr.close() + + with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: + self.loop.run_until_complete( + asyncio.wait_for(client(srv.addr), timeout=10)) + + def test_create_connection_memory_leak(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + + server_context = self._create_server_ssl_context( + test_utils.ONLYCERT, test_utils.ONLYKEY) + client_context = self._create_client_ssl_context() + + def serve(sock): + sock.settimeout(self.TIMEOUT) + + sock.starttls(server_context, server_side=True) + + sock.sendall(b'O') + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.unwrap() + sock.close() + + class ClientProto(asyncio.Protocol): + def __init__(self, on_data, on_eof): + self.on_data = on_data + self.on_eof = on_eof + self.con_made_cnt = 0 + + def connection_made(proto, tr): + # XXX: We assume user stores the transport in protocol + proto.tr = tr + proto.con_made_cnt += 1 + # Ensure connection_made gets called only once. + self.assertEqual(proto.con_made_cnt, 1) + + def data_received(self, data): + self.on_data.set_result(data) + + def eof_received(self): + self.on_eof.set_result(True) + + async def client(addr): + await asyncio.sleep(0.5) + + on_data = self.loop.create_future() + on_eof = self.loop.create_future() + + tr, proto = await self.loop.create_connection( + lambda: ClientProto(on_data, on_eof), *addr, + ssl=client_context) + + self.assertEqual(await on_data, b'O') + tr.write(HELLO_MSG) + await on_eof + + tr.close() + + with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: + self.loop.run_until_complete( + asyncio.wait_for(client(srv.addr), timeout=10)) + + # No garbage is left for SSL client from loop.create_connection, even + # if user stores the SSLTransport in corresponding protocol instance + client_context = weakref.ref(client_context) + self.assertIsNone(client_context()) + + def test_start_tls_client_buf_proto_1(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + + server_context = test_utils.simple_server_sslcontext() + client_context = test_utils.simple_client_sslcontext() + + client_con_made_calls = 0 + + def serve(sock): + sock.settimeout(self.TIMEOUT) + + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.starttls(server_context, server_side=True) + + sock.sendall(b'O') + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.sendall(b'2') + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.unwrap() + sock.close() + + class ClientProtoFirst(asyncio.BufferedProtocol): + def __init__(self, on_data): + self.on_data = on_data + self.buf = bytearray(1) + + def connection_made(self, tr): + nonlocal client_con_made_calls + client_con_made_calls += 1 + + def get_buffer(self, sizehint): + return self.buf + + def buffer_updated(self, nsize): + assert nsize == 1 + self.on_data.set_result(bytes(self.buf[:nsize])) + + def eof_received(self): + pass + + class ClientProtoSecond(asyncio.Protocol): + def __init__(self, on_data, on_eof): + self.on_data = on_data + self.on_eof = on_eof + self.con_made_cnt = 0 + + def connection_made(self, tr): + nonlocal client_con_made_calls + client_con_made_calls += 1 + + def data_received(self, data): + self.on_data.set_result(data) + + def eof_received(self): + self.on_eof.set_result(True) + + async def client(addr): + await asyncio.sleep(0.5) + + on_data1 = self.loop.create_future() + on_data2 = self.loop.create_future() + on_eof = self.loop.create_future() + + tr, proto = await self.loop.create_connection( + lambda: ClientProtoFirst(on_data1), *addr) + + tr.write(HELLO_MSG) + new_tr = await self.loop.start_tls(tr, proto, client_context) + + self.assertEqual(await on_data1, b'O') + new_tr.write(HELLO_MSG) + + new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof)) + self.assertEqual(await on_data2, b'2') + new_tr.write(HELLO_MSG) + await on_eof + + new_tr.close() + + # connection_made() should be called only once -- when + # we establish connection for the first time. Start TLS + # doesn't call connection_made() on application protocols. + self.assertEqual(client_con_made_calls, 1) + + with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: + self.loop.run_until_complete( + asyncio.wait_for(client(srv.addr), + timeout=self.TIMEOUT)) + + def test_start_tls_slow_client_cancel(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + + client_context = test_utils.simple_client_sslcontext() + server_waits_on_handshake = self.loop.create_future() + + def serve(sock): + sock.settimeout(self.TIMEOUT) + + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + try: + self.loop.call_soon_threadsafe( + server_waits_on_handshake.set_result, None) + data = sock.recv_all(1024 * 1024) + except ConnectionAbortedError: + pass + finally: + sock.close() + + class ClientProto(asyncio.Protocol): + def __init__(self, on_data, on_eof): + self.on_data = on_data + self.on_eof = on_eof + self.con_made_cnt = 0 + + def connection_made(proto, tr): + proto.con_made_cnt += 1 + # Ensure connection_made gets called only once. + self.assertEqual(proto.con_made_cnt, 1) + + def data_received(self, data): + self.on_data.set_result(data) + + def eof_received(self): + self.on_eof.set_result(True) + + async def client(addr): + await asyncio.sleep(0.5) + + on_data = self.loop.create_future() + on_eof = self.loop.create_future() + + tr, proto = await self.loop.create_connection( + lambda: ClientProto(on_data, on_eof), *addr) + + tr.write(HELLO_MSG) + + await server_waits_on_handshake + + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for( + self.loop.start_tls(tr, proto, client_context), + 0.5) + + with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: + self.loop.run_until_complete( + asyncio.wait_for(client(srv.addr), timeout=10)) + + def test_start_tls_server_1(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + + server_context = test_utils.simple_server_sslcontext() + client_context = test_utils.simple_client_sslcontext() + + def client(sock, addr): + sock.settimeout(self.TIMEOUT) + + sock.connect(addr) + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.starttls(client_context) + sock.sendall(HELLO_MSG) + + sock.unwrap() + sock.close() + + class ServerProto(asyncio.Protocol): + def __init__(self, on_con, on_eof, on_con_lost): + self.on_con = on_con + self.on_eof = on_eof + self.on_con_lost = on_con_lost + self.data = b'' + + def connection_made(self, tr): + self.on_con.set_result(tr) + + def data_received(self, data): + self.data += data + + def eof_received(self): + self.on_eof.set_result(1) + + def connection_lost(self, exc): + if exc is None: + self.on_con_lost.set_result(None) + else: + self.on_con_lost.set_exception(exc) + + async def main(proto, on_con, on_eof, on_con_lost): + tr = await on_con + tr.write(HELLO_MSG) + + self.assertEqual(proto.data, b'') + + new_tr = await self.loop.start_tls( + tr, proto, server_context, + server_side=True, + ssl_handshake_timeout=self.TIMEOUT) + + await on_eof + await on_con_lost + self.assertEqual(proto.data, HELLO_MSG) + new_tr.close() + + async def run_main(): + on_con = self.loop.create_future() + on_eof = self.loop.create_future() + on_con_lost = self.loop.create_future() + proto = ServerProto(on_con, on_eof, on_con_lost) + + server = await self.loop.create_server( + lambda: proto, '127.0.0.1', 0) + addr = server.sockets[0].getsockname() + + with self.tcp_client(lambda sock: client(sock, addr), + timeout=self.TIMEOUT): + await asyncio.wait_for( + main(proto, on_con, on_eof, on_con_lost), + timeout=self.TIMEOUT) + + server.close() + await server.wait_closed() + + self.loop.run_until_complete(run_main()) + + def test_create_server_ssl_over_ssl(self): + CNT = 0 # number of clients that were successful + TOTAL_CNT = 25 # total number of clients that test will create + TIMEOUT = 10.0 # timeout for this test + + A_DATA = b'A' * 1024 * 1024 + B_DATA = b'B' * 1024 * 1024 + + sslctx_1 = self._create_server_ssl_context( + test_utils.ONLYCERT, test_utils.ONLYKEY) + client_sslctx_1 = self._create_client_ssl_context() + sslctx_2 = self._create_server_ssl_context( + test_utils.ONLYCERT, test_utils.ONLYKEY) + client_sslctx_2 = self._create_client_ssl_context() + + clients = [] + + async def handle_client(reader, writer): + nonlocal CNT + + data = await reader.readexactly(len(A_DATA)) + self.assertEqual(data, A_DATA) + writer.write(b'OK') + + data = await reader.readexactly(len(B_DATA)) + self.assertEqual(data, B_DATA) + writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) + + await writer.drain() + writer.close() + + CNT += 1 + + class ServerProtocol(asyncio.StreamReaderProtocol): + def connection_made(self, transport): + super_ = super() + transport.pause_reading() + fut = self._loop.create_task(self._loop.start_tls( + transport, self, sslctx_2, server_side=True)) + + def cb(_): + try: + tr = fut.result() + except Exception as ex: + super_.connection_lost(ex) + else: + super_.connection_made(tr) + fut.add_done_callback(cb) + + def server_protocol_factory(): + reader = asyncio.StreamReader() + protocol = ServerProtocol(reader, handle_client) + return protocol + + async def test_client(addr): + fut = asyncio.Future() + + def prog(sock): + try: + sock.connect(addr) + sock.starttls(client_sslctx_1) + + # because wrap_socket() doesn't work correctly on + # SSLSocket, we have to do the 2nd level SSL manually + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + sslobj = client_sslctx_2.wrap_bio(incoming, outgoing) + + def do(func, *args): + while True: + try: + rv = func(*args) + break + except ssl.SSLWantReadError: + if outgoing.pending: + sock.send(outgoing.read()) + incoming.write(sock.recv(65536)) + if outgoing.pending: + sock.send(outgoing.read()) + return rv + + do(sslobj.do_handshake) + + do(sslobj.write, A_DATA) + data = do(sslobj.read, 2) + self.assertEqual(data, b'OK') + + do(sslobj.write, B_DATA) + data = b'' + while True: + chunk = do(sslobj.read, 4) + if not chunk: + break + data += chunk + self.assertEqual(data, b'SPAM') + + do(sslobj.unwrap) + sock.close() + + except Exception as ex: + self.loop.call_soon_threadsafe(fut.set_exception, ex) + sock.close() + else: + self.loop.call_soon_threadsafe(fut.set_result, None) + + client = self.tcp_client(prog) + client.start() + clients.append(client) + + await fut + + async def start_server(): + extras = {} + + srv = await self.loop.create_server( + server_protocol_factory, + '127.0.0.1', 0, + family=socket.AF_INET, + ssl=sslctx_1, + **extras) + + try: + srv_socks = srv.sockets + self.assertTrue(srv_socks) + + addr = srv_socks[0].getsockname() + + tasks = [] + for _ in range(TOTAL_CNT): + tasks.append(test_client(addr)) + + await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) + + finally: + self.loop.call_soon(srv.close) + await srv.wait_closed() + + with self._silence_eof_received_warning(): + self.loop.run_until_complete(start_server()) + + self.assertEqual(CNT, TOTAL_CNT) + + for client in clients: + client.stop() + + def test_shutdown_cleanly(self): + CNT = 0 + TOTAL_CNT = 25 + + A_DATA = b'A' * 1024 * 1024 + + sslctx = self._create_server_ssl_context( + test_utils.ONLYCERT, test_utils.ONLYKEY) + client_sslctx = self._create_client_ssl_context() + + def server(sock): + sock.starttls( + sslctx, + server_side=True) + + data = sock.recv_all(len(A_DATA)) + self.assertEqual(data, A_DATA) + sock.send(b'OK') + + sock.unwrap() + + sock.close() + + async def client(addr): + extras = {} + extras = dict(ssl_handshake_timeout=10.0) + + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + **extras) + + writer.write(A_DATA) + self.assertEqual(await reader.readexactly(2), b'OK') + + self.assertEqual(await reader.read(), b'') + + nonlocal CNT + CNT += 1 + + writer.close() + await self.wait_closed(writer) + + def run(coro): + nonlocal CNT + CNT = 0 + + async def _gather(*tasks): + return await asyncio.gather(*tasks) + + with self.tcp_server(server, + max_clients=TOTAL_CNT, + backlog=TOTAL_CNT) as srv: + tasks = [] + for _ in range(TOTAL_CNT): + tasks.append(coro(srv.addr)) + + self.loop.run_until_complete( + _gather(*tasks)) + + self.assertEqual(CNT, TOTAL_CNT) + + with self._silence_eof_received_warning(): + run(client) + + def test_flush_before_shutdown(self): + CHUNK = 1024 * 128 + SIZE = 32 + + sslctx = self._create_server_ssl_context( + test_utils.ONLYCERT, test_utils.ONLYKEY) + client_sslctx = self._create_client_ssl_context() + if hasattr(ssl, 'OP_NO_TLSv1_3'): + client_sslctx.options |= ssl.OP_NO_TLSv1_3 + + future = None + + def server(sock): + sock.starttls(sslctx, server_side=True) + self.assertEqual(sock.recv_all(4), b'ping') + sock.send(b'pong') + time.sleep(0.5) # hopefully stuck the TCP buffer + data = sock.recv_all(CHUNK * SIZE) + self.assertEqual(len(data), CHUNK * SIZE) + sock.close() + + def run(meth): + def wrapper(sock): + try: + meth(sock) + except Exception as ex: + self.loop.call_soon_threadsafe(future.set_exception, ex) + else: + self.loop.call_soon_threadsafe(future.set_result, None) + return wrapper + + async def client(addr): + nonlocal future + future = self.loop.create_future() + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='') + sslprotocol = writer.transport._ssl_protocol + writer.write(b'ping') + data = await reader.readexactly(4) + self.assertEqual(data, b'pong') + + sslprotocol.pause_writing() + for _ in range(SIZE): + writer.write(b'x' * CHUNK) + + writer.close() + sslprotocol.resume_writing() + + await self.wait_closed(writer) + try: + data = await reader.read() + self.assertEqual(data, b'') + except ConnectionResetError: + pass + await future + + with self.tcp_server(run(server)) as srv: + self.loop.run_until_complete(client(srv.addr)) + + def test_remote_shutdown_receives_trailing_data(self): + CHUNK = 1024 * 128 + SIZE = 32 + + sslctx = self._create_server_ssl_context( + test_utils.ONLYCERT, + test_utils.ONLYKEY + ) + client_sslctx = self._create_client_ssl_context() + future = None + + def server(sock): + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True) + + while True: + try: + sslobj.do_handshake() + except ssl.SSLWantReadError: + if outgoing.pending: + sock.send(outgoing.read()) + incoming.write(sock.recv(16384)) + else: + if outgoing.pending: + sock.send(outgoing.read()) + break + + while True: + try: + data = sslobj.read(4) + except ssl.SSLWantReadError: + incoming.write(sock.recv(16384)) + else: + break + + self.assertEqual(data, b'ping') + sslobj.write(b'pong') + sock.send(outgoing.read()) + + time.sleep(0.2) # wait for the peer to fill its backlog + + # send close_notify but don't wait for response + with self.assertRaises(ssl.SSLWantReadError): + sslobj.unwrap() + sock.send(outgoing.read()) + + # should receive all data + data_len = 0 + while True: + try: + chunk = len(sslobj.read(16384)) + data_len += chunk + except ssl.SSLWantReadError: + incoming.write(sock.recv(16384)) + except ssl.SSLZeroReturnError: + break + + self.assertEqual(data_len, CHUNK * SIZE) + + # verify that close_notify is received + sslobj.unwrap() + + sock.close() + + def eof_server(sock): + sock.starttls(sslctx, server_side=True) + self.assertEqual(sock.recv_all(4), b'ping') + sock.send(b'pong') + + time.sleep(0.2) # wait for the peer to fill its backlog + + # send EOF + sock.shutdown(socket.SHUT_WR) + + # should receive all data + data = sock.recv_all(CHUNK * SIZE) + self.assertEqual(len(data), CHUNK * SIZE) + + sock.close() + + async def client(addr): + nonlocal future + future = self.loop.create_future() + + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='') + writer.write(b'ping') + data = await reader.readexactly(4) + self.assertEqual(data, b'pong') + + # fill write backlog in a hacky way - renegotiation won't help + for _ in range(SIZE): + writer.transport._test__append_write_backlog(b'x' * CHUNK) + + try: + data = await reader.read() + self.assertEqual(data, b'') + except (BrokenPipeError, ConnectionResetError): + pass + + await future + + writer.close() + await self.wait_closed(writer) + + def run(meth): + def wrapper(sock): + try: + meth(sock) + except Exception as ex: + self.loop.call_soon_threadsafe(future.set_exception, ex) + else: + self.loop.call_soon_threadsafe(future.set_result, None) + return wrapper + + with self.tcp_server(run(server)) as srv: + self.loop.run_until_complete(client(srv.addr)) + + with self.tcp_server(run(eof_server)) as srv: + self.loop.run_until_complete(client(srv.addr)) + + def test_connect_timeout_warning(self): + s = socket.socket(socket.AF_INET) + s.bind(('127.0.0.1', 0)) + addr = s.getsockname() + + async def test(): + try: + await asyncio.wait_for( + self.loop.create_connection(asyncio.Protocol, + *addr, ssl=True), + 0.1) + except (ConnectionRefusedError, asyncio.TimeoutError): + pass + else: + self.fail('TimeoutError is not raised') + + with s: + try: + with self.assertWarns(ResourceWarning) as cm: + self.loop.run_until_complete(test()) + gc.collect() + gc.collect() + gc.collect() + except AssertionError as e: + self.assertEqual(str(e), 'ResourceWarning not triggered') + else: + self.fail('Unexpected ResourceWarning: {}'.format(cm.warning)) + + def test_handshake_timeout_handler_leak(self): + s = socket.socket(socket.AF_INET) + s.bind(('127.0.0.1', 0)) + s.listen(1) + addr = s.getsockname() + + async def test(ctx): + try: + await asyncio.wait_for( + self.loop.create_connection(asyncio.Protocol, *addr, + ssl=ctx), + 0.1) + except (ConnectionRefusedError, asyncio.TimeoutError): + pass + else: + self.fail('TimeoutError is not raised') + + with s: + ctx = ssl.create_default_context() + self.loop.run_until_complete(test(ctx)) + ctx = weakref.ref(ctx) + + # SSLProtocol should be DECREF to 0 + self.assertIsNone(ctx()) + + def test_shutdown_timeout_handler_leak(self): + loop = self.loop + + def server(sock): + sslctx = self._create_server_ssl_context( + test_utils.ONLYCERT, + test_utils.ONLYKEY + ) + sock = sslctx.wrap_socket(sock, server_side=True) + sock.recv(32) + sock.close() + + class Protocol(asyncio.Protocol): + def __init__(self): + self.fut = asyncio.Future(loop=loop) + + def connection_lost(self, exc): + self.fut.set_result(None) + + async def client(addr, ctx): + tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx) + tr.close() + await pr.fut + + with self.tcp_server(server) as srv: + ctx = self._create_client_ssl_context() + loop.run_until_complete(client(srv.addr, ctx)) + ctx = weakref.ref(ctx) + + # asyncio has no shutdown timeout, but it ends up with a circular + # reference loop - not ideal (introduces gc glitches), but at least + # not leaking + gc.collect() + gc.collect() + gc.collect() + + # SSLProtocol should be DECREF to 0 + self.assertIsNone(ctx()) + + def test_shutdown_timeout_handler_not_set(self): + loop = self.loop + eof = asyncio.Event() + extra = None + + def server(sock): + sslctx = self._create_server_ssl_context( + test_utils.ONLYCERT, + test_utils.ONLYKEY + ) + sock = sslctx.wrap_socket(sock, server_side=True) + sock.send(b'hello') + assert sock.recv(1024) == b'world' + sock.send(b'extra bytes') + # sending EOF here + sock.shutdown(socket.SHUT_WR) + loop.call_soon_threadsafe(eof.set) + # make sure we have enough time to reproduce the issue + assert sock.recv(1024) == b'' + sock.close() + + class Protocol(asyncio.Protocol): + def __init__(self): + self.fut = asyncio.Future(loop=loop) + self.transport = None + + def connection_made(self, transport): + self.transport = transport + + def data_received(self, data): + if data == b'hello': + self.transport.write(b'world') + # pause reading would make incoming data stay in the sslobj + self.transport.pause_reading() + else: + nonlocal extra + extra = data + + def connection_lost(self, exc): + if exc is None: + self.fut.set_result(None) + else: + self.fut.set_exception(exc) + + async def client(addr): + ctx = self._create_client_ssl_context() + tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx) + await eof.wait() + tr.resume_reading() + await pr.fut + tr.close() + assert extra == b'extra bytes' + + with self.tcp_server(server) as srv: + loop.run_until_complete(client(srv.addr)) + + +############################################################################### +# Socket Testing Utilities +############################################################################### + + +class TestSocketWrapper: + + def __init__(self, sock): + self.__sock = sock + + def recv_all(self, n): + buf = b'' + while len(buf) < n: + data = self.recv(n - len(buf)) + if data == b'': + raise ConnectionAbortedError + buf += data + return buf + + def starttls(self, ssl_context, *, + server_side=False, + server_hostname=None, + do_handshake_on_connect=True): + + assert isinstance(ssl_context, ssl.SSLContext) + + ssl_sock = ssl_context.wrap_socket( + self.__sock, server_side=server_side, + server_hostname=server_hostname, + do_handshake_on_connect=do_handshake_on_connect) + + if server_side: + ssl_sock.do_handshake() + + self.__sock.close() + self.__sock = ssl_sock + + def __getattr__(self, name): + return getattr(self.__sock, name) + + def __repr__(self): + return '<{} {!r}>'.format(type(self).__name__, self.__sock) + + +class SocketThread(threading.Thread): + + def stop(self): + self._active = False + self.join() + + def __enter__(self): + self.start() + return self + + def __exit__(self, *exc): + self.stop() + + +class TestThreadedClient(SocketThread): + + def __init__(self, test, sock, prog, timeout): + threading.Thread.__init__(self, None, None, 'test-client') + self.daemon = True + + self._timeout = timeout + self._sock = sock + self._active = True + self._prog = prog + self._test = test + + def run(self): + try: + self._prog(TestSocketWrapper(self._sock)) + except (KeyboardInterrupt, SystemExit): + raise + except BaseException as ex: + self._test._abort_socket_test(ex) + + +class TestThreadedServer(SocketThread): + + def __init__(self, test, sock, prog, timeout, max_clients): + threading.Thread.__init__(self, None, None, 'test-server') + self.daemon = True + + self._clients = 0 + self._finished_clients = 0 + self._max_clients = max_clients + self._timeout = timeout + self._sock = sock + self._active = True + + self._prog = prog + + self._s1, self._s2 = socket.socketpair() + self._s1.setblocking(False) + + self._test = test + + def stop(self): + try: + if self._s2 and self._s2.fileno() != -1: + try: + self._s2.send(b'stop') + except OSError: + pass + finally: + super().stop() + + def run(self): + try: + with self._sock: + self._sock.setblocking(0) + self._run() + finally: + self._s1.close() + self._s2.close() + + def _run(self): + while self._active: + if self._clients >= self._max_clients: + return + + r, w, x = select.select( + [self._sock, self._s1], [], [], self._timeout) + + if self._s1 in r: + return + + if self._sock in r: + try: + conn, addr = self._sock.accept() + except BlockingIOError: + continue + except socket.timeout: + if not self._active: + return + else: + raise + else: + self._clients += 1 + conn.settimeout(self._timeout) + try: + with conn: + self._handle_client(conn) + except (KeyboardInterrupt, SystemExit): + raise + except BaseException as ex: + self._active = False + try: + raise + finally: + self._test._abort_socket_test(ex) + + def _handle_client(self, sock): + self._prog(TestSocketWrapper(sock)) + + @property + def addr(self): + return self._sock.getsockname() diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py index e87863eb712373..79a81bd8c39b69 100644 --- a/Lib/test/test_asyncio/test_sslproto.py +++ b/Lib/test/test_asyncio/test_sslproto.py @@ -15,7 +15,6 @@ from asyncio import log from asyncio import protocols from asyncio import sslproto -from test import support from test.test_asyncio import utils as test_utils from test.test_asyncio import functional as func_tests @@ -44,16 +43,13 @@ def ssl_protocol(self, *, waiter=None, proto=None): def connection_made(self, ssl_proto, *, do_handshake=None): transport = mock.Mock() - sslpipe = mock.Mock() - sslpipe.shutdown.return_value = b'' - if do_handshake: - sslpipe.do_handshake.side_effect = do_handshake - else: - def mock_handshake(callback): - return [] - sslpipe.do_handshake.side_effect = mock_handshake - with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe): - ssl_proto.connection_made(transport) + sslobj = mock.Mock() + # emulate reading decompressed data + sslobj.read.side_effect = ssl.SSLWantReadError + if do_handshake is not None: + sslobj.do_handshake = do_handshake + ssl_proto._sslobj = sslobj + ssl_proto.connection_made(transport) return transport def test_handshake_timeout_zero(self): @@ -75,7 +71,10 @@ def test_handshake_timeout_negative(self): def test_eof_received_waiter(self): waiter = self.loop.create_future() ssl_proto = self.ssl_protocol(waiter=waiter) - self.connection_made(ssl_proto) + self.connection_made( + ssl_proto, + do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError) + ) ssl_proto.eof_received() test_utils.run_briefly(self.loop) self.assertIsInstance(waiter.exception(), ConnectionResetError) @@ -100,7 +99,10 @@ def test_connection_lost(self): # yield from waiter hang if lost_connection was called. waiter = self.loop.create_future() ssl_proto = self.ssl_protocol(waiter=waiter) - self.connection_made(ssl_proto) + self.connection_made( + ssl_proto, + do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError) + ) ssl_proto.connection_lost(ConnectionAbortedError) test_utils.run_briefly(self.loop) self.assertIsInstance(waiter.exception(), ConnectionAbortedError) @@ -110,7 +112,10 @@ def test_close_during_handshake(self): waiter = self.loop.create_future() ssl_proto = self.ssl_protocol(waiter=waiter) - transport = self.connection_made(ssl_proto) + transport = self.connection_made( + ssl_proto, + do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError) + ) test_utils.run_briefly(self.loop) ssl_proto._app_transport.close() @@ -143,7 +148,7 @@ def test_data_received_after_closing(self): transp.close() # should not raise - self.assertIsNone(ssl_proto.data_received(b'data')) + self.assertIsNone(ssl_proto.buffer_updated(5)) def test_write_after_closing(self): ssl_proto = self.ssl_protocol() diff --git a/Misc/NEWS.d/next/Library/2021-05-02-23-44-21.bpo-44011.hd8iUO.rst b/Misc/NEWS.d/next/Library/2021-05-02-23-44-21.bpo-44011.hd8iUO.rst new file mode 100644 index 00000000000000..e2b5a9e395ef1c --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-05-02-23-44-21.bpo-44011.hd8iUO.rst @@ -0,0 +1,2 @@ +Reimplement SSL/TLS support in asyncio, borrow the impelementation from +uvloop library.