diff --git a/Lib/test/support/threading_helper.py b/Lib/test/support/threading_helper.py index b9973c8bf5c914..1cd2a020ced0d6 100644 --- a/Lib/test/support/threading_helper.py +++ b/Lib/test/support/threading_helper.py @@ -3,11 +3,27 @@ import functools import sys import threading +from socket import socket import time import unittest +from concurrent.futures import ThreadPoolExecutor +from test.support.socket_helper import bind_port, HOST from test import support +_thread_pool = None + + +def _release(): + global _thread_pool + _thread_pool = None + + +def init(): + global _thread_pool + _thread_pool = ThreadPoolExecutor() + unittest.addModuleCleanup(_release) + #======================================================================= # Threading support to prevent reporting refleaks when running regrtest.py -R @@ -240,3 +256,89 @@ def requires_working_threading(*, module=False): raise unittest.SkipTest(msg) else: return unittest.skipUnless(can_start_thread, msg) + + +class Server: + """A context manager for a blocking server in a thread pool. + + The server is designed: + + - for testing purposes so it serves a fixed count of clients, one by one + - to be one-pass, short-lived, and terminated by in-protocol means so no + stopper flag is used + - to be used where asyncio has no application + + The server listens on an address returned from the ``with`` statement. + + For each client connected, the server calls a user-supplied function and + preserves whatever the function returns or throws to pass it to a client + later. + + When a client attempt to exit the context manager, it blocks until a server + stops processing all clients and exits. + """ + + def __init__(self, client_func, *args, client_count=1, **kwargs): + """Create and run the server. + + The method blocks until the server is ready to accept clients. + + After this constructor returns, the server: + + 1. Consequently waits for client_count clients + 1. For each client: + a. Calls client_func for each of them + b. Closes client connection when the function returns + c. Collects returned values into Server.result list + 5. Terminates a server + 6. Allows a context manager to exit + 7. Since ``with ... as`` section keeps its parameter alive, + Server.result field can be accessed outside of the section. + + If client_func raises an exception, the server is stopped, all pending + clients are discarded and the context manager raises an exception. + + Args: + client_func: a function called in a dedicated thread for each new + connected client. The function receives all argument passed to + the __init__ method excluding client_func and client_count. + args: positional arguments passed to client_func. + client_count: count of clients the server processes one by one + before stopping. + results: a reference to a list for collecting client_func + return values. Populated after the execution leaves a ``with`` + blocks associated with the Server context manager. + kwargs: keyword arguments passed to client_func. + """ + server_socket = socket() + self._port = bind_port(server_socket) + server_socket.listen() + self._result = _thread_pool.submit(self._thread_func, server_socket, + client_func, client_count, + args, kwargs) + + def _thread_func(self, server_socket, client_func, client_count, + args, kwargs): + with server_socket: + results = [] + for i in range(client_count): + client, peer_address = server_socket.accept() + with client: + r = client_func(client, peer_address, *args, **kwargs) + results.append(r) + return results + + def __enter__(self): + return HOST, self._port + + def __exit__(self, etype, evalue, traceback): + peer_willingly_closed = isinstance(etype, ConnectionError) + # We find our client disappeared when our socket read() fails. + peer_disappeared = etype is OSError + if peer_willingly_closed or peer_disappeared: + if self._result.exception() is not None: + generic = RuntimeError('server-side error') + raise generic from self._result.exception() + return False + self.result = self._result.result() + return False diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index d4eb2d2e81fe0f..dc15c36d2b175b 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -318,6 +318,164 @@ def testing_context(server_cert=SIGNED_CERTFILE, *, server_chain=True): return client_context, server_context, hostname +def _on_ssl_client(socket, peer_address, certificate=None, + certreqs=ssl.CERT_NONE, cacerts=None, + chatty=True, starttls_server=False, + alpn_protocols=None, ciphers=None, context=None): + # A mildly complicated server, because we want it to work both + # with and without the SSL wrapper around the socket connection, so + # that we can test the STARTTLS functionality. + + def log(message): + if support.verbose and chatty: + sys.stdout.write(f' server: {message}\n') + + if context is None: + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + context.verify_mode = certreqs + if cacerts: + context.load_verify_locations(cacerts) + if certificate: + context.load_cert_chain(certificate) + if alpn_protocols: + context.set_alpn_protocols(alpn_protocols) + if ciphers: + context.set_ciphers(ciphers) + + # Returned via the future + selected_alpn_protocols = [] + shared_ciphers = [] + + # Functions swithed on wrapping/unwrapping + read = lambda: socket.recv(1024) + write = socket.send + # A caller of on_client will close the socket + close = lambda: None + + def wrap_conn(socket): + # Notes on how to treat exceptions thrown by wrap_socket: + # + # We treat ConnectionResetError as though it were an + # SSLError - OpenSSL on Ubuntu abruptly closes the + # connection when asked to use an unsupported protocol. + # + # BrokenPipeError is raised in TLS 1.3 mode, when OpenSSL + # tries to send session tickets after handshake. + # https://github.com/openssl/openssl/issues/6342 + # + # ConnectionAbortedError is raised in TLS 1.3 mode, when OpenSSL + # tries to send session tickets after handshake when using WinSock. + # + # OSError may occur with wrong protocols, e.g. both + # sides use PROTOCOL_TLS_SERVER. + + try: + sslconn = context.wrap_socket(socket, server_side=True) + nonlocal read, write, close + read = lambda: sslconn.recv(1024) + write = sslconn.write + close = sslconn.close + + nonlocal selected_alpn_protocols, shared_ciphers + selected_alpn_protocols = sslconn.selected_alpn_protocol() + shared_ciphers = sslconn.shared_ciphers() + + if context.verify_mode == ssl.CERT_REQUIRED: + cert = sslconn.getpeercert() + log(f'client cert is {pprint.pformat(cert)}') + cert_binary = sslconn.getpeercert(True) + if cert_binary is None: + log('client did not provide a cert') + else: + log(f'cert binary is {len(cert_binary)}b') + + log(f'connection cipher is now {sslconn.cipher()}') + return sslconn + + except (ssl.SSLError, OSError) as e: + # bpo-44229, bpo-43855, bpo-44237, and bpo-33450: + # Ignore spurious EPROTOTYPE returned by write() on macOS. + # See also http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/ + if e.errno != errno.EPROTOTYPE and sys.platform != 'darwin': + raise + + log(f'new connection from {peer_address!r}') + sslconn = None if starttls_server else wrap_conn(socket) + + forced_exit = False + while not forced_exit and (msg := read()): + stripped = msg.strip() + + if stripped == b'over': + log('client closed connection') + forced_exit = True + close() + + elif starttls_server and stripped == b'STARTTLS': + log('read STARTTLS from client, sending OK') + write(b'OK\n') + sslconn = wrap_conn(socket) + + elif starttls_server and sslconn and stripped == b'ENDTLS': + log('read ENDTLS from client, sending OK') + write(b'OK\n') + socket = sslconn.unwrap() + sslconn = None + log('connection is now unencrypted') + + elif stripped == b'CB tls-unique': + log('read CB tls-unique from client, sending our CB data') + data = sslconn.get_channel_binding('tls-unique') + write(repr(data).encode('us-ascii') + b'\n') + + elif stripped == b'PHA': + log('initiating post handshake auth') + try: + sslconn.verify_client_post_handshake() + except ssl.SSLError as e: + write(repr(e).encode('us-ascii') + b'\n') + else: + write(b'OK\n') + + elif stripped == b'HASCERT': + if sslconn.getpeercert() is not None: + write(b'TRUE\n') + else: + write(b'FALSE\n') + + elif stripped == b'GETCERT': + cert = sslconn.getpeercert() + write(repr(cert).encode('us-ascii') + b'\n') + + elif stripped == b'VERIFIEDCHAIN': + certs = sslconn._sslobj.get_verified_chain() + write(len(certs).to_bytes(1, 'big') + b'\n') + + elif stripped == b'UNVERIFIEDCHAIN': + certs = sslconn._sslobj.get_unverified_chain() + write(len(certs).to_bytes(1, 'big') + b'\n') + + else: + ctype = 'encrypted' if sslconn else 'unencrypted' + in_str = msg.decode() + out_str = in_str.lower() + log(f'read {in_str} ({ctype}), sending back {out_str} ({ctype})') + write(out_str.encode()) + + try: + socket = sslconn.unwrap() + except OSError: + # Many tests shut the TCP connection down without an SSL shutdown. + # This causes unwrap() to raise OSError with errno=0. + pass + + close() + return selected_alpn_protocols, shared_ciphers + + +Server = functools.partial(threading_helper.Server, _on_ssl_client) + + class BasicSocketTests(unittest.TestCase): def test_constants(self): @@ -889,13 +1047,15 @@ def test_connect_ex_error(self): def test_read_write_zero(self): # empty reads and writes now work, bpo-42854, bpo-31711 client_context, server_context, hostname = testing_context() - server = ThreadedEchoServer(context=server_context) - with server: + with Server(context=server_context) as address: with client_context.wrap_socket(socket.socket(), server_hostname=hostname) as s: - s.connect((HOST, server.port)) + s.connect(address) self.assertEqual(s.recv(0), b"") self.assertEqual(s.send(b""), 0) + # OpenSSL postpones the handshake until some data are sent so + # force it by queuing explicit shutdown. + s.unwrap().close() class ContextTests(unittest.TestCase): @@ -2699,6 +2859,20 @@ def try_protocol_combo(server_protocol, client_protocol, expect_success, class ThreadedTests(unittest.TestCase): + def wait_connection(self, socket): + """Force a socket to immediately initiate and process a TLS handshake. + + OpenSSL delays TLS 1.3 session ticket exchange until a socket user + attempts to send some data. As a result, we need to write some + non-empty string to force the handshake and avoid server-side + ConnectionAbortedError ("An established connection was aborted by the + software in your host machine") when some test has nothing to send and + closes a half-open TLS connection. + """ + echo_message = b'hi' + socket.write(echo_message) + socket.read(len(echo_message)) + def test_echo(self): """Basic test of an SSL client connecting to a server""" if support.verbose: @@ -2819,38 +2993,52 @@ def test_crl_check(self): cert = s.getpeercert() self.assertTrue(cert, "Can't get peer certificate.") - def test_check_hostname(self): + def test_check_hostname_correct(self): if support.verbose: sys.stdout.write("\n") client_context, server_context, hostname = testing_context() - # correct hostname should verify - server = ThreadedEchoServer(context=server_context, chatty=True) - with server: + with Server(context=server_context) as address: with client_context.wrap_socket(socket.socket(), server_hostname=hostname) as s: - s.connect((HOST, server.port)) + s.connect(address) + self.wait_connection(s) cert = s.getpeercert() self.assertTrue(cert, "Can't get peer certificate.") - # incorrect hostname should raise an exception - server = ThreadedEchoServer(context=server_context, chatty=True) - with server: - with client_context.wrap_socket(socket.socket(), - server_hostname="invalid") as s: - with self.assertRaisesRegex( - ssl.CertificateError, - "Hostname mismatch, certificate is not valid for 'invalid'."): - s.connect((HOST, server.port)) + def test_check_hostname_incorrect(self): + if support.verbose: + sys.stdout.write("\n") - # missing server_hostname arg should cause an exception, too - server = ThreadedEchoServer(context=server_context, chatty=True) - with server: - with socket.socket() as s: - with self.assertRaisesRegex(ValueError, - "check_hostname requires server_hostname"): - client_context.wrap_socket(s) + if sys.platform == 'darwin': + server_exc_type = OSError + server_exc_msg = '[Errno 9]' + else: + server_exc_type = ssl.SSLError + server_exc_msg = 'SSLV3_ALERT_BAD_CERTIFICATE' + + with self.assertRaisesRegex(server_exc_type, server_exc_msg): + client_context, server_context, hostname = testing_context() + + with Server(context=server_context) as address: + with client_context.wrap_socket(socket.socket(), + server_hostname="invalid") as s: + with self.assertRaisesRegex( + ssl.CertificateError, + "Hostname mismatch, certificate is not valid for 'invalid'."): + s.connect(address) + + def test_check_hostname_missing(self): + if support.verbose: + sys.stdout.write("\n") + + client_context, server_context, hostname = testing_context() + + with socket.socket() as s: + with self.assertRaisesRegex(ValueError, + "check_hostname requires server_hostname"): + client_context.wrap_socket(s) @unittest.skipUnless( ssl.HAS_NEVER_CHECK_COMMON_NAME, "test requires hostname_checks_common_name" @@ -2887,11 +3075,11 @@ def test_ecc_cert(self): server_context.load_cert_chain(SIGNED_CERTFILE_ECC) # correct hostname should verify - server = ThreadedEchoServer(context=server_context, chatty=True) - with server: + with Server(context=server_context) as address: with client_context.wrap_socket(socket.socket(), server_hostname=hostname) as s: - s.connect((HOST, server.port)) + s.connect(address) + self.wait_connection(s) cert = s.getpeercert() self.assertTrue(cert, "Can't get peer certificate.") cipher = s.cipher()[0].split('-') @@ -2913,11 +3101,11 @@ def test_dual_rsa_ecc(self): server_context.load_cert_chain(SIGNED_CERTFILE) # correct hostname should verify - server = ThreadedEchoServer(context=server_context, chatty=True) - with server: + with Server(context=server_context) as address: with client_context.wrap_socket(socket.socket(), server_hostname=hostname) as s: - s.connect((HOST, server.port)) + s.connect(address) + self.wait_connection(s) cert = s.getpeercert() self.assertTrue(cert, "Can't get peer certificate.") cipher = s.cipher()[0].split('-') @@ -3740,21 +3928,17 @@ def test_default_ecdh_curve(self): @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES, "'tls-unique' channel binding not available") def test_tls_unique_channel_binding(self): - """Test tls-unique channel binding.""" if support.verbose: sys.stdout.write("\n") client_context, server_context, hostname = testing_context() - server = ThreadedEchoServer(context=server_context, - chatty=True, - connectionchatty=False) - - with server: + with Server(context=server_context, client_count=2) as address: with client_context.wrap_socket( socket.socket(), server_hostname=hostname) as s: - s.connect((HOST, server.port)) + s.connect(address) + self.wait_connection(s) # get the data cb_data = s.get_channel_binding("tls-unique") if support.verbose: @@ -3778,7 +3962,7 @@ def test_tls_unique_channel_binding(self): with client_context.wrap_socket( socket.socket(), server_hostname=hostname) as s: - s.connect((HOST, server.port)) + s.connect(address) new_cb_data = s.get_channel_binding("tls-unique") if support.verbose: sys.stdout.write( @@ -4147,14 +4331,14 @@ def test_session_handling(self): client_context.maximum_version = ssl.TLSVersion.TLSv1_2 client_context2.maximum_version = ssl.TLSVersion.TLSv1_2 - server = ThreadedEchoServer(context=server_context, chatty=False) - with server: + with Server(context=server_context, chatty=False, client_count=3) as address: with client_context.wrap_socket(socket.socket(), server_hostname=hostname) as s: # session is None before handshake self.assertEqual(s.session, None) self.assertEqual(s.session_reused, None) - s.connect((HOST, server.port)) + s.connect(address) + self.wait_connection(s) session = s.session self.assertTrue(session) with self.assertRaises(TypeError) as e: @@ -4163,7 +4347,8 @@ def test_session_handling(self): with client_context.wrap_socket(socket.socket(), server_hostname=hostname) as s: - s.connect((HOST, server.port)) + s.connect(address) + self.wait_connection(s) # cannot set session after handshake with self.assertRaises(ValueError) as e: s.session = session @@ -4175,7 +4360,8 @@ def test_session_handling(self): # can set session before handshake and before the # connection was established s.session = session - s.connect((HOST, server.port)) + s.connect(address) + self.wait_connection(s) self.assertEqual(s.session.id, session.id) self.assertEqual(s.session, session) self.assertEqual(s.session_reused, True) @@ -4185,7 +4371,7 @@ def test_session_handling(self): # cannot re-use session with a different SSLContext with self.assertRaises(ValueError) as e: s.session = session - s.connect((HOST, server.port)) + s.connect(address) self.assertEqual(str(e.exception), 'Session refers to a different SSLContext.') @@ -4803,8 +4989,8 @@ def setUpModule(): if not os.path.exists(filename): raise support.TestFailed("Can't read certificate file %r" % filename) - thread_info = threading_helper.threading_setup() - unittest.addModuleCleanup(threading_helper.threading_cleanup, *thread_info) + threading_helper.init() + socket.setdefaulttimeout(support.LOOPBACK_TIMEOUT) if __name__ == "__main__": diff --git a/Misc/NEWS.d/next/Tests/2022-07-06-18-46-00.gh-issue-94609.eeM1cu.rst b/Misc/NEWS.d/next/Tests/2022-07-06-18-46-00.gh-issue-94609.eeM1cu.rst new file mode 100644 index 00000000000000..3e4192c058b158 --- /dev/null +++ b/Misc/NEWS.d/next/Tests/2022-07-06-18-46-00.gh-issue-94609.eeM1cu.rst @@ -0,0 +1,3 @@ +Add a unified :class:`test.support.threading_helper.Server` context manager +to create, wait, destroy and communicate with a blocking TCP server both +in-band and out-of-band. Patch by Oleg Iarygin.