|
12 | 12 | import tempfile
|
13 | 13 | import threading
|
14 | 14 | import time
|
| 15 | +import unittest.mock |
15 | 16 | import weakref
|
16 | 17 | import unittest
|
17 | 18 |
|
@@ -1431,6 +1432,166 @@ def wrapper(sock):
|
1431 | 1432 | with self.tcp_server(run(eof_server)) as srv:
|
1432 | 1433 | self.loop.run_until_complete(client(srv.addr))
|
1433 | 1434 |
|
| 1435 | + def test_remote_shutdown_receives_trailing_data_on_slow_socket(self): |
| 1436 | + # This test is the same as test_remote_shutdown_receives_trailing_data, |
| 1437 | + # except it simulates a socket that is not able to write data in time, |
| 1438 | + # thus triggering different code path in _SelectorSocketTransport. |
| 1439 | + # This triggers bug gh-115514, also tested using mocks in |
| 1440 | + # test.test_asyncio.test_selector_events.SelectorSocketTransportTests.test_write_buffer_after_close |
| 1441 | + # The slow path is triggered here by setting SO_SNDBUF, see code and comment below. |
| 1442 | + |
| 1443 | + CHUNK = 1024 * 128 |
| 1444 | + SIZE = 32 |
| 1445 | + |
| 1446 | + sslctx = self._create_server_ssl_context( |
| 1447 | + test_utils.ONLYCERT, |
| 1448 | + test_utils.ONLYKEY |
| 1449 | + ) |
| 1450 | + client_sslctx = self._create_client_ssl_context() |
| 1451 | + future = None |
| 1452 | + |
| 1453 | + def server(sock): |
| 1454 | + incoming = ssl.MemoryBIO() |
| 1455 | + outgoing = ssl.MemoryBIO() |
| 1456 | + sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True) |
| 1457 | + |
| 1458 | + while True: |
| 1459 | + try: |
| 1460 | + sslobj.do_handshake() |
| 1461 | + except ssl.SSLWantReadError: |
| 1462 | + if outgoing.pending: |
| 1463 | + sock.send(outgoing.read()) |
| 1464 | + incoming.write(sock.recv(16384)) |
| 1465 | + else: |
| 1466 | + if outgoing.pending: |
| 1467 | + sock.send(outgoing.read()) |
| 1468 | + break |
| 1469 | + |
| 1470 | + while True: |
| 1471 | + try: |
| 1472 | + data = sslobj.read(4) |
| 1473 | + except ssl.SSLWantReadError: |
| 1474 | + incoming.write(sock.recv(16384)) |
| 1475 | + else: |
| 1476 | + break |
| 1477 | + |
| 1478 | + self.assertEqual(data, b'ping') |
| 1479 | + sslobj.write(b'pong') |
| 1480 | + sock.send(outgoing.read()) |
| 1481 | + |
| 1482 | + time.sleep(0.2) # wait for the peer to fill its backlog |
| 1483 | + |
| 1484 | + # send close_notify but don't wait for response |
| 1485 | + with self.assertRaises(ssl.SSLWantReadError): |
| 1486 | + sslobj.unwrap() |
| 1487 | + sock.send(outgoing.read()) |
| 1488 | + |
| 1489 | + # should receive all data |
| 1490 | + data_len = 0 |
| 1491 | + while True: |
| 1492 | + try: |
| 1493 | + chunk = len(sslobj.read(16384)) |
| 1494 | + data_len += chunk |
| 1495 | + except ssl.SSLWantReadError: |
| 1496 | + incoming.write(sock.recv(16384)) |
| 1497 | + except ssl.SSLZeroReturnError: |
| 1498 | + break |
| 1499 | + |
| 1500 | + self.assertEqual(data_len, CHUNK * SIZE*2) |
| 1501 | + |
| 1502 | + # verify that close_notify is received |
| 1503 | + sslobj.unwrap() |
| 1504 | + |
| 1505 | + sock.close() |
| 1506 | + |
| 1507 | + def eof_server(sock): |
| 1508 | + sock.starttls(sslctx, server_side=True) |
| 1509 | + self.assertEqual(sock.recv_all(4), b'ping') |
| 1510 | + sock.send(b'pong') |
| 1511 | + |
| 1512 | + time.sleep(0.2) # wait for the peer to fill its backlog |
| 1513 | + |
| 1514 | + # send EOF |
| 1515 | + sock.shutdown(socket.SHUT_WR) |
| 1516 | + |
| 1517 | + # should receive all data |
| 1518 | + data = sock.recv_all(CHUNK * SIZE) |
| 1519 | + self.assertEqual(len(data), CHUNK * SIZE) |
| 1520 | + |
| 1521 | + sock.close() |
| 1522 | + |
| 1523 | + async def client(addr): |
| 1524 | + nonlocal future |
| 1525 | + future = self.loop.create_future() |
| 1526 | + |
| 1527 | + reader, writer = await asyncio.open_connection( |
| 1528 | + *addr, |
| 1529 | + ssl=client_sslctx, |
| 1530 | + server_hostname='') |
| 1531 | + writer.write(b'ping') |
| 1532 | + data = await reader.readexactly(4) |
| 1533 | + self.assertEqual(data, b'pong') |
| 1534 | + |
| 1535 | + # fill write backlog in a hacky way - renegotiation won't help |
| 1536 | + for _ in range(SIZE*2): |
| 1537 | + writer.transport._test__append_write_backlog(b'x' * CHUNK) |
| 1538 | + |
| 1539 | + try: |
| 1540 | + data = await reader.read() |
| 1541 | + self.assertEqual(data, b'') |
| 1542 | + except (BrokenPipeError, ConnectionResetError): |
| 1543 | + pass |
| 1544 | + |
| 1545 | + # Make sure _SelectorSocketTransport enters the delayed write |
| 1546 | + # path in its `write` method by wrapping socket in a fake class |
| 1547 | + # that acts as if there is not enough space in socket buffer. |
| 1548 | + # This triggers bug gh-115514, also tested using mocks in |
| 1549 | + # test.test_asyncio.test_selector_events.SelectorSocketTransportTests.test_write_buffer_after_close |
| 1550 | + socket_transport = writer.transport._ssl_protocol._transport |
| 1551 | + |
| 1552 | + class SocketWrapper: |
| 1553 | + def __init__(self, sock) -> None: |
| 1554 | + self.sock = sock |
| 1555 | + |
| 1556 | + def __getattr__(self, name): |
| 1557 | + return getattr(self.sock, name) |
| 1558 | + |
| 1559 | + def send(self, data): |
| 1560 | + # Fake that our write buffer is full, send only half |
| 1561 | + to_send = len(data)//2 |
| 1562 | + return self.sock.send(data[:to_send]) |
| 1563 | + |
| 1564 | + def _fake_full_write_buffer(data): |
| 1565 | + if socket_transport._read_ready_cb is None and not isinstance(socket_transport._sock, SocketWrapper): |
| 1566 | + socket_transport._sock = SocketWrapper(socket_transport._sock) |
| 1567 | + return unittest.mock.DEFAULT |
| 1568 | + |
| 1569 | + with unittest.mock.patch.object( |
| 1570 | + socket_transport, "write", |
| 1571 | + wraps=socket_transport.write, |
| 1572 | + side_effect=_fake_full_write_buffer |
| 1573 | + ): |
| 1574 | + await future |
| 1575 | + |
| 1576 | + writer.close() |
| 1577 | + await self.wait_closed(writer) |
| 1578 | + |
| 1579 | + def run(meth): |
| 1580 | + def wrapper(sock): |
| 1581 | + try: |
| 1582 | + meth(sock) |
| 1583 | + except Exception as ex: |
| 1584 | + self.loop.call_soon_threadsafe(future.set_exception, ex) |
| 1585 | + else: |
| 1586 | + self.loop.call_soon_threadsafe(future.set_result, None) |
| 1587 | + return wrapper |
| 1588 | + |
| 1589 | + with self.tcp_server(run(server)) as srv: |
| 1590 | + self.loop.run_until_complete(client(srv.addr)) |
| 1591 | + |
| 1592 | + with self.tcp_server(run(eof_server)) as srv: |
| 1593 | + self.loop.run_until_complete(client(srv.addr)) |
| 1594 | + |
1434 | 1595 | def test_connect_timeout_warning(self):
|
1435 | 1596 | s = socket.socket(socket.AF_INET)
|
1436 | 1597 | s.bind(('127.0.0.1', 0))
|
|
0 commit comments