Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-98388: Add tests for happy eyeballs and its internal workings #98389

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
80 changes: 80 additions & 0 deletions Lib/test/test_asyncio/test_base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,27 @@ def test_ipaddr_info_no_inet_pton(self, m_socket):
socket.SOCK_STREAM,
socket.IPPROTO_TCP))

def test_interleave_addrinfos(self):
SIX_A = (socket.AF_INET6, 0, 0, '', ('2001:db8::1', 1))
SIX_B = (socket.AF_INET6, 0, 0, '', ('2001:db8::2', 2))
SIX_C = (socket.AF_INET6, 0, 0, '', ('2001:db8::3', 3))
SIX_D = (socket.AF_INET6, 0, 0, '', ('2001:db8::4', 4))
FOUR_A = (socket.AF_INET, 0, 0, '', ('192.0.2.1', 5))
FOUR_B = (socket.AF_INET, 0, 0, '', ('192.0.2.2', 6))
FOUR_C = (socket.AF_INET, 0, 0, '', ('192.0.2.3', 7))
FOUR_D = (socket.AF_INET, 0, 0, '', ('192.0.2.4', 8))

addrinfos = [SIX_A, SIX_B, SIX_C, SIX_D, FOUR_A, FOUR_B, FOUR_C, FOUR_D]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be a more thorough test if we mixed up the order a bit, e.g. like this?

Suggested change
addrinfos = [SIX_A, SIX_B, SIX_C, SIX_D, FOUR_A, FOUR_B, FOUR_C, FOUR_D]
addrinfos = [SIX_A, SIX_B, SIX_C, FOUR_A, FOUR_B, FOUR_C, FOUR_D, SIX_D]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM @gvanrossum.

expected = [SIX_A, FOUR_A, SIX_B, FOUR_B, SIX_C, FOUR_C, SIX_D, FOUR_D]

self.assertEqual(expected, base_events._interleave_addrinfos(addrinfos))

expected_fafc_2 = [SIX_A, SIX_B, FOUR_A, SIX_C, FOUR_B, SIX_D, FOUR_C, FOUR_D]
self.assertEqual(
expected_fafc_2,
base_events._interleave_addrinfos(addrinfos, first_address_family_count=2),
)


class BaseEventLoopTests(test_utils.TestCase):

Expand Down Expand Up @@ -1426,6 +1447,65 @@ def getaddrinfo_task(*args, **kwds):
self.assertRaises(
OSError, self.loop.run_until_complete, coro)

@unittest.skipUnless(socket_helper.IPV6_ENABLED, 'no IPv6 support')
@patch_socket
def test_create_connection_happy_eyeballs(self, m_socket):

class MyProto(asyncio.Protocol):
pass

async def getaddrinfo(*args, **kw):
return [(socket.AF_INET6, 0, 0, '', ('2001:db8::1', 1)),
(socket.AF_INET, 0, 0, '', ('192.0.2.1', 5))]

async def sock_connect(sock, address):
if address[0] == '2001:db8::1':
await asyncio.sleep(1)
sock.connect(address)

self.loop._add_reader = mock.Mock()
self.loop._add_writer = mock.Mock()
self.loop.getaddrinfo = getaddrinfo
self.loop.sock_connect = sock_connect

coro = self.loop.create_connection(MyProto, 'example.com', 80, happy_eyeballs_delay=0.3)
transport, protocol = self.loop.run_until_complete(coro)
try:
sock = transport._sock
sock.connect.assert_called_with(('192.0.2.1', 5))
finally:
transport.close()
test_utils.run_briefly(self.loop) # allow transport to close

@patch_socket
def test_create_connection_happy_eyeballs_ipv4_only(self, m_socket):

class MyProto(asyncio.Protocol):
pass

async def getaddrinfo(*args, **kw):
return [(socket.AF_INET, 0, 0, '', ('192.0.2.1', 5)),
(socket.AF_INET, 0, 0, '', ('192.0.2.2', 6))]

async def sock_connect(sock, address):
if address[0] == '192.0.2.1':
await asyncio.sleep(1)
sock.connect(address)

self.loop._add_reader = mock.Mock()
self.loop._add_writer = mock.Mock()
self.loop.getaddrinfo = getaddrinfo
self.loop.sock_connect = sock_connect

coro = self.loop.create_connection(MyProto, 'example.com', 80, happy_eyeballs_delay=0.3)
transport, protocol = self.loop.run_until_complete(coro)
try:
sock = transport._sock
sock.connect.assert_called_with(('192.0.2.2', 6))
finally:
transport.close()
test_utils.run_briefly(self.loop) # allow transport to close

