Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-115514: Fix incomplete writes after close in asyncio._SelectorSocketTransport #128037

Merged
merged 7 commits into from
Feb 2, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions Lib/asyncio/selector_events.py
Original file line number Diff line number Diff line change
@@ -1185,10 +1185,13 @@ def can_write_eof(self):
return True

def _call_connection_lost(self, exc):
super()._call_connection_lost(exc)
if self._empty_waiter is not None:
self._empty_waiter.set_exception(
ConnectionError("Connection is closed by peer"))
try:
super()._call_connection_lost(exc)
finally:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good idea to add a comment explaining why we need this (and link the issue number as well).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the try/finally? It calls user-supplied callback internally, I'm just doing the same thing the parent SelectorTransport is:

def _call_connection_lost(self, exc):
try:
if self._protocol_connected:
self._protocol.connection_lost(exc)
finally:
self._sock.close()
self._sock = None
self._protocol = None
self._loop = None
server = self._server
if server is not None:
server._detach(self)
self._server = None

Do you think I should write a comment here about that?

self._write_ready = None
if self._empty_waiter is not None:
self._empty_waiter.set_exception(
ConnectionError("Connection is closed by peer"))

def _make_empty_waiter(self):
if self._empty_waiter is not None:
@@ -1203,7 +1206,6 @@ def _reset_empty_waiter(self):

def close(self):
self._read_ready_cb = None
self._write_ready = None
super().close()


42 changes: 42 additions & 0 deletions Lib/test/test_asyncio/test_selector_events.py
Original file line number Diff line number Diff line change
@@ -1051,6 +1051,48 @@ def test_transport_close_remove_writer(self, m_log):
transport.close()
remove_writer.assert_called_with(self.sock_fd)

def test_write_buffer_after_close(self):
# gh-115514: If the transport is closed while:
# * Transport write buffer is not empty
# * Transport is paused
# * Protocol has data in its buffer, like SSLProtocol in self._outgoing
# The data is still written out.

# Also tested with real SSL transport in
# test.test_asyncio.test_ssl.TestSSL.test_remote_shutdown_receives_trailing_data

data = memoryview(b'data')
self.sock.send.return_value = 2
self.sock.send.fileno.return_value = 7

def _resume_writing():
transport.write(b"data")
self.protocol.resume_writing.side_effect = None

self.protocol.resume_writing.side_effect = _resume_writing

transport = self.socket_transport()
transport._high_water = 1

transport.write(data)

self.assertTrue(transport._protocol_paused)
self.assertTrue(self.sock.send.called)
self.loop.assert_writer(7, transport._write_ready)

transport.close()

# not called, we still have data in write buffer
self.assertFalse(self.protocol.connection_lost.called)

self.loop.writers[7]._run()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handlers should not be run manually, please change test to not rely on it

Also I think this test should really be in ssl tests not here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, could you please suggest on how to do that? The TestLoop used here does not process the writers when it is ran, so for example self.loop.run_until_complete(asyncio.sleep(0)) does not work.

Also I think this test should really be in ssl tests not here

I don't think so, this is behavior of the SocketTransport, SSL just happens to be the one prevalent user that relies on this behavior.

There already happens to be a more real-world-like test in SSL that almost triggers the bad path here, but I would somehow have to make socket.send not consume the whole data to actually trigger this bug. IMHO its nicer to have the more artificial test here that is easier to understand and triggers the bugged path reliably.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There already happens to be a more real-world-like test in SSL that almost triggers the bad path here, but I would somehow have to make socket.send not consume the whole data to actually trigger this bug. IMHO its nicer to have the more artificial test here that is easier to understand and triggers the bugged path reliably.

You can use mock for socket.send to trigger the bug right? That would be a more realistic test than this one

Copy link
Contributor Author

@Tasssadar Tasssadar Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, the socket still needs to work because it sends data through it. I can shove this there however, artificially limit the socket buffer size which triggers the problem on my computer.

I'm not convinced this is really better, but I can do that in addition to or instead of the mock-based test, what do you think?

index 125a6c35793..b694b4d38db 100644
--- a/Lib/test/test_asyncio/test_ssl.py
+++ b/Lib/test/test_asyncio/test_ssl.py
@@ -12,6 +12,7 @@
 import tempfile
 import threading
 import time
