Skip to content

Commit ce8fc18

Browse files
authored
gh-93357: Start porting asyncio server test cases to IsolatedAsyncioTestCase (#93369)
Lay the foundation for further work in `asyncio.test_streams`.
1 parent 4f380db commit ce8fc18

File tree

1 file changed

+119
-173
lines changed

1 file changed

+119
-173
lines changed

Lib/test/test_asyncio/test_streams.py

+119-173
Original file line numberDiff line numberDiff line change
@@ -566,46 +566,10 @@ def test_exception_cancel(self):
566566
test_utils.run_briefly(self.loop)
567567
self.assertIs(stream._waiter, None)
568568

569-
def test_start_server(self):
570-
571-
class MyServer:
572-
573-
def __init__(self, loop):
574-
self.server = None
575-
self.loop = loop
576-
577-
async def handle_client(self, client_reader, client_writer):
578-
data = await client_reader.readline()
579-
client_writer.write(data)
580-
await client_writer.drain()
581-
client_writer.close()
582-
await client_writer.wait_closed()
583-
584-
def start(self):
585-
sock = socket.create_server(('127.0.0.1', 0))
586-
self.server = self.loop.run_until_complete(
587-
asyncio.start_server(self.handle_client,
588-
sock=sock))
589-
return sock.getsockname()
590-
591-
def handle_client_callback(self, client_reader, client_writer):
592-
self.loop.create_task(self.handle_client(client_reader,
593-
client_writer))
594-
595-
def start_callback(self):
596-
sock = socket.create_server(('127.0.0.1', 0))
597-
addr = sock.getsockname()
598-
sock.close()
599-
self.server = self.loop.run_until_complete(
600-
asyncio.start_server(self.handle_client_callback,
601-
host=addr[0], port=addr[1]))
602-
return addr
603-
604-
def stop(self):
605-
if self.server is not None:
606-
self.server.close()
607-
self.loop.run_until_complete(self.server.wait_closed())
608-
self.server = None
569+
570+
class NewStreamTests(unittest.IsolatedAsyncioTestCase):
571+
572+
async def test_start_server(self):
609573

610574
async def client(addr):
611575
reader, writer = await asyncio.open_connection(*addr)
@@ -617,61 +581,43 @@ async def client(addr):
617581
await writer.wait_closed()
618582
return msgback
619583

620-
messages = []
621-
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
622-
623-
# test the server variant with a coroutine as client handler
624-
server = MyServer(self.loop)
625-
addr = server.start()
626-
msg = self.loop.run_until_complete(self.loop.create_task(client(addr)))
627-
server.stop()
628-
self.assertEqual(msg, b"hello world!\n")
584+
async def handle_client(client_reader, client_writer):
585+
data = await client_reader.readline()
586+
client_writer.write(data)
587+
await client_writer.drain()
588+
client_writer.close()
589+
await client_writer.wait_closed()
590+
591+
with self.subTest(msg="coroutine"):
592+
server = await asyncio.start_server(
593+
handle_client,
594+
host=socket_helper.HOSTv4
595+
)
596+
addr = server.sockets[0].getsockname()
597+
msg = await client(addr)
598+
server.close()
599+
await server.wait_closed()
600+
self.assertEqual(msg, b"hello world!\n")
629601

630-
# test the server variant with a callback as client handler
631-
server = MyServer(self.loop)
632-
addr = server.start_callback()
633-
msg = self.loop.run_until_complete(self.loop.create_task(client(addr)))
634-
server.stop()
635-
self.assertEqual(msg, b"hello world!\n")
602+
with self.subTest(msg="callback"):
603+
async def handle_client_callback(client_reader, client_writer):
604+
asyncio.get_running_loop().create_task(
605+
handle_client(client_reader, client_writer)
606+
)
636607

637-
self.assertEqual(messages, [])
608+
server = await asyncio.start_server(
609+
handle_client_callback,
610+
host=socket_helper.HOSTv4
611+
)
612+
addr = server.sockets[0].getsockname()
613+
reader, writer = await asyncio.open_connection(*addr)
614+
msg = await client(addr)
615+
server.close()
616+
await server.wait_closed()
617+
self.assertEqual(msg, b"hello world!\n")
638618

639619
@socket_helper.skip_unless_bind_unix_socket
640-
def test_start_unix_server(self):
641-
642-
class MyServer:
643-
644-
def __init__(self, loop, path):
645-
self.server = None
646-
self.loop = loop
647-
self.path = path
648-
649-
async def handle_client(self, client_reader, client_writer):
650-
data = await client_reader.readline()
651-
client_writer.write(data)
652-
await client_writer.drain()
653-
client_writer.close()
654-
await client_writer.wait_closed()
655-
656-
def start(self):
657-
self.server = self.loop.run_until_complete(
658-
asyncio.start_unix_server(self.handle_client,
659-
path=self.path))
660-
661-
def handle_client_callback(self, client_reader, client_writer):
662-
self.loop.create_task(self.handle_client(client_reader,
663-
client_writer))
664-
665-
def start_callback(self):
666-
start = asyncio.start_unix_server(self.handle_client_callback,
667-
path=self.path)
668-
self.server = self.loop.run_until_complete(start)
669-
670-
def stop(self):
671-
if self.server is not None:
672-
self.server.close()
673-
self.loop.run_until_complete(self.server.wait_closed())
674-
self.server = None
620+
async def test_start_unix_server(self):
675621

676622
async def client(path):
677623
reader, writer = await asyncio.open_unix_connection(path)
@@ -683,64 +629,42 @@ async def client(path):
683629
await writer.wait_closed()
684630
return msgback
685631

686-
messages = []
687-
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
688-
689-
# test the server variant with a coroutine as client handler
690-
with test_utils.unix_socket_path() as path:
691-
server = MyServer(self.loop, path)
692-
server.start()
693-
msg = self.loop.run_until_complete(
694-
self.loop.create_task(client(path)))
695-
server.stop()
696-
self.assertEqual(msg, b"hello world!\n")
697-
698-
# test the server variant with a callback as client handler
699-
with test_utils.unix_socket_path() as path:
700-
server = MyServer(self.loop, path)
701-
server.start_callback()
702-
msg = self.loop.run_until_complete(
703-
self.loop.create_task(client(path)))
704-
server.stop()
705-
self.assertEqual(msg, b"hello world!\n")
706-
707-
self.assertEqual(messages, [])
632+
async def handle_client(client_reader, client_writer):
633+
data = await client_reader.readline()
634+
client_writer.write(data)
635+
await client_writer.drain()
636+
client_writer.close()
637+
await client_writer.wait_closed()
638+
639+
with self.subTest(msg="coroutine"):
640+
with test_utils.unix_socket_path() as path:
641+
server = await asyncio.start_unix_server(
642+
handle_client,
643+
path=path
644+
)
645+
msg = await client(path)
646+
server.close()
647+
await server.wait_closed()
648+
self.assertEqual(msg, b"hello world!\n")
649+
650+
with self.subTest(msg="callback"):
651+
async def handle_client_callback(client_reader, client_writer):
652+
asyncio.get_running_loop().create_task(
653+
handle_client(client_reader, client_writer)
654+
)
655+
656+
with test_utils.unix_socket_path() as path:
657+
server = await asyncio.start_unix_server(
658+
handle_client_callback,
659+
path=path
660+
)
661+
msg = await client(path)
662+
server.close()
663+
await server.wait_closed()
664+
self.assertEqual(msg, b"hello world!\n")
708665

709666
@unittest.skipIf(ssl is None, 'No ssl module')
710-
def test_start_tls(self):
711-
712-
class MyServer:
713-
714-
def __init__(self, loop):
715-
self.server = None
716-
self.loop = loop
717-
718-
async def handle_client(self, client_reader, client_writer):
719-
data1 = await client_reader.readline()
720-
client_writer.write(data1)
721-
await client_writer.drain()
722-
assert client_writer.get_extra_info('sslcontext') is None
723-
await client_writer.start_tls(
724-
test_utils.simple_server_sslcontext())
725-
assert client_writer.get_extra_info('sslcontext') is not None
726-
data2 = await client_reader.readline()
727-
client_writer.write(data2)
728-
await client_writer.drain()
729-
client_writer.close()
730-
await client_writer.wait_closed()
731-
732-
def start(self):
733-
sock = socket.create_server(('127.0.0.1', 0))
734-
self.server = self.loop.run_until_complete(
735-
asyncio.start_server(self.handle_client,
736-
sock=sock))
737-
return sock.getsockname()
738-
739-
def stop(self):
740-
if self.server is not None:
741-
self.server.close()
742-
self.loop.run_until_complete(self.server.wait_closed())
743-
self.server = None
667+
async def test_start_tls(self):
744668

745669
async def client(addr):
746670
reader, writer = await asyncio.open_connection(*addr)
@@ -757,18 +681,49 @@ async def client(addr):
757681
await writer.wait_closed()
758682
return msgback1, msgback2
759683

760-
messages = []
761-
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
762-
763-
server = MyServer(self.loop)
764-
addr = server.start()
765-
msg1, msg2 = self.loop.run_until_complete(client(addr))
766-
server.stop()
767-
768-
self.assertEqual(messages, [])
684+
async def handle_client(client_reader, client_writer):
685+
data1 = await client_reader.readline()
686+
client_writer.write(data1)
687+
await client_writer.drain()
688+
assert client_writer.get_extra_info('sslcontext') is None
689+
await client_writer.start_tls(
690+
test_utils.simple_server_sslcontext())
691+
assert client_writer.get_extra_info('sslcontext') is not None
692+
693+
data2 = await client_reader.readline()
694+
client_writer.write(data2)
695+
await client_writer.drain()
696+
client_writer.close()
697+
await client_writer.wait_closed()
698+
699+
server = await asyncio.start_server(
700+
handle_client,
701+
host=socket_helper.HOSTv4
702+
)
703+
addr = server.sockets[0].getsockname()
704+
705+
msg1, msg2 = await client(addr)
706+
server.close()
707+
await server.wait_closed()
769708
self.assertEqual(msg1, b"hello world 1!\n")
770709
self.assertEqual(msg2, b"hello world 2!\n")
771710

711+
712+
class StreamTests2(test_utils.TestCase):
713+
714+
def setUp(self):
715+
super().setUp()
716+
self.loop = asyncio.new_event_loop()
717+
self.set_event_loop(self.loop)
718+
719+
def tearDown(self):
720+
# just in case if we have transport close callbacks
721+
test_utils.run_briefly(self.loop)
722+
723+
self.loop.close()
724+
gc.collect()
725+
super().tearDown()
726+
772727
@unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
773728
def test_read_all_from_pipe_reader(self):
774729
# See asyncio issue 168. This test is derived from the example
@@ -986,22 +941,20 @@ def test_LimitOverrunError_pickleable(self):
986941
self.assertEqual(str(e), str(e2))
987942
self.assertEqual(e.consumed, e2.consumed)
988943

989-
def test_wait_closed_on_close(self):
990-
with test_utils.run_test_server() as httpd:
944+
async def test_wait_closed_on_close(self):
945+
async with test_utils.run_test_server() as httpd:
991946
rd, wr = self.loop.run_until_complete(
992947
asyncio.open_connection(*httpd.address))
993948

994949
wr.write(b'GET / HTTP/1.0\r\n\r\n')
995-
f = rd.readline()
996-
data = self.loop.run_until_complete(f)
950+
data = await rd.readline()
997951
self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
998-
f = rd.read()
999-
data = self.loop.run_until_complete(f)
952+
await rd.read()
1000953
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
1001954
self.assertFalse(wr.is_closing())
1002955
wr.close()
1003956
self.assertTrue(wr.is_closing())
1004-
self.loop.run_until_complete(wr.wait_closed())
957+
await wr.wait_closed()
1005958

1006959
def test_wait_closed_on_close_with_unread_data(self):
1007960
with test_utils.run_test_server() as httpd:
@@ -1057,15 +1010,10 @@ async def inner(httpd):
10571010

10581011
self.assertEqual(messages, [])
10591012

1060-
def test_eof_feed_when_closing_writer(self):
1013+
async def test_eof_feed_when_closing_writer(self):
10611014
# See http://bugs.python.org/issue35065
1062-
messages = []
1063-
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
1064-
1065-
with test_utils.run_test_server() as httpd:
1066-
rd, wr = self.loop.run_until_complete(
1067-
asyncio.open_connection(*httpd.address))
1068-
1015+
async with test_utils.run_test_server() as httpd:
1016+
rd, wr = await asyncio.open_connection(*httpd.address)
10691017
wr.close()
10701018
f = wr.wait_closed()
10711019
self.loop.run_until_complete(f)
@@ -1074,8 +1022,6 @@ def test_eof_feed_when_closing_writer(self):
10741022
data = self.loop.run_until_complete(f)
10751023
self.assertEqual(data, b'')
10761024

1077-
self.assertEqual(messages, [])
1078-
10791025

10801026
if __name__ == '__main__':
10811027
unittest.main()

0 commit comments

Comments
 (0)