From d8b449241490db36752cd45ff41a0daa35f8d57a Mon Sep 17 00:00:00 2001 From: twisteroid ambassador Date: Tue, 18 Oct 2022 19:09:09 +0800 Subject: [PATCH 1/9] Add tests for happy eyeballs and its internal workings. One of the new tests may fail intermittently due to #86296. --- Lib/test/test_asyncio/test_base_events.py | 111 +++++++++++++++++++++ Lib/test/test_asyncio/test_staggerd.py | 115 ++++++++++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 Lib/test/test_asyncio/test_staggerd.py diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index 2dcb20c1cec7f9..c5fb4f85e86238 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -145,6 +145,58 @@ def test_ipaddr_info_no_inet_pton(self, m_socket): socket.SOCK_STREAM, socket.IPPROTO_TCP)) + def test_interleave_ipaddrs(self): + addrinfos = [ + (socket.AF_INET6, 0, 0, '', ('2001:db8::1', 1)), + (socket.AF_INET6, 0, 0, '', ('2001:db8::2', 2)), + (socket.AF_INET6, 0, 0, '', ('2001:db8::3', 3)), + (socket.AF_INET6, 0, 0, '', ('2001:db8::4', 4)), + (socket.AF_INET, 0, 0, '', ('192.0.2.1', 5)), + (socket.AF_INET, 0, 0, '', ('192.0.2.2', 6)), + (socket.AF_INET, 0, 0, '', ('192.0.2.3', 7)), + (socket.AF_INET, 0, 0, '', ('192.0.2.4', 8)), + ] + + self.assertEqual( + [ + (socket.AF_INET6, 0, 0, '', ('2001:db8::1', 1)), + (socket.AF_INET, 0, 0, '', ('192.0.2.1', 5)), + (socket.AF_INET6, 0, 0, '', ('2001:db8::2', 2)), + (socket.AF_INET, 0, 0, '', ('192.0.2.2', 6)), + (socket.AF_INET6, 0, 0, '', ('2001:db8::3', 3)), + (socket.AF_INET, 0, 0, '', ('192.0.2.3', 7)), + (socket.AF_INET6, 0, 0, '', ('2001:db8::4', 4)), + (socket.AF_INET, 0, 0, '', ('192.0.2.4', 8)), + ], + base_events._interleave_addrinfos(addrinfos) + ) + + def test_interleave_ipaddrs_first_address_family_count(self): + addrinfos = [ + (socket.AF_INET6, 0, 0, '', ('2001:db8::1', 1)), + (socket.AF_INET6, 0, 0, '', ('2001:db8::2', 2)), + (socket.AF_INET6, 0, 0, '', ('2001:db8::3', 3)), + (socket.AF_INET6, 0, 0, '', ('2001:db8::4', 4)), + (socket.AF_INET, 0, 0, '', ('192.0.2.1', 5)), + (socket.AF_INET, 0, 0, '', ('192.0.2.2', 6)), + (socket.AF_INET, 0, 0, '', ('192.0.2.3', 7)), + (socket.AF_INET, 0, 0, '', ('192.0.2.4', 8)), + ] + + self.assertEqual( + [ + (socket.AF_INET6, 0, 0, '', ('2001:db8::1', 1)), + (socket.AF_INET6, 0, 0, '', ('2001:db8::2', 2)), + (socket.AF_INET, 0, 0, '', ('192.0.2.1', 5)), + (socket.AF_INET6, 0, 0, '', ('2001:db8::3', 3)), + (socket.AF_INET, 0, 0, '', ('192.0.2.2', 6)), + (socket.AF_INET6, 0, 0, '', ('2001:db8::4', 4)), + (socket.AF_INET, 0, 0, '', ('192.0.2.3', 7)), + (socket.AF_INET, 0, 0, '', ('192.0.2.4', 8)), + ], + base_events._interleave_addrinfos(addrinfos, 2) + ) + class BaseEventLoopTests(test_utils.TestCase): @@ -1431,6 +1483,65 @@ def getaddrinfo_task(*args, **kwds): self.assertRaises( OSError, self.loop.run_until_complete, coro) + @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'no IPv6 support') + @patch_socket + def test_create_connection_happy_eyeballs(self, m_socket): + + class MyProto(asyncio.Protocol): + pass + + async def getaddrinfo(*args, **kw): + return [(socket.AF_INET6, 0, 0, '', ('2001:db8::1', 1)), + (socket.AF_INET, 0, 0, '', ('192.0.2.1', 5))] + + async def sock_connect(sock, address): + if address[0] == '2001:db8::1': + await asyncio.sleep(1) + sock.connect(address) + + self.loop._add_reader = mock.Mock() + self.loop._add_writer = mock.Mock() + self.loop.getaddrinfo = getaddrinfo + self.loop.sock_connect = sock_connect + + coro = self.loop.create_connection(MyProto, 'example.com', 80, happy_eyeballs_delay=0.3) + transport, protocol = self.loop.run_until_complete(coro) + try: + sock = transport._sock + sock.connect.assert_called_with(('192.0.2.1', 5)) + finally: + transport.close() + test_utils.run_briefly(self.loop) # allow transport to close + + @patch_socket + def test_create_connection_happy_eyeballs_ipv4_only(self, m_socket): + + class MyProto(asyncio.Protocol): + pass + + async def getaddrinfo(*args, **kw): + return [(socket.AF_INET, 0, 0, '', ('192.0.2.1', 5)), + (socket.AF_INET, 0, 0, '', ('192.0.2.2', 6))] + + async def sock_connect(sock, address): + if address[0] == '192.0.2.1': + await asyncio.sleep(1) + sock.connect(address) + + self.loop._add_reader = mock.Mock() + self.loop._add_writer = mock.Mock() + self.loop.getaddrinfo = getaddrinfo + self.loop.sock_connect = sock_connect + + coro = self.loop.create_connection(MyProto, 'example.com', 80, happy_eyeballs_delay=0.3) + transport, protocol = self.loop.run_until_complete(coro) + try: + sock = transport._sock + sock.connect.assert_called_with(('192.0.2.2', 6)) + finally: + transport.close() + test_utils.run_briefly(self.loop) # allow transport to close + @patch_socket def test_create_connection_bluetooth(self, m_socket): # See http://bugs.python.org/issue27136, fallback to getaddrinfo when diff --git a/Lib/test/test_asyncio/test_staggerd.py b/Lib/test/test_asyncio/test_staggerd.py new file mode 100644 index 00000000000000..e2585eb2f41878 --- /dev/null +++ b/Lib/test/test_asyncio/test_staggerd.py @@ -0,0 +1,115 @@ +import asyncio +import functools +import unittest +from asyncio.staggered import staggered_race + + +# To prevent a warning "test altered the execution environment" +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class TestStaggered(unittest.IsolatedAsyncioTestCase): + @staticmethod + async def waiting_coroutine(return_value, wait_seconds, success): + await asyncio.sleep(wait_seconds) + if success: + return return_value + raise RuntimeError(str(return_value)) + + def get_waiting_coroutine_factory(self, return_value, wait_seconds, success): + return functools.partial(self.waiting_coroutine, return_value, wait_seconds, success) + + async def test_single_success(self): + winner_result, winner_idx, exceptions = await staggered_race( + (self.get_waiting_coroutine_factory(0, 0.1, True),), + 0.1, + ) + self.assertEqual(winner_result, 0) + self.assertEqual(winner_idx, 0) + self.assertEqual(len(exceptions), 1) + self.assertIsNone(exceptions[0]) + + async def test_single_fail(self): + winner_result, winner_idx, exceptions = await staggered_race( + (self.get_waiting_coroutine_factory(0, 0.1, False),), + 0.1, + ) + self.assertEqual(winner_result, None) + self.assertEqual(winner_idx, None) + self.assertEqual(len(exceptions), 1) + self.assertIsInstance(exceptions[0], RuntimeError) + + async def test_first_win(self): + winner_result, winner_idx, exceptions = await staggered_race( + ( + self.get_waiting_coroutine_factory(0, 0.2, True), + self.get_waiting_coroutine_factory(1, 0.2, True), + ), + 0.1, + ) + self.assertEqual(winner_result, 0) + self.assertEqual(winner_idx, 0) + self.assertEqual(len(exceptions), 2) + self.assertIsNone(exceptions[0]) + self.assertIsInstance(exceptions[1], asyncio.CancelledError) + + async def test_second_win(self): + winner_result, winner_idx, exceptions = await staggered_race( + ( + self.get_waiting_coroutine_factory(0, 0.3, True), + self.get_waiting_coroutine_factory(1, 0.1, True), + ), + 0.1, + ) + self.assertEqual(winner_result, 1) + self.assertEqual(winner_idx, 1) + self.assertEqual(len(exceptions), 2) + self.assertIsInstance(exceptions[0], asyncio.CancelledError) + self.assertIsNone(exceptions[1]) + + async def test_first_fail(self): + winner_result, winner_idx, exceptions = await staggered_race( + ( + self.get_waiting_coroutine_factory(0, 0.2, False), + self.get_waiting_coroutine_factory(1, 0.2, True), + ), + 0.1, + ) + self.assertEqual(winner_result, 1) + self.assertEqual(winner_idx, 1) + self.assertEqual(len(exceptions), 2) + self.assertIsInstance(exceptions[0], RuntimeError) + self.assertIsNone(exceptions[1]) + + async def test_second_fail(self): + winner_result, winner_idx, exceptions = await staggered_race( + ( + self.get_waiting_coroutine_factory(0, 0.2, True), + self.get_waiting_coroutine_factory(1, 0, False), + ), + 0.1, + ) + self.assertEqual(winner_result, 0) + self.assertEqual(winner_idx, 0) + self.assertEqual(len(exceptions), 2) + self.assertIsNone(exceptions[0]) + self.assertIsInstance(exceptions[1], RuntimeError) + + async def test_simultaneous_success_fail(self): + # There's a potential race condition here: + # https://github.com/python/cpython/issues/86296 + for _ in range(50): + winner_result, winner_idx, exceptions = await staggered_race( + ( + self.get_waiting_coroutine_factory(0, 0.1, True), + self.get_waiting_coroutine_factory(1, 0.05, False), + self.get_waiting_coroutine_factory(2, 0.05, True) + ), + 0.05, + ) + self.assertEqual(winner_result, 0) + self.assertEqual(winner_idx, 0) + + + From 34a8f6f6009c225579bfbd0f3e21e9522ea71892 Mon Sep 17 00:00:00 2001 From: twisteroid ambassador Date: Tue, 18 Oct 2022 23:48:53 +0800 Subject: [PATCH 2/9] Correct typo in filename. --- Lib/test/test_asyncio/{test_staggerd.py => test_staggered.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename Lib/test/test_asyncio/{test_staggerd.py => test_staggered.py} (100%) diff --git a/Lib/test/test_asyncio/test_staggerd.py b/Lib/test/test_asyncio/test_staggered.py similarity index 100% rename from Lib/test/test_asyncio/test_staggerd.py rename to Lib/test/test_asyncio/test_staggered.py From c0e179e8c648b2f5217c911682e92dbe3e2e905e Mon Sep 17 00:00:00 2001 From: twisteroid ambassador Date: Wed, 19 Oct 2022 01:19:26 +0800 Subject: [PATCH 3/9] Improve readability for _interleave_addrinfos tests. --- Lib/test/test_asyncio/test_base_events.py | 59 ++++++----------------- 1 file changed, 14 insertions(+), 45 deletions(-) diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index c5fb4f85e86238..cba456f833f9a5 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -146,55 +146,24 @@ def test_ipaddr_info_no_inet_pton(self, m_socket): socket.IPPROTO_TCP)) def test_interleave_ipaddrs(self): - addrinfos = [ - (socket.AF_INET6, 0, 0, '', ('2001:db8::1', 1)), - (socket.AF_INET6, 0, 0, '', ('2001:db8::2', 2)), - (socket.AF_INET6, 0, 0, '', ('2001:db8::3', 3)), - (socket.AF_INET6, 0, 0, '', ('2001:db8::4', 4)), - (socket.AF_INET, 0, 0, '', ('192.0.2.1', 5)), - (socket.AF_INET, 0, 0, '', ('192.0.2.2', 6)), - (socket.AF_INET, 0, 0, '', ('192.0.2.3', 7)), - (socket.AF_INET, 0, 0, '', ('192.0.2.4', 8)), - ] + SIX_A = (socket.AF_INET6, 0, 0, '', ('2001:db8::1', 1)) + SIX_B = (socket.AF_INET6, 0, 0, '', ('2001:db8::2', 2)) + SIX_C = (socket.AF_INET6, 0, 0, '', ('2001:db8::3', 3)) + SIX_D = (socket.AF_INET6, 0, 0, '', ('2001:db8::4', 4)) + FOUR_A = (socket.AF_INET, 0, 0, '', ('192.0.2.1', 5)) + FOUR_B = (socket.AF_INET, 0, 0, '', ('192.0.2.2', 6)) + FOUR_C = (socket.AF_INET, 0, 0, '', ('192.0.2.3', 7)) + FOUR_D = (socket.AF_INET, 0, 0, '', ('192.0.2.4', 8)) - self.assertEqual( - [ - (socket.AF_INET6, 0, 0, '', ('2001:db8::1', 1)), - (socket.AF_INET, 0, 0, '', ('192.0.2.1', 5)), - (socket.AF_INET6, 0, 0, '', ('2001:db8::2', 2)), - (socket.AF_INET, 0, 0, '', ('192.0.2.2', 6)), - (socket.AF_INET6, 0, 0, '', ('2001:db8::3', 3)), - (socket.AF_INET, 0, 0, '', ('192.0.2.3', 7)), - (socket.AF_INET6, 0, 0, '', ('2001:db8::4', 4)), - (socket.AF_INET, 0, 0, '', ('192.0.2.4', 8)), - ], - base_events._interleave_addrinfos(addrinfos) - ) + addrinfos = [SIX_A, SIX_B, SIX_C, SIX_D, FOUR_A, FOUR_B, FOUR_C, FOUR_D] + expected = [SIX_A, FOUR_A, SIX_B, FOUR_B, SIX_C, FOUR_C, SIX_D, FOUR_D] - def test_interleave_ipaddrs_first_address_family_count(self): - addrinfos = [ - (socket.AF_INET6, 0, 0, '', ('2001:db8::1', 1)), - (socket.AF_INET6, 0, 0, '', ('2001:db8::2', 2)), - (socket.AF_INET6, 0, 0, '', ('2001:db8::3', 3)), - (socket.AF_INET6, 0, 0, '', ('2001:db8::4', 4)), - (socket.AF_INET, 0, 0, '', ('192.0.2.1', 5)), - (socket.AF_INET, 0, 0, '', ('192.0.2.2', 6)), - (socket.AF_INET, 0, 0, '', ('192.0.2.3', 7)), - (socket.AF_INET, 0, 0, '', ('192.0.2.4', 8)), - ] + self.assertEqual(expected, base_events._interleave_addrinfos(addrinfos)) + expected_fafc_2 = [SIX_A, SIX_B, FOUR_A, SIX_C, FOUR_B, SIX_D, FOUR_C, FOUR_D] self.assertEqual( - [ - (socket.AF_INET6, 0, 0, '', ('2001:db8::1', 1)), - (socket.AF_INET6, 0, 0, '', ('2001:db8::2', 2)), - (socket.AF_INET, 0, 0, '', ('192.0.2.1', 5)), - (socket.AF_INET6, 0, 0, '', ('2001:db8::3', 3)), - (socket.AF_INET, 0, 0, '', ('192.0.2.2', 6)), - (socket.AF_INET6, 0, 0, '', ('2001:db8::4', 4)), - (socket.AF_INET, 0, 0, '', ('192.0.2.3', 7)), - (socket.AF_INET, 0, 0, '', ('192.0.2.4', 8)), - ], - base_events._interleave_addrinfos(addrinfos, 2) + expected_fafc_2, + base_events._interleave_addrinfos(addrinfos, first_address_family_count=2), ) From 8270e6976665ba24fbe927a9b0ac1fdb55a3cee0 Mon Sep 17 00:00:00 2001 From: twisteroid ambassador Date: Wed, 19 Oct 2022 01:20:06 +0800 Subject: [PATCH 4/9] Rename _interleave_addrinfos tests to match the method being tested. --- Lib/test/test_asyncio/test_base_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index cba456f833f9a5..8e368ead7c8f00 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -145,7 +145,7 @@ def test_ipaddr_info_no_inet_pton(self, m_socket): socket.SOCK_STREAM, socket.IPPROTO_TCP)) - def test_interleave_ipaddrs(self): + def test_interleave_addrinfos(self): SIX_A = (socket.AF_INET6, 0, 0, '', ('2001:db8::1', 1)) SIX_B = (socket.AF_INET6, 0, 0, '', ('2001:db8::2', 2)) SIX_C = (socket.AF_INET6, 0, 0, '', ('2001:db8::3', 3)) From afb6abfae9e1510ec65cc5ee493f7d0893675bfa Mon Sep 17 00:00:00 2001 From: twisteroid ambassador Date: Wed, 19 Oct 2022 01:28:29 +0800 Subject: [PATCH 5/9] Made test_simultaneous_success_fail fail more consistently (I hope). --- Lib/test/test_asyncio/test_staggered.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_asyncio/test_staggered.py b/Lib/test/test_asyncio/test_staggered.py index e2585eb2f41878..4cb82afa419d00 100644 --- a/Lib/test/test_asyncio/test_staggered.py +++ b/Lib/test/test_asyncio/test_staggered.py @@ -99,14 +99,17 @@ async def test_second_fail(self): async def test_simultaneous_success_fail(self): # There's a potential race condition here: # https://github.com/python/cpython/issues/86296 - for _ in range(50): + # As with any race condition, it can be difficult to reproduce. + # This test may not fail every time. + for i in range(201): + time_unit = 0.0001 * i winner_result, winner_idx, exceptions = await staggered_race( ( - self.get_waiting_coroutine_factory(0, 0.1, True), - self.get_waiting_coroutine_factory(1, 0.05, False), + self.get_waiting_coroutine_factory(0, time_unit*2, True), + self.get_waiting_coroutine_factory(1, time_unit, False), self.get_waiting_coroutine_factory(2, 0.05, True) ), - 0.05, + time_unit, ) self.assertEqual(winner_result, 0) self.assertEqual(winner_idx, 0) From 513f7f911690bbd704de0168564d83cd8c5fadd9 Mon Sep 17 00:00:00 2001 From: twisteroid ambassador Date: Wed, 19 Oct 2022 12:02:17 +0800 Subject: [PATCH 6/9] Replace `wait_for` with `wait`. This solution is compatible with all Python versions, and should pass all tests. --- Lib/asyncio/staggered.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py index 451a53a16f3831..5b74045bea8bca 100644 --- a/Lib/asyncio/staggered.py +++ b/Lib/asyncio/staggered.py @@ -2,11 +2,9 @@ __all__ = 'staggered_race', -import contextlib import typing from . import events -from . import exceptions as exceptions_mod from . import locks from . import tasks @@ -83,12 +81,11 @@ async def run_one_coro( previous_failed: typing.Optional[locks.Event]) -> None: # Wait for the previous task to finish, or for delay seconds if previous_failed is not None: - with contextlib.suppress(exceptions_mod.TimeoutError): - # Use asyncio.wait_for() instead of asyncio.wait() here, so - # that if we get cancelled at this point, Event.wait() is also - # cancelled, otherwise there will be a "Task destroyed but it is - # pending" later. - await tasks.wait_for(previous_failed.wait(), delay) + wait_task = tasks.create_task(previous_failed.wait()) + try: + await tasks.wait((wait_task,), timeout=delay) + finally: + wait_task.cancel() # Get the next coroutine to run try: this_index, coro_fn = next(enum_coro_fns) From 079eec82bb0eef7143639fc1339c080e884ed63a Mon Sep 17 00:00:00 2001 From: Oleg Iarygin Date: Wed, 26 Apr 2023 17:27:44 +0400 Subject: [PATCH 7/9] Delete staggered.py --- Lib/asyncio/staggered.py | 146 --------------------------------------- 1 file changed, 146 deletions(-) delete mode 100644 Lib/asyncio/staggered.py diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py deleted file mode 100644 index 5b74045bea8bca..00000000000000 --- a/Lib/asyncio/staggered.py +++ /dev/null @@ -1,146 +0,0 @@ -"""Support for running coroutines in parallel with staggered start times.""" - -__all__ = 'staggered_race', - -import typing - -from . import events -from . import locks -from . import tasks - - -async def staggered_race( - coro_fns: typing.Iterable[typing.Callable[[], typing.Awaitable]], - delay: typing.Optional[float], - *, - loop: events.AbstractEventLoop = None, -) -> typing.Tuple[ - typing.Any, - typing.Optional[int], - typing.List[typing.Optional[Exception]] -]: - """Run coroutines with staggered start times and take the first to finish. - - This method takes an iterable of coroutine functions. The first one is - started immediately. From then on, whenever the immediately preceding one - fails (raises an exception), or when *delay* seconds has passed, the next - coroutine is started. This continues until one of the coroutines complete - successfully, in which case all others are cancelled, or until all - coroutines fail. - - The coroutines provided should be well-behaved in the following way: - - * They should only ``return`` if completed successfully. - - * They should always raise an exception if they did not complete - successfully. In particular, if they handle cancellation, they should - probably reraise, like this:: - - try: - # do work - except asyncio.CancelledError: - # undo partially completed work - raise - - Args: - coro_fns: an iterable of coroutine functions, i.e. callables that - return a coroutine object when called. Use ``functools.partial`` or - lambdas to pass arguments. - - delay: amount of time, in seconds, between starting coroutines. If - ``None``, the coroutines will run sequentially. - - loop: the event loop to use. - - Returns: - tuple *(winner_result, winner_index, exceptions)* where - - - *winner_result*: the result of the winning coroutine, or ``None`` - if no coroutines won. - - - *winner_index*: the index of the winning coroutine in - ``coro_fns``, or ``None`` if no coroutines won. If the winning - coroutine may return None on success, *winner_index* can be used - to definitively determine whether any coroutine won. - - - *exceptions*: list of exceptions returned by the coroutines. - ``len(exceptions)`` is equal to the number of coroutines actually - started, and the order is the same as in ``coro_fns``. The winning - coroutine's entry is ``None``. - - """ - # TODO: when we have aiter() and anext(), allow async iterables in coro_fns. - loop = loop or events.get_running_loop() - enum_coro_fns = enumerate(coro_fns) - winner_result = None - winner_index = None - exceptions = [] - running_tasks = [] - - async def run_one_coro( - previous_failed: typing.Optional[locks.Event]) -> None: - # Wait for the previous task to finish, or for delay seconds - if previous_failed is not None: - wait_task = tasks.create_task(previous_failed.wait()) - try: - await tasks.wait((wait_task,), timeout=delay) - finally: - wait_task.cancel() - # Get the next coroutine to run - try: - this_index, coro_fn = next(enum_coro_fns) - except StopIteration: - return - # Start task that will run the next coroutine - this_failed = locks.Event() - next_task = loop.create_task(run_one_coro(this_failed)) - running_tasks.append(next_task) - assert len(running_tasks) == this_index + 2 - # Prepare place to put this coroutine's exceptions if not won - exceptions.append(None) - assert len(exceptions) == this_index + 1 - - try: - result = await coro_fn() - except (SystemExit, KeyboardInterrupt): - raise - except BaseException as e: - exceptions[this_index] = e - this_failed.set() # Kickstart the next coroutine - else: - # Store winner's results - nonlocal winner_index, winner_result - assert winner_index is None - winner_index = this_index - winner_result = result - # Cancel all other tasks. We take care to not cancel the current - # task as well. If we do so, then since there is no `await` after - # here and CancelledError are usually thrown at one, we will - # encounter a curious corner case where the current task will end - # up as done() == True, cancelled() == False, exception() == - # asyncio.CancelledError. This behavior is specified in - # https://bugs.python.org/issue30048 - for i, t in enumerate(running_tasks): - if i != this_index: - t.cancel() - - first_task = loop.create_task(run_one_coro(None)) - running_tasks.append(first_task) - try: - # Wait for a growing list of tasks to all finish: poor man's version of - # curio's TaskGroup or trio's nursery - done_count = 0 - while done_count != len(running_tasks): - done, _ = await tasks.wait(running_tasks) - done_count = len(done) - # If run_one_coro raises an unhandled exception, it's probably a - # programming error, and I want to see it. - if __debug__: - for d in done: - if d.done() and not d.cancelled() and d.exception(): - raise d.exception() - return winner_result, winner_index, exceptions - finally: - # Make sure no tasks are left running if we leave this function - for t in running_tasks: - t.cancel() From 5ec7d65496d52b5cea2becb48ae0100b43d4ede9 Mon Sep 17 00:00:00 2001 From: Oleg Iarygin Date: Wed, 26 Apr 2023 19:50:51 +0400 Subject: [PATCH 8/9] Restore staggered.py from `main` branch --- Lib/asyncio/staggered.py | 149 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 Lib/asyncio/staggered.py diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py new file mode 100644 index 00000000000000..451a53a16f3831 --- /dev/null +++ b/Lib/asyncio/staggered.py @@ -0,0 +1,149 @@ +"""Support for running coroutines in parallel with staggered start times.""" + +__all__ = 'staggered_race', + +import contextlib +import typing + +from . import events +from . import exceptions as exceptions_mod +from . import locks +from . import tasks + + +async def staggered_race( + coro_fns: typing.Iterable[typing.Callable[[], typing.Awaitable]], + delay: typing.Optional[float], + *, + loop: events.AbstractEventLoop = None, +) -> typing.Tuple[ + typing.Any, + typing.Optional[int], + typing.List[typing.Optional[Exception]] +]: + """Run coroutines with staggered start times and take the first to finish. + + This method takes an iterable of coroutine functions. The first one is + started immediately. From then on, whenever the immediately preceding one + fails (raises an exception), or when *delay* seconds has passed, the next + coroutine is started. This continues until one of the coroutines complete + successfully, in which case all others are cancelled, or until all + coroutines fail. + + The coroutines provided should be well-behaved in the following way: + + * They should only ``return`` if completed successfully. + + * They should always raise an exception if they did not complete + successfully. In particular, if they handle cancellation, they should + probably reraise, like this:: + + try: + # do work + except asyncio.CancelledError: + # undo partially completed work + raise + + Args: + coro_fns: an iterable of coroutine functions, i.e. callables that + return a coroutine object when called. Use ``functools.partial`` or + lambdas to pass arguments. + + delay: amount of time, in seconds, between starting coroutines. If + ``None``, the coroutines will run sequentially. + + loop: the event loop to use. + + Returns: + tuple *(winner_result, winner_index, exceptions)* where + + - *winner_result*: the result of the winning coroutine, or ``None`` + if no coroutines won. + + - *winner_index*: the index of the winning coroutine in + ``coro_fns``, or ``None`` if no coroutines won. If the winning + coroutine may return None on success, *winner_index* can be used + to definitively determine whether any coroutine won. + + - *exceptions*: list of exceptions returned by the coroutines. + ``len(exceptions)`` is equal to the number of coroutines actually + started, and the order is the same as in ``coro_fns``. The winning + coroutine's entry is ``None``. + + """ + # TODO: when we have aiter() and anext(), allow async iterables in coro_fns. + loop = loop or events.get_running_loop() + enum_coro_fns = enumerate(coro_fns) + winner_result = None + winner_index = None + exceptions = [] + running_tasks = [] + + async def run_one_coro( + previous_failed: typing.Optional[locks.Event]) -> None: + # Wait for the previous task to finish, or for delay seconds + if previous_failed is not None: + with contextlib.suppress(exceptions_mod.TimeoutError): + # Use asyncio.wait_for() instead of asyncio.wait() here, so + # that if we get cancelled at this point, Event.wait() is also + # cancelled, otherwise there will be a "Task destroyed but it is + # pending" later. + await tasks.wait_for(previous_failed.wait(), delay) + # Get the next coroutine to run + try: + this_index, coro_fn = next(enum_coro_fns) + except StopIteration: + return + # Start task that will run the next coroutine + this_failed = locks.Event() + next_task = loop.create_task(run_one_coro(this_failed)) + running_tasks.append(next_task) + assert len(running_tasks) == this_index + 2 + # Prepare place to put this coroutine's exceptions if not won + exceptions.append(None) + assert len(exceptions) == this_index + 1 + + try: + result = await coro_fn() + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as e: + exceptions[this_index] = e + this_failed.set() # Kickstart the next coroutine + else: + # Store winner's results + nonlocal winner_index, winner_result + assert winner_index is None + winner_index = this_index + winner_result = result + # Cancel all other tasks. We take care to not cancel the current + # task as well. If we do so, then since there is no `await` after + # here and CancelledError are usually thrown at one, we will + # encounter a curious corner case where the current task will end + # up as done() == True, cancelled() == False, exception() == + # asyncio.CancelledError. This behavior is specified in + # https://bugs.python.org/issue30048 + for i, t in enumerate(running_tasks): + if i != this_index: + t.cancel() + + first_task = loop.create_task(run_one_coro(None)) + running_tasks.append(first_task) + try: + # Wait for a growing list of tasks to all finish: poor man's version of + # curio's TaskGroup or trio's nursery + done_count = 0 + while done_count != len(running_tasks): + done, _ = await tasks.wait(running_tasks) + done_count = len(done) + # If run_one_coro raises an unhandled exception, it's probably a + # programming error, and I want to see it. + if __debug__: + for d in done: + if d.done() and not d.cancelled() and d.exception(): + raise d.exception() + return winner_result, winner_index, exceptions + finally: + # Make sure no tasks are left running if we leave this function + for t in running_tasks: + t.cancel() From 4e9f06a192ede48dd77f47cfb0fb3058ca64fce1 Mon Sep 17 00:00:00 2001 From: Oleg Iarygin Date: Wed, 26 Apr 2023 22:54:35 +0400 Subject: [PATCH 9/9] Address patchcheck report --- Lib/test/test_asyncio/test_staggered.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/Lib/test/test_asyncio/test_staggered.py b/Lib/test/test_asyncio/test_staggered.py index 4cb82afa419d00..775f6f0901fa59 100644 --- a/Lib/test/test_asyncio/test_staggered.py +++ b/Lib/test/test_asyncio/test_staggered.py @@ -113,6 +113,3 @@ async def test_simultaneous_success_fail(self): ) self.assertEqual(winner_result, 0) self.assertEqual(winner_idx, 0) - - -