Skip to content

Commit 632785d

Browse files
Tasssadarkumaraditya303
authored andcommitted
pythongh-115514: Fix incomplete writes after close while using ssl in asyncio(python#128037)
Co-authored-by: Kumar Aditya <kumaraditya@python.org>
1 parent cee0ebc commit 632785d

File tree

5 files changed

+213
-5
lines changed

5 files changed

+213
-5
lines changed

Lib/asyncio/selector_events.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1185,10 +1185,13 @@ def can_write_eof(self):
11851185
return True
11861186

11871187
def _call_connection_lost(self, exc):
1188-
super()._call_connection_lost(exc)
1189-
if self._empty_waiter is not None:
1190-
self._empty_waiter.set_exception(
1191-
ConnectionError("Connection is closed by peer"))
1188+
try:
1189+
super()._call_connection_lost(exc)
1190+
finally:
1191+
self._write_ready = None
1192+
if self._empty_waiter is not None:
1193+
self._empty_waiter.set_exception(
1194+
ConnectionError("Connection is closed by peer"))
11921195

11931196
def _make_empty_waiter(self):
11941197
if self._empty_waiter is not None:
@@ -1203,7 +1206,6 @@ def _reset_empty_waiter(self):
12031206

12041207
def close(self):
12051208
self._read_ready_cb = None
1206-
self._write_ready = None
12071209
super().close()
12081210

12091211

Lib/test/test_asyncio/test_selector_events.py

+42
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,48 @@ def test_transport_close_remove_writer(self, m_log):
10511051
transport.close()
10521052
remove_writer.assert_called_with(self.sock_fd)
10531053

1054+
def test_write_buffer_after_close(self):
1055+
# gh-115514: If the transport is closed while:
1056+
# * Transport write buffer is not empty
1057+
# * Transport is paused
1058+
# * Protocol has data in its buffer, like SSLProtocol in self._outgoing
1059+
# The data is still written out.
1060+
1061+
# Also tested with real SSL transport in
1062+
# test.test_asyncio.test_ssl.TestSSL.test_remote_shutdown_receives_trailing_data
1063+
1064+
data = memoryview(b'data')
1065+
self.sock.send.return_value = 2
1066+
self.sock.send.fileno.return_value = 7
1067+
1068+
def _resume_writing():
1069+
transport.write(b"data")
1070+
self.protocol.resume_writing.side_effect = None
1071+
1072+
self.protocol.resume_writing.side_effect = _resume_writing
1073+
1074+
transport = self.socket_transport()
1075+
transport._high_water = 1
1076+
1077+
transport.write(data)
1078+
1079+
self.assertTrue(transport._protocol_paused)
1080+
self.assertTrue(self.sock.send.called)
1081+
self.loop.assert_writer(7, transport._write_ready)
1082+
1083+
transport.close()
1084+
1085+
# not called, we still have data in write buffer
1086+
self.assertFalse(self.protocol.connection_lost.called)
1087+
1088+
self.loop.writers[7]._run()
1089+
# during this ^ run, the _resume_writing mock above was called and added more data
1090+
1091+
self.assertEqual(transport.get_write_buffer_size(), 2)
1092+
self.loop.writers[7]._run()
1093+
1094+
self.assertEqual(transport.get_write_buffer_size(), 0)
1095+
self.assertTrue(self.protocol.connection_lost.called)
10541096

10551097
class SelectorSocketTransportBufferedProtocolTests(test_utils.TestCase):
10561098

Lib/test/test_asyncio/test_ssl.py

+161
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import tempfile
1313
import threading
1414
import time
15+
import unittest.mock
1516
import weakref
1617
import unittest
1718

@@ -1431,6 +1432,166 @@ def wrapper(sock):
14311432
with self.tcp_server(run(eof_server)) as srv:
14321433
self.loop.run_until_complete(client(srv.addr))
14331434

1435+
def test_remote_shutdown_receives_trailing_data_on_slow_socket(self):
1436+
# This test is the same as test_remote_shutdown_receives_trailing_data,
1437+
# except it simulates a socket that is not able to write data in time,
1438+
# thus triggering different code path in _SelectorSocketTransport.
1439+
# This triggers bug gh-115514, also tested using mocks in
1440+
# test.test_asyncio.test_selector_events.SelectorSocketTransportTests.test_write_buffer_after_close
1441+
# The slow path is triggered here by setting SO_SNDBUF, see code and comment below.
1442+
1443+
CHUNK = 1024 * 128
1444+
SIZE = 32
1445+
1446+
sslctx = self._create_server_ssl_context(
1447+
test_utils.ONLYCERT,
1448+
test_utils.ONLYKEY
1449+
)
1450+
client_sslctx = self._create_client_ssl_context()
1451+
future = None
1452+
1453+
def server(sock):
1454+
incoming = ssl.MemoryBIO()
1455+
outgoing = ssl.MemoryBIO()
1456+
sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)
1457+
1458+
while True:
1459+
try:
1460+
sslobj.do_handshake()
1461+
except ssl.SSLWantReadError:
1462+
if outgoing.pending:
1463+
sock.send(outgoing.read())
1464+
incoming.write(sock.recv(16384))
1465+
else:
1466+
if outgoing.pending:
1467+
sock.send(outgoing.read())
1468+
break
1469+
1470+
while True:
1471+
try:
1472+
data = sslobj.read(4)
1473+
except ssl.SSLWantReadError:
1474+
incoming.write(sock.recv(16384))
1475+
else:
1476+
break
1477+
1478+
self.assertEqual(data, b'ping')
1479+
sslobj.write(b'pong')
1480+
sock.send(outgoing.read())
1481+
1482+
time.sleep(0.2) # wait for the peer to fill its backlog
1483+
1484+
# send close_notify but don't wait for response
1485+
with self.assertRaises(ssl.SSLWantReadError):
1486+
sslobj.unwrap()
1487+
sock.send(outgoing.read())
1488+
1489+
# should receive all data
1490+
data_len = 0
1491+
while True:
1492+
try:
1493+
chunk = len(sslobj.read(16384))
1494+
data_len += chunk
1495+
except ssl.SSLWantReadError:
1496+
incoming.write(sock.recv(16384))
1497+
except ssl.SSLZeroReturnError:
1498+
break
1499+
1500+
self.assertEqual(data_len, CHUNK * SIZE*2)
1501+
1502+
# verify that close_notify is received
1503+
sslobj.unwrap()
1504+
1505+
sock.close()
1506+
1507+
def eof_server(sock):
1508+
sock.starttls(sslctx, server_side=True)
1509+
self.assertEqual(sock.recv_all(4), b'ping')
1510+
sock.send(b'pong')
1511+
1512+
time.sleep(0.2) # wait for the peer to fill its backlog
1513+
1514+
# send EOF
1515+
sock.shutdown(socket.SHUT_WR)
1516+
1517+
# should receive all data
1518+
data = sock.recv_all(CHUNK * SIZE)
1519+
self.assertEqual(len(data), CHUNK * SIZE)
1520+
1521+
sock.close()
1522+
1523+
async def client(addr):
1524+
nonlocal future
1525+
future = self.loop.create_future()
1526+
1527+
reader, writer = await asyncio.open_connection(
1528+
*addr,
1529+
ssl=client_sslctx,
1530+
server_hostname='')
1531+
writer.write(b'ping')
1532+
data = await reader.readexactly(4)
1533+
self.assertEqual(data, b'pong')
1534+
1535+
# fill write backlog in a hacky way - renegotiation won't help
1536+
for _ in range(SIZE*2):
1537+
writer.transport._test__append_write_backlog(b'x' * CHUNK)
1538+
1539+
try:
1540+
data = await reader.read()
1541+
self.assertEqual(data, b'')
1542+
except (BrokenPipeError, ConnectionResetError):
1543+
pass
1544+
1545+
# Make sure _SelectorSocketTransport enters the delayed write
1546+
# path in its `write` method by wrapping socket in a fake class
1547+
# that acts as if there is not enough space in socket buffer.
1548+
# This triggers bug gh-115514, also tested using mocks in
1549+
# test.test_asyncio.test_selector_events.SelectorSocketTransportTests.test_write_buffer_after_close
1550+
socket_transport = writer.transport._ssl_protocol._transport
1551+
1552+
class SocketWrapper:
1553+
def __init__(self, sock) -> None:
1554+
self.sock = sock
1555+
1556+
def __getattr__(self, name):
1557+
return getattr(self.sock, name)
1558+
1559+
def send(self, data):
1560+
# Fake that our write buffer is full, send only half
1561+
to_send = len(data)//2
1562+
return self.sock.send(data[:to_send])
1563+
1564+
def _fake_full_write_buffer(data):
1565+
if socket_transport._read_ready_cb is None and not isinstance(socket_transport._sock, SocketWrapper):
1566+
socket_transport._sock = SocketWrapper(socket_transport._sock)
1567+
return unittest.mock.DEFAULT
1568+
1569+
with unittest.mock.patch.object(
1570+
socket_transport, "write",
1571+
wraps=socket_transport.write,
1572+
side_effect=_fake_full_write_buffer
1573+
):
1574+
await future
1575+
1576+
writer.close()
1577+
await self.wait_closed(writer)
1578+
1579+
def run(meth):
1580+
def wrapper(sock):
1581+
try:
1582+
meth(sock)
1583+
except Exception as ex:
1584+
self.loop.call_soon_threadsafe(future.set_exception, ex)
1585+
else:
1586+
self.loop.call_soon_threadsafe(future.set_result, None)
1587+
return wrapper
1588+
1589+
with self.tcp_server(run(server)) as srv:
1590+
self.loop.run_until_complete(client(srv.addr))
1591+
1592+
with self.tcp_server(run(eof_server)) as srv:
1593+
self.loop.run_until_complete(client(srv.addr))
1594+
14341595
def test_connect_timeout_warning(self):
14351596
s = socket.socket(socket.AF_INET)
14361597
s.bind(('127.0.0.1', 0))

Misc/ACKS

+1
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ Stéphane Blondon
189189
Eric Blossom
190190
Sergey Bobrov
191191
Finn Bock
192+
Vojtěch Boček
192193
Paul Boddie
193194
Matthew Boedicker
194195
Robin Boerdijk
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fix exceptions and incomplete writes after :class:`!asyncio._SelectorTransport`
2+
is closed before writes are completed.

0 commit comments

Comments
 (0)