Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[3.11] gh-122133: Rework pure Python socketpair tests to avoid use of importlib.reload. (GH-122493) #122506

Merged
merged 1 commit into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 58 additions & 63 deletions Lib/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,16 +590,65 @@ def fromshare(info):
return socket(0, 0, 0, info)
__all__.append("fromshare")

if hasattr(_socket, "socketpair"):
# Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain.
# This is used if _socket doesn't natively provide socketpair. It's
# always defined so that it can be patched in for testing purposes.
def _fallback_socketpair(family=AF_INET, type=SOCK_STREAM, proto=0):
if family == AF_INET:
host = _LOCALHOST
elif family == AF_INET6:
host = _LOCALHOST_V6
else:
raise ValueError("Only AF_INET and AF_INET6 socket address families "
"are supported")
if type != SOCK_STREAM:
raise ValueError("Only SOCK_STREAM socket type is supported")
if proto != 0:
raise ValueError("Only protocol zero is supported")

# We create a connected TCP socket. Note the trick with
# setblocking(False) that prevents us from having to create a thread.
lsock = socket(family, type, proto)
try:
lsock.bind((host, 0))
lsock.listen()
# On IPv6, ignore flow_info and scope_id
addr, port = lsock.getsockname()[:2]
csock = socket(family, type, proto)
try:
csock.setblocking(False)
try:
csock.connect((addr, port))
except (BlockingIOError, InterruptedError):
pass
csock.setblocking(True)
ssock, _ = lsock.accept()
except:
csock.close()
raise
finally:
lsock.close()

def socketpair(family=None, type=SOCK_STREAM, proto=0):
"""socketpair([family[, type[, proto]]]) -> (socket object, socket object)
# Authenticating avoids using a connection from something else
# able to connect to {host}:{port} instead of us.
# We expect only AF_INET and AF_INET6 families.
try:
if (
ssock.getsockname() != csock.getpeername()
or csock.getsockname() != ssock.getpeername()
):
raise ConnectionError("Unexpected peer connection")
except:
# getsockname() and getpeername() can fail
# if either socket isn't connected.
ssock.close()
csock.close()
raise

Create a pair of socket objects from the sockets returned by the platform
socketpair() function.
The arguments are the same as for socket() except the default family is
AF_UNIX if defined on the platform; otherwise, the default is AF_INET.
"""
return (ssock, csock)

if hasattr(_socket, "socketpair"):
def socketpair(family=None, type=SOCK_STREAM, proto=0):
if family is None:
try:
family = AF_UNIX
Expand All @@ -611,61 +660,7 @@ def socketpair(family=None, type=SOCK_STREAM, proto=0):
return a, b

else:

# Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain.
def socketpair(family=AF_INET, type=SOCK_STREAM, proto=0):
if family == AF_INET:
host = _LOCALHOST
elif family == AF_INET6:
host = _LOCALHOST_V6
else:
raise ValueError("Only AF_INET and AF_INET6 socket address families "
"are supported")
if type != SOCK_STREAM:
raise ValueError("Only SOCK_STREAM socket type is supported")
if proto != 0:
raise ValueError("Only protocol zero is supported")

# We create a connected TCP socket. Note the trick with
# setblocking(False) that prevents us from having to create a thread.
lsock = socket(family, type, proto)
try:
lsock.bind((host, 0))
lsock.listen()
# On IPv6, ignore flow_info and scope_id
addr, port = lsock.getsockname()[:2]
csock = socket(family, type, proto)
try:
csock.setblocking(False)
try:
csock.connect((addr, port))
except (BlockingIOError, InterruptedError):
pass
csock.setblocking(True)
ssock, _ = lsock.accept()
except:
csock.close()
raise
finally:
lsock.close()

# Authenticating avoids using a connection from something else
# able to connect to {host}:{port} instead of us.
# We expect only AF_INET and AF_INET6 families.
try:
if (
ssock.getsockname() != csock.getpeername()
or csock.getsockname() != ssock.getpeername()
):
raise ConnectionError("Unexpected peer connection")
except:
# getsockname() and getpeername() can fail
# if either socket isn't connected.
ssock.close()
csock.close()
raise

return (ssock, csock)
socketpair = _fallback_socketpair
__all__.append("socketpair")

socketpair.__doc__ = """socketpair([family[, type[, proto]]]) -> (socket object, socket object)
Expand Down
20 changes: 6 additions & 14 deletions Lib/test/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -4676,7 +4676,6 @@ def _testSend(self):


class PurePythonSocketPairTest(SocketPairTest):

# Explicitly use socketpair AF_INET or AF_INET6 to ensure that is the
# code path we're using regardless platform is the pure python one where
# `_socket.socketpair` does not exist. (AF_INET does not work with
Expand All @@ -4691,28 +4690,21 @@ def socketpair(self):
# Local imports in this class make for easy security fix backporting.

def setUp(self):
import _socket
self._orig_sp = getattr(_socket, 'socketpair', None)
if self._orig_sp is not None:
if hasattr(_socket, "socketpair"):
self._orig_sp = socket.socketpair
# This forces the version using the non-OS provided socketpair
# emulation via an AF_INET socket in Lib/socket.py.
del _socket.socketpair
import importlib
global socket
socket = importlib.reload(socket)
socket.socketpair = socket._fallback_socketpair
else:
pass # This platform already uses the non-OS provided version.
# This platform already uses the non-OS provided version.
self._orig_sp = None
super().setUp()

def tearDown(self):
super().tearDown()
import _socket
if self._orig_sp is not None:
# Restore the default socket.socketpair definition.
_socket.socketpair = self._orig_sp
import importlib
global socket
socket = importlib.reload(socket)
socket.socketpair = self._orig_sp

def test_recv(self):
msg = self.serv.recv(1024)
Expand Down
Loading