+import unittest.mock
 import weakref
 import unittest

@@ -1410,10 +1411,22 @@ async def client(addr):
             except (BrokenPipeError, ConnectionResetError):
                 pass

-            await future
+            socket_transport = writer.transport._ssl_protocol._transport

-            writer.close()
-            await self.wait_closed(writer)
+            def _shrink_sock_buffer(data):
+                if socket_transport._read_ready_cb is None:
+                    socket_transport._sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1)
+                return unittest.mock.DEFAULT
+
+            with unittest.mock.patch.object(
+                socket_transport, "write",
+                wraps=socket_transport.write,
+                side_effect=_shrink_sock_buffer
+            ):
+                await future
+
+                writer.close()
+                await self.wait_closed(writer)

         def run(meth):
             def wrapper(sock):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think limiting the buffer with SO_SNDBUF makes sense, you can perhaps access the socket using get_extra_info method as well instead of private attrs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in latest commit. Still kept the original mocked test I added - the interaction from SSLProtocol is really hard to track down, and this benefits from having a more straightforward illustration of what is expected from SocketTransport, even if it is all mocks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test isn't failing for me even if I remove your fix in ssl test

Copy link
Contributor Author

@Tasssadar Tasssadar Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On what platform, and what exactly are you changing? On MacOS,while on the branch from this PR, I only have to roll back the fix in selector_events.py like this to make both new tests fail:

EDIT: ha, just tested on Linux and it does succeed there anyway - probably the socket behaves differently. Will debug it and let you know.

diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py
index 22147451fa7..60bef420331 100644
--- a/Lib/asyncio/selector_events.py
+++ b/Lib/asyncio/selector_events.py
@@ -1188,7 +1188,7 @@ def _call_connection_lost(self, exc):
         try:
             super()._call_connection_lost(exc)
         finally:
-            self._write_ready = None
+            #self._write_ready = None
             if self._empty_waiter is not None:
                 self._empty_waiter.set_exception(
                     ConnectionError("Connection is closed by peer"))
@@ -1206,6 +1206,7 @@ def _reset_empty_waiter(self):

     def close(self):
         self._read_ready_cb = None
+        self._write_ready = None
         super().close()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, it's fixed now - problem was that on Linux, the minimum write buffer size is 1024, and it can only be set before the socket is connected anyway.

I switched it to use a wrapper class that fakes the full buffer by only writing half the data.

On linux, the ssl test was also not deterministic, it was not sending enough data and they actually got through before close call most of the time, so I doubled the size and now I can reliably trigger the bug with them on both Linux and MacOS.

# during this ^ run, the _resume_writing mock above was called and added more data

self.assertEqual(transport.get_write_buffer_size(), 2)
self.loop.writers[7]._run()

self.assertEqual(transport.get_write_buffer_size(), 0)
self.assertTrue(self.protocol.connection_lost.called)

class SelectorSocketTransportBufferedProtocolTests(test_utils.TestCase):

161 changes: 161 additions & 0 deletions Lib/test/test_asyncio/test_ssl.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
import tempfile
import threading
import time
import unittest.mock
import weakref
import unittest

@@ -1431,6 +1432,166 @@ def wrapper(sock):
with self.tcp_server(run(eof_server)) as srv:
self.loop.run_until_complete(client(srv.addr))

def test_remote_shutdown_receives_trailing_data_on_slow_socket(self):
# This test is the same as test_remote_shutdown_receives_trailing_data,
# except it simulates a socket that is not able to write data in time,
# thus triggering different code path in _SelectorSocketTransport.
# This triggers bug gh-115514, also tested using mocks in
# test.test_asyncio.test_selector_events.SelectorSocketTransportTests.test_write_buffer_after_close
# The slow path is triggered here by setting SO_SNDBUF, see code and comment below.

CHUNK = 1024 * 128
SIZE = 32

sslctx = self._create_server_ssl_context(
test_utils.ONLYCERT,
test_utils.ONLYKEY
)
client_sslctx = self._create_client_ssl_context()
future = None

def server(sock):
incoming = ssl.MemoryBIO()
outgoing = ssl.MemoryBIO()
sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)

