From 353f73dc23d23cd327ede08bcf1f672c438995f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20Bo=C4=8Dek?= Date: Sun, 2 Feb 2025 16:11:25 +0100 Subject: [PATCH] gh-115514: Fix incomplete writes after close while using ssl in asyncio(GH-128037) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit 4e38eeafe2ff3bfc686514731d6281fed34a435e) Co-authored-by: Vojtěch Boček Co-authored-by: Kumar Aditya --- Lib/asyncio/selector_events.py | 12 +- Lib/test/test_asyncio/test_selector_events.py | 42 +++++ Lib/test/test_asyncio/test_ssl.py | 161 ++++++++++++++++++ Misc/ACKS | 1 + ...-12-17-16-48-02.gh-issue-115514.1yOJ7T.rst | 2 + 5 files changed, 213 insertions(+), 5 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2024-12-17-16-48-02.gh-issue-115514.1yOJ7T.rst diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index f1ab9b12d69a5d..f847d7df3f8c1a 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -1181,10 +1181,13 @@ def can_write_eof(self): return True def _call_connection_lost(self, exc): - super()._call_connection_lost(exc) - if self._empty_waiter is not None: - self._empty_waiter.set_exception( - ConnectionError("Connection is closed by peer")) + try: + super()._call_connection_lost(exc) + finally: + self._write_ready = None + if self._empty_waiter is not None: + self._empty_waiter.set_exception( + ConnectionError("Connection is closed by peer")) def _make_empty_waiter(self): if self._empty_waiter is not None: @@ -1199,7 +1202,6 @@ def _reset_empty_waiter(self): def close(self): self._read_ready_cb = None - self._write_ready = None super().close() diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py index efca30f37414f9..893c9e2fba3d21 100644 --- a/Lib/test/test_asyncio/test_selector_events.py +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -1026,6 +1026,48 @@ def test_transport_close_remove_writer(self, m_log): transport.close() remove_writer.assert_called_with(self.sock_fd) + def test_write_buffer_after_close(self): + # gh-115514: If the transport is closed while: + # * Transport write buffer is not empty + # * Transport is paused + # * Protocol has data in its buffer, like SSLProtocol in self._outgoing + # The data is still written out. + + # Also tested with real SSL transport in + # test.test_asyncio.test_ssl.TestSSL.test_remote_shutdown_receives_trailing_data + + data = memoryview(b'data') + self.sock.send.return_value = 2 + self.sock.send.fileno.return_value = 7 + + def _resume_writing(): + transport.write(b"data") + self.protocol.resume_writing.side_effect = None + + self.protocol.resume_writing.side_effect = _resume_writing + + transport = self.socket_transport() + transport._high_water = 1 + + transport.write(data) + + self.assertTrue(transport._protocol_paused) + self.assertTrue(self.sock.send.called) + self.loop.assert_writer(7, transport._write_ready) + + transport.close() + + # not called, we still have data in write buffer + self.assertFalse(self.protocol.connection_lost.called) + + self.loop.writers[7]._run() + # during this ^ run, the _resume_writing mock above was called and added more data + + self.assertEqual(transport.get_write_buffer_size(), 2) + self.loop.writers[7]._run() + + self.assertEqual(transport.get_write_buffer_size(), 0) + self.assertTrue(self.protocol.connection_lost.called) class SelectorSocketTransportBufferedProtocolTests(test_utils.TestCase): diff --git a/Lib/test/test_asyncio/test_ssl.py b/Lib/test/test_asyncio/test_ssl.py index e072ede29ee3c7..e4ab5a9024c956 100644 --- a/Lib/test/test_asyncio/test_ssl.py +++ b/Lib/test/test_asyncio/test_ssl.py @@ -12,6 +12,7 @@ import tempfile import threading import time +import unittest.mock import weakref import unittest @@ -1431,6 +1432,166 @@ def wrapper(sock): with self.tcp_server(run(eof_server)) as srv: self.loop.run_until_complete(client(srv.addr)) + def test_remote_shutdown_receives_trailing_data_on_slow_socket(self): + # This test is the same as test_remote_shutdown_receives_trailing_data, + # except it simulates a socket that is not able to write data in time, + # thus triggering different code path in _SelectorSocketTransport. + # This triggers bug gh-115514, also tested using mocks in + # test.test_asyncio.test_selector_events.SelectorSocketTransportTests.test_write_buffer_after_close + # The slow path is triggered here by setting SO_SNDBUF, see code and comment below. + + CHUNK = 1024 * 128 + SIZE = 32 + + sslctx = self._create_server_ssl_context( + test_utils.ONLYCERT, + test_utils.ONLYKEY + ) + client_sslctx = self._create_client_ssl_context() + future = None + + def server(sock): + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True) + + while True: + try: + sslobj.do_handshake() + except ssl.SSLWantReadError: + if outgoing.pending: + sock.send(outgoing.read()) + incoming.write(sock.recv(16384)) + else: + if outgoing.pending: + sock.send(outgoing.read()) + break + + while True: + try: + data = sslobj.read(4) + except ssl.SSLWantReadError: + incoming.write(sock.recv(16384)) + else: + break + + self.assertEqual(data, b'ping') + sslobj.write(b'pong') + sock.send(outgoing.read()) + + time.sleep(0.2) # wait for the peer to fill its backlog + + # send close_notify but don't wait for response + with self.assertRaises(ssl.SSLWantReadError): + sslobj.unwrap() + sock.send(outgoing.read()) + + # should receive all data + data_len = 0 + while True: + try: + chunk = len(sslobj.read(16384)) + data_len += chunk + except ssl.SSLWantReadError: + incoming.write(sock.recv(16384)) + except ssl.SSLZeroReturnError: + break + + self.assertEqual(data_len, CHUNK * SIZE*2) + + # verify that close_notify is received + sslobj.unwrap() + + sock.close() + + def eof_server(sock): + sock.starttls(sslctx, server_side=True) + self.assertEqual(sock.recv_all(4), b'ping') + sock.send(b'pong') + + time.sleep(0.2) # wait for the peer to fill its backlog + + # send EOF + sock.shutdown(socket.SHUT_WR) + + # should receive all data + data = sock.recv_all(CHUNK * SIZE) + self.assertEqual(len(data), CHUNK * SIZE) + + sock.close() + + async def client(addr): + nonlocal future + future = self.loop.create_future() + + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='') + writer.write(b'ping') + data = await reader.readexactly(4) + self.assertEqual(data, b'pong') + + # fill write backlog in a hacky way - renegotiation won't help + for _ in range(SIZE*2): + writer.transport._test__append_write_backlog(b'x' * CHUNK) + + try: + data = await reader.read() + self.assertEqual(data, b'') + except (BrokenPipeError, ConnectionResetError): + pass + + # Make sure _SelectorSocketTransport enters the delayed write + # path in its `write` method by wrapping socket in a fake class + # that acts as if there is not enough space in socket buffer. + # This triggers bug gh-115514, also tested using mocks in + # test.test_asyncio.test_selector_events.SelectorSocketTransportTests.test_write_buffer_after_close + socket_transport = writer.transport._ssl_protocol._transport + + class SocketWrapper: + def __init__(self, sock) -> None: + self.sock = sock + + def __getattr__(self, name): + return getattr(self.sock, name) + + def send(self, data): + # Fake that our write buffer is full, send only half + to_send = len(data)//2 + return self.sock.send(data[:to_send]) + + def _fake_full_write_buffer(data): + if socket_transport._read_ready_cb is None and not isinstance(socket_transport._sock, SocketWrapper): + socket_transport._sock = SocketWrapper(socket_transport._sock) + return unittest.mock.DEFAULT + + with unittest.mock.patch.object( + socket_transport, "write", + wraps=socket_transport.write, + side_effect=_fake_full_write_buffer + ): + await future + + writer.close() + await self.wait_closed(writer) + + def run(meth): + def wrapper(sock): + try: + meth(sock) + except Exception as ex: + self.loop.call_soon_threadsafe(future.set_exception, ex) + else: + self.loop.call_soon_threadsafe(future.set_result, None) + return wrapper + + with self.tcp_server(run(server)) as srv: + self.loop.run_until_complete(client(srv.addr)) + + with self.tcp_server(run(eof_server)) as srv: + self.loop.run_until_complete(client(srv.addr)) + def test_connect_timeout_warning(self): s = socket.socket(socket.AF_INET) s.bind(('127.0.0.1', 0)) diff --git a/Misc/ACKS b/Misc/ACKS index 206c9b849cf0ea..54e6076773793b 100644 --- a/Misc/ACKS +++ b/Misc/ACKS @@ -188,6 +188,7 @@ Stéphane Blondon Eric Blossom Sergey Bobrov Finn Bock +Vojtěch Boček Paul Boddie Matthew Boedicker Robin Boerdijk diff --git a/Misc/NEWS.d/next/Library/2024-12-17-16-48-02.gh-issue-115514.1yOJ7T.rst b/Misc/NEWS.d/next/Library/2024-12-17-16-48-02.gh-issue-115514.1yOJ7T.rst new file mode 100644 index 00000000000000..24e836a0b0b7f9 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2024-12-17-16-48-02.gh-issue-115514.1yOJ7T.rst @@ -0,0 +1,2 @@ +Fix exceptions and incomplete writes after :class:`!asyncio._SelectorTransport` +is closed before writes are completed.