Skip to content

Commit 95cf8fa

Browse files
authored
Update tests to not rely on mutating contexts (#1445)
1 parent 946f87e commit 95cf8fa

File tree

1 file changed

+39
-17
lines changed

1 file changed

+39
-17
lines changed

tests/test_ssl.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,6 +1163,7 @@ def test_set_proto_version(self) -> None:
11631163
with pytest.raises(Error, match="unsupported protocol"):
11641164
self._handshake_test(server_context, client_context)
11651165

1166+
client_context = Context(TLS_METHOD)
11661167
client_context.set_max_proto_version(0)
11671168
self._handshake_test(server_context, client_context)
11681169

@@ -1577,22 +1578,23 @@ def test_set_verify_callback_reference(self) -> None:
15771578
load_certificate(FILETYPE_PEM, root_cert_pem)
15781579
)
15791580

1580-
clientContext = Context(TLSv1_2_METHOD)
1581-
15821581
clients = []
15831582

15841583
for i in range(5):
15851584

15861585
def verify_callback(*args: object) -> bool:
15871586
return True
15881587

1588+
# Create a fresh client context for each iteration since contexts
1589+
# cannot be mutated after use
1590+
clientContext = Context(TLSv1_2_METHOD)
1591+
clientContext.set_verify(VERIFY_PEER, verify_callback)
1592+
15891593
serverSocket, clientSocket = socket_pair()
15901594
client = Connection(clientContext, clientSocket)
15911595

15921596
clients.append((serverSocket, client))
15931597

1594-
clientContext.set_verify(VERIFY_PEER, verify_callback)
1595-
15961598
gc.collect()
15971599

15981600
# Make them talk to each other.
@@ -2921,21 +2923,22 @@ def callback(
29212923
del callback
29222924

29232925
conn = Connection(context, None)
2924-
context.set_verify(VERIFY_NONE)
29252926

29262927
collect()
29272928
collect()
29282929
assert tracker()
29292930

2931+
# Setting a new callback on the connection should maintain the original
2932+
# context callback reference
29302933
conn.set_verify(
29312934
VERIFY_PEER, lambda conn, cert, errnum, depth, ok: bool(ok)
29322935
)
29332936
collect()
29342937
collect()
2938+
2939+
# The callback should still be referenced - check that it exists
29352940
callback_ref = tracker()
2936-
if callback_ref is not None: # pragma: nocover
2937-
referrers = get_referrers(callback_ref)
2938-
assert len(referrers) == 1
2941+
assert callback_ref is not None
29392942

29402943
def test_get_session_unconnected(self) -> None:
29412944
"""
@@ -3973,12 +3976,10 @@ class TestMemoryBIO:
39733976
Tests for `OpenSSL.SSL.Connection` using a memory BIO.
39743977
"""
39753978

3976-
def _server(self, sock: socket | None) -> Connection:
3979+
def _create_server_context(self) -> Context:
39773980
"""
3978-
Create a new server-side SSL `Connection` object wrapped around `sock`.
3981+
Create a configured server context with certificates and options.
39793982
"""
3980-
# Create the server side Connection. This is mostly setup boilerplate
3981-
# - use TLSv1, use a particular certificate, etc.
39823983
server_ctx = Context(SSLv23_METHOD)
39833984
server_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE)
39843985
server_ctx.set_verify(
@@ -3995,6 +3996,23 @@ def _server(self, sock: socket | None) -> Connection:
39953996
)
39963997
server_ctx.check_privatekey()
39973998
server_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem))
3999+
return server_ctx
4000+
4001+
def _server(
4002+
self, sock: socket | None, ctx: Context | None = None
4003+
) -> Connection:
4004+
"""
4005+
Create a new server-side SSL `Connection` object wrapped around `sock`.
4006+
4007+
:param sock: The socket to wrap, or None for memory BIO.
4008+
:param ctx: Optional pre-configured context. If None, creates a
4009+
default server context.
4010+
"""
4011+
if ctx is None:
4012+
server_ctx = self._create_server_context()
4013+
else:
4014+
server_ctx = ctx
4015+
39984016
# Here the Connection is actually created. If None is passed as the
39994017
# 2nd parameter, it indicates a memory BIO should be created.
40004018
server_conn = Connection(server_ctx, sock)
@@ -4204,12 +4222,16 @@ def _check_client_ca_list(
42044222
that `get_client_ca_list` returns the proper value at
42054223
various times.
42064224
"""
4207-
server = self._server(None)
4225+
# Create a server context and configure it before creating connections
4226+
server_ctx = self._create_server_context()
4227+
4228+
# Configure the CA list before creating connections
4229+
expected = func(server_ctx)
4230+
4231+
# Now create connections with the configured context
4232+
server = self._server(None, server_ctx)
42084233
client = self._client(None)
4209-
assert client.get_client_ca_list() == []
4210-
assert server.get_client_ca_list() == []
4211-
ctx = server.get_context()
4212-
expected = func(ctx)
4234+
42134235
assert client.get_client_ca_list() == []
42144236
assert server.get_client_ca_list() == expected
42154237
interact_in_memory(client, server)

0 commit comments

Comments
 (0)