while True:
try:
sslobj.do_handshake()
except ssl.SSLWantReadError:
if outgoing.pending:
sock.send(outgoing.read())
incoming.write(sock.recv(16384))
else:
if outgoing.pending:
sock.send(outgoing.read())
break

while True:
try:
data = sslobj.read(4)
except ssl.SSLWantReadError:
incoming.write(sock.recv(16384))
else:
break

self.assertEqual(data, b'ping')
sslobj.write(b'pong')
sock.send(outgoing.read())

time.sleep(0.2) # wait for the peer to fill its backlog

# send close_notify but don't wait for response
with self.assertRaises(ssl.SSLWantReadError):
sslobj.unwrap()
sock.send(outgoing.read())

# should receive all data
data_len = 0
while True:
try:
chunk = len(sslobj.read(16384))
data_len += chunk
except ssl.SSLWantReadError:
incoming.write(sock.recv(16384))
except ssl.SSLZeroReturnError:
break

self.assertEqual(data_len, CHUNK * SIZE*2)

# verify that close_notify is received
sslobj.unwrap()

sock.close()

def eof_server(sock):
sock.starttls(sslctx, server_side=True)
self.assertEqual(sock.recv_all(4), b'ping')
sock.send(b'pong')

time.sleep(0.2) # wait for the peer to fill its backlog

# send EOF
sock.shutdown(socket.SHUT_WR)

# should receive all data
data = sock.recv_all(CHUNK * SIZE)
self.assertEqual(len(data), CHUNK * SIZE)

sock.close()

async def client(addr):
nonlocal future
future = self.loop.create_future()

reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='')
writer.write(b'ping')
data = await reader.readexactly(4)
self.assertEqual(data, b'pong')

# fill write backlog in a hacky way - renegotiation won't help
for _ in range(SIZE*2):
writer.transport._test__append_write_backlog(b'x' * CHUNK)

try:
data = await reader.read()
self.assertEqual(data, b'')
except (BrokenPipeError, ConnectionResetError):
pass

# Make sure _SelectorSocketTransport enters the delayed write
# path in its `write` method by wrapping socket in a fake class
# that acts as if there is not enough space in socket buffer.
# This triggers bug gh-115514, also tested using mocks in
# test.test_asyncio.test_selector_events.SelectorSocketTransportTests.test_write_buffer_after_close
socket_transport = writer.transport._ssl_protocol._transport

class SocketWrapper:
def __init__(self, sock) -> None:
self.sock = sock

def __getattr__(self, name):
return getattr(self.sock, name)

def send(self, data):
# Fake that our write buffer is full, send only half
to_send = len(data)//2
return self.sock.send(data[:to_send])

def _fake_full_write_buffer(data):
if socket_transport._read_ready_cb is None and not isinstance(socket_transport._sock, SocketWrapper):
socket_transport._sock = SocketWrapper(socket_transport._sock)
return unittest.mock.DEFAULT

with unittest.mock.patch.object(
socket_transport, "write",
wraps=socket_transport.write,
side_effect=_fake_full_write_buffer
):
await future

writer.close()
await self.wait_closed(writer)

def run(meth):
def wrapper(sock):
try:
meth(sock)
except Exception as ex:
self.loop.call_soon_threadsafe(future.set_exception, ex)
else:
self.loop.call_soon_threadsafe(future.set_result, None)
return wrapper

with self.tcp_server(run(server)) as srv:
self.loop.run_until_complete(client(srv.addr))

with self.tcp_server(run(eof_server)) as srv:
self.loop.run_until_complete(client(srv.addr))

def test_connect_timeout_warning(self):
s = socket.socket(socket.AF_INET)
s.bind(('127.0.0.1', 0))
1 change: 1 addition & 0 deletions Misc/ACKS
Original file line number Diff line number Diff line change
@@ -189,6 +189,7 @@ Stéphane Blondon
Eric Blossom
Sergey Bobrov
Finn Bock
Vojtěch Boček
Paul Boddie
Matthew Boedicker
Robin Boerdijk
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix exceptions and incomplete writes after :class:`!asyncio._SelectorTransport`
is closed before writes are completed.