@patch_socket
def test_create_connection_bluetooth(self, m_socket):
# See http://bugs.python.org/issue27136, fallback to getaddrinfo when
Expand Down
115 changes: 115 additions & 0 deletions Lib/test/test_asyncio/test_staggered.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import asyncio
import functools
import unittest
from asyncio.staggered import staggered_race


# To prevent a warning "test altered the execution environment"
def tearDownModule():
asyncio.set_event_loop_policy(None)


class TestStaggered(unittest.IsolatedAsyncioTestCase):
@staticmethod
async def waiting_coroutine(return_value, wait_seconds, success):
await asyncio.sleep(wait_seconds)
if success:
return return_value
raise RuntimeError(str(return_value))

def get_waiting_coroutine_factory(self, return_value, wait_seconds, success):
return functools.partial(self.waiting_coroutine, return_value, wait_seconds, success)

async def test_single_success(self):
winner_result, winner_idx, exceptions = await staggered_race(
(self.get_waiting_coroutine_factory(0, 0.1, True),),
0.1,
)
self.assertEqual(winner_result, 0)
self.assertEqual(winner_idx, 0)
self.assertEqual(len(exceptions), 1)
self.assertIsNone(exceptions[0])

async def test_single_fail(self):
winner_result, winner_idx, exceptions = await staggered_race(
(self.get_waiting_coroutine_factory(0, 0.1, False),),
0.1,
)
self.assertEqual(winner_result, None)
self.assertEqual(winner_idx, None)
self.assertEqual(len(exceptions), 1)
self.assertIsInstance(exceptions[0], RuntimeError)

async def test_first_win(self):
winner_result, winner_idx, exceptions = await staggered_race(
(
self.get_waiting_coroutine_factory(0, 0.2, True),
self.get_waiting_coroutine_factory(1, 0.2, True),
),
0.1,
)
self.assertEqual(winner_result, 0)
self.assertEqual(winner_idx, 0)
self.assertEqual(len(exceptions), 2)
self.assertIsNone(exceptions[0])
self.assertIsInstance(exceptions[1], asyncio.CancelledError)

async def test_second_win(self):
winner_result, winner_idx, exceptions = await staggered_race(
(
self.get_waiting_coroutine_factory(0, 0.3, True),
self.get_waiting_coroutine_factory(1, 0.1, True),
),
0.1,
)
self.assertEqual(winner_result, 1)
self.assertEqual(winner_idx, 1)
self.assertEqual(len(exceptions), 2)
self.assertIsInstance(exceptions[0], asyncio.CancelledError)
self.assertIsNone(exceptions[1])

async def test_first_fail(self):
winner_result, winner_idx, exceptions = await staggered_race(
(
self.get_waiting_coroutine_factory(0, 0.2, False),
self.get_waiting_coroutine_factory(1, 0.2, True),
),
0.1,
)
self.assertEqual(winner_result, 1)
self.assertEqual(winner_idx, 1)
self.assertEqual(len(exceptions), 2)
self.assertIsInstance(exceptions[0], RuntimeError)
self.assertIsNone(exceptions[1])

async def test_second_fail(self):
winner_result, winner_idx, exceptions = await staggered_race(
(
self.get_waiting_coroutine_factory(0, 0.2, True),
self.get_waiting_coroutine_factory(1, 0, False),
),
0.1,
)
self.assertEqual(winner_result, 0)
self.assertEqual(winner_idx, 0)
self.assertEqual(len(exceptions), 2)
self.assertIsNone(exceptions[0])
self.assertIsInstance(exceptions[1], RuntimeError)

async def test_simultaneous_success_fail(self):
# There's a potential race condition here:
# https://github.com/python/cpython/issues/86296
# As with any race condition, it can be difficult to reproduce.
# This test may not fail every time.
Comment on lines +100 to +103
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a note to this comment that the race condition has been fixed in 3.12, e.g.

Suggested change
# There's a potential race condition here:
# https://github.com/python/cpython/issues/86296
# As with any race condition, it can be difficult to reproduce.
# This test may not fail every time.
# There's a potential race condition here (fixed in Python 3.12):
# https://github.com/python/cpython/issues/86296
# As with any race condition, it can be difficult to reproduce.
# This test may not fail every time.

for i in range(201):
time_unit = 0.0001 * i
winner_result, winner_idx, exceptions = await staggered_race(
(
self.get_waiting_coroutine_factory(0, time_unit*2, True),
self.get_waiting_coroutine_factory(1, time_unit, False),
self.get_waiting_coroutine_factory(2, 0.05, True)
),
time_unit,
)
self.assertEqual(winner_result, 0)
self.assertEqual(winner_idx, 0)
Loading