From c9e5b67122f92eae2790d75f40225527341d5522 Mon Sep 17 00:00:00 2001 From: Russell Keith-Magee Date: Wed, 31 Jul 2024 16:24:15 +0800 Subject: [PATCH] gh-122133: Rework pure Python socketpair tests to avoid use of importlib.reload. (GH-122493) (cherry picked from commit f071f01b7b7e19d7d6b3a4b0ec62f820ecb14660) Co-authored-by: Russell Keith-Magee Co-authored-by: Gregory P. Smith --- Lib/socket.py | 121 +++++++++++++++++++--------------------- Lib/test/test_socket.py | 20 ++----- 2 files changed, 64 insertions(+), 77 deletions(-) diff --git a/Lib/socket.py b/Lib/socket.py index 591d4739a64a91..f386241abfb79b 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -590,16 +590,65 @@ def fromshare(info): return socket(0, 0, 0, info) __all__.append("fromshare") -if hasattr(_socket, "socketpair"): +# Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. +# This is used if _socket doesn't natively provide socketpair. It's +# always defined so that it can be patched in for testing purposes. +def _fallback_socketpair(family=AF_INET, type=SOCK_STREAM, proto=0): + if family == AF_INET: + host = _LOCALHOST + elif family == AF_INET6: + host = _LOCALHOST_V6 + else: + raise ValueError("Only AF_INET and AF_INET6 socket address families " + "are supported") + if type != SOCK_STREAM: + raise ValueError("Only SOCK_STREAM socket type is supported") + if proto != 0: + raise ValueError("Only protocol zero is supported") + + # We create a connected TCP socket. Note the trick with + # setblocking(False) that prevents us from having to create a thread. + lsock = socket(family, type, proto) + try: + lsock.bind((host, 0)) + lsock.listen() + # On IPv6, ignore flow_info and scope_id + addr, port = lsock.getsockname()[:2] + csock = socket(family, type, proto) + try: + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + csock.setblocking(True) + ssock, _ = lsock.accept() + except: + csock.close() + raise + finally: + lsock.close() - def socketpair(family=None, type=SOCK_STREAM, proto=0): - """socketpair([family[, type[, proto]]]) -> (socket object, socket object) + # Authenticating avoids using a connection from something else + # able to connect to {host}:{port} instead of us. + # We expect only AF_INET and AF_INET6 families. + try: + if ( + ssock.getsockname() != csock.getpeername() + or csock.getsockname() != ssock.getpeername() + ): + raise ConnectionError("Unexpected peer connection") + except: + # getsockname() and getpeername() can fail + # if either socket isn't connected. + ssock.close() + csock.close() + raise - Create a pair of socket objects from the sockets returned by the platform - socketpair() function. - The arguments are the same as for socket() except the default family is - AF_UNIX if defined on the platform; otherwise, the default is AF_INET. - """ + return (ssock, csock) + +if hasattr(_socket, "socketpair"): + def socketpair(family=None, type=SOCK_STREAM, proto=0): if family is None: try: family = AF_UNIX @@ -611,61 +660,7 @@ def socketpair(family=None, type=SOCK_STREAM, proto=0): return a, b else: - - # Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. - def socketpair(family=AF_INET, type=SOCK_STREAM, proto=0): - if family == AF_INET: - host = _LOCALHOST - elif family == AF_INET6: - host = _LOCALHOST_V6 - else: - raise ValueError("Only AF_INET and AF_INET6 socket address families " - "are supported") - if type != SOCK_STREAM: - raise ValueError("Only SOCK_STREAM socket type is supported") - if proto != 0: - raise ValueError("Only protocol zero is supported") - - # We create a connected TCP socket. Note the trick with - # setblocking(False) that prevents us from having to create a thread. - lsock = socket(family, type, proto) - try: - lsock.bind((host, 0)) - lsock.listen() - # On IPv6, ignore flow_info and scope_id - addr, port = lsock.getsockname()[:2] - csock = socket(family, type, proto) - try: - csock.setblocking(False) - try: - csock.connect((addr, port)) - except (BlockingIOError, InterruptedError): - pass - csock.setblocking(True) - ssock, _ = lsock.accept() - except: - csock.close() - raise - finally: - lsock.close() - - # Authenticating avoids using a connection from something else - # able to connect to {host}:{port} instead of us. - # We expect only AF_INET and AF_INET6 families. - try: - if ( - ssock.getsockname() != csock.getpeername() - or csock.getsockname() != ssock.getpeername() - ): - raise ConnectionError("Unexpected peer connection") - except: - # getsockname() and getpeername() can fail - # if either socket isn't connected. - ssock.close() - csock.close() - raise - - return (ssock, csock) + socketpair = _fallback_socketpair __all__.append("socketpair") socketpair.__doc__ = """socketpair([family[, type[, proto]]]) -> (socket object, socket object) diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index a60eb436c7b5e0..cc803d8753b513 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -4676,7 +4676,6 @@ def _testSend(self): class PurePythonSocketPairTest(SocketPairTest): - # Explicitly use socketpair AF_INET or AF_INET6 to ensure that is the # code path we're using regardless platform is the pure python one where # `_socket.socketpair` does not exist. (AF_INET does not work with @@ -4691,28 +4690,21 @@ def socketpair(self): # Local imports in this class make for easy security fix backporting. def setUp(self): - import _socket - self._orig_sp = getattr(_socket, 'socketpair', None) - if self._orig_sp is not None: + if hasattr(_socket, "socketpair"): + self._orig_sp = socket.socketpair # This forces the version using the non-OS provided socketpair # emulation via an AF_INET socket in Lib/socket.py. - del _socket.socketpair - import importlib - global socket - socket = importlib.reload(socket) + socket.socketpair = socket._fallback_socketpair else: - pass # This platform already uses the non-OS provided version. + # This platform already uses the non-OS provided version. + self._orig_sp = None super().setUp() def tearDown(self): super().tearDown() - import _socket if self._orig_sp is not None: # Restore the default socket.socketpair definition. - _socket.socketpair = self._orig_sp - import importlib - global socket - socket = importlib.reload(socket) + socket.socketpair = self._orig_sp def test_recv(self): msg = self.serv.recv(1024)