Skip to content

Commit

Permalink
Implemented the Happy Eyeballs algorithm for connect_tcp()
Browse files Browse the repository at this point in the history
Fixes #69.
  • Loading branch information
agronholm committed Oct 4, 2019
1 parent 40f7547 commit d0ffbec
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 20 deletions.
76 changes: 57 additions & 19 deletions anyio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from contextlib import contextmanager
from importlib import import_module
from ssl import SSLContext
from typing import TypeVar, Callable, Union, Optional, Awaitable, Coroutine, Any, Dict
from typing import TypeVar, Callable, Union, Optional, Awaitable, Coroutine, Any, Dict, List

import sniffio

Expand Down Expand Up @@ -276,11 +276,18 @@ def aopen(file: Union[str, 'os.PathLike', int], mode: str = 'r', buffering: int
async def connect_tcp(
address: IPAddressType, port: int, *, ssl_context: Optional[SSLContext] = None,
autostart_tls: bool = False, bind_host: Optional[IPAddressType] = None,
bind_port: Optional[int] = None, tls_standard_compatible: bool = True
bind_port: Optional[int] = None, tls_standard_compatible: bool = True,
happy_eyeballs_delay: float = 0.25
) -> SocketStream:
"""
Connect to a host using the TCP protocol.
This function implements the stateless version of the Happy Eyeballs algorithm (RFC 6555).
If ``address`` is a host name that resolves to multiple IP addresses, each one is tried until
one connection attempt succeeds. If the first attempt does not connected within 300
milliseconds, a second attempt is started using the next address in the list, and so on.
For IPv6 enabled systems, IPv6 addresses are tried first.
:param address: the IP address or host name to connect to
:param port: port on the target host to connect to
:param ssl_context: default SSL context to use for TLS handshakes
Expand All @@ -292,34 +299,65 @@ async def connect_tcp(
:exc:`~ssl.SSLEOFError` may be raised during reads from the stream.
Some protocols, such as HTTP, require this option to be ``False``.
See :meth:`~ssl.SSLContext.wrap_socket` for details.
:param happy_eyeballs_delay: delay (in seconds) before starting the next connection attempt
:return: a socket stream object
:raises OSError: if the connection attempt fails
"""
# Placed here due to https://github.com/python/mypy/issues/7057
stream = None # type: Optional[SocketStream]

async def try_connect(af: int, addr: str, delay: float):
nonlocal stream

if delay:
await sleep(delay)

raw_socket = socket.socket(af, socket.SOCK_STREAM)
sock = asynclib.Socket(raw_socket)
try:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
if interface is not None and bind_port is not None:
await sock.bind((interface, bind_port))

await sock.connect((addr, port))
except OSError as exc:
oserrors.append(exc)
await sock.close()
return
except BaseException:
await sock.close()
raise

assert stream is None
stream = _networking.SocketStream(sock, ssl_context, str(address), tls_standard_compatible)
await tg.cancel_scope.cancel()

asynclib = _get_asynclib()
interface, family = None, 0 # type: Optional[str], int
if bind_host:
interface, family, _v6only = await _networking.get_bind_address(bind_host)

# getaddrinfo() will raise an exception if name resolution fails
address = str(address)
addrlist = await run_in_thread(socket.getaddrinfo, address, port, family, socket.SOCK_STREAM)
family, type_, proto, _cn, sa = addrlist[0]
raw_socket = socket.socket(family, type_, proto)
sock = _get_asynclib().Socket(raw_socket)
try:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
if interface is not None and bind_port is not None:
await sock.bind((interface, bind_port))
addrlist = await run_in_thread(socket.getaddrinfo, str(address), port, family,
socket.SOCK_STREAM)

await sock.connect(sa)
stream = _networking.SocketStream(sock, ssl_context, address, tls_standard_compatible)
# Sort the list so that IPv4 addresses are tried last
addresses = sorted(((item[0], item[-1][0]) for item in addrlist),
key=lambda item: item[0] == socket.AF_INET)

if autostart_tls:
await stream.start_tls()
oserrors = [] # type: List[OSError]
async with create_task_group() as tg:
for i, (af, addr) in enumerate(addresses):
await tg.spawn(try_connect, af, addr, i * happy_eyeballs_delay)

return stream
except BaseException:
await sock.close()
raise
if stream is None:
raise OSError('All connection attempts failed') from asynclib.ExceptionGroup(oserrors)

if autostart_tls:
await stream.start_tls()

return stream


async def connect_unix(path: Union[str, 'os.PathLike']) -> SocketStream:
Expand Down
1 change: 1 addition & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.

- Added the possibility to parametrize regular pytest test functions against the selected list of
backends
- Implemented the Happy Eyeballs (:rfc:`6555`) algorithm for ``anyio.connect_tcp()``
- Fixed ``KeyError`` on asyncio and curio where entering and exiting a cancel scope happens in
different tasks
- Dropped support for trio v0.11
Expand Down
47 changes: 46 additions & 1 deletion tests/test_networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
create_task_group, connect_tcp, create_udp_socket, connect_unix, create_unix_server,
create_tcp_server, wait_all_tasks_blocked)
from anyio.exceptions import (
IncompleteRead, DelimiterNotFound, ClosedResourceError, ResourceBusyError)
IncompleteRead, DelimiterNotFound, ClosedResourceError, ResourceBusyError, ExceptionGroup)


@pytest.fixture(scope='module')
Expand Down Expand Up @@ -263,6 +263,51 @@ async def receive_data():
finally:
await tg.cancel_scope.cancel()

@pytest.mark.skipif(not socket.has_ipv6, reason='IPv6 is not available')
@pytest.mark.parametrize('interface, expected_addr', [
(None, b'::1'),
('127.0.0.1', b'127.0.0.1'),
('::1', b'::1')
])
@pytest.mark.anyio
async def test_happy_eyeballs(self, interface, expected_addr, monkeypatch):
async def handle_client(stream):
addr, port, *rest = stream._socket._raw_socket.getpeername()
await stream.send_all(addr.encode() + b'\n')

async def server():
async for stream in stream_server.accept_connections():
await tg.spawn(handle_client, stream)

# Fake getaddrinfo() to return IPv4 addresses first so we can test the IPv6 preference
fake_results = [(socket.AF_INET, socket.SOCK_STREAM, '', ('127.0.0.1', 0)),
(socket.AF_INET6, socket.SOCK_STREAM, '', ('::1', 0))]
monkeypatch.setattr('socket.getaddrinfo', lambda *args: fake_results)

async with await create_tcp_server(interface=interface) as stream_server:
async with create_task_group() as tg:
await tg.spawn(server)
async with await connect_tcp('localhost', stream_server.port) as client:
assert await client.receive_until(b'\n', 100) == expected_addr

await stream_server.close()

@pytest.mark.skipif(not socket.has_ipv6, reason='IPv6 is not available')
@pytest.mark.anyio
async def test_happy_eyeballs_connrefused(self):
dummy_socket = socket.socket(socket.AF_INET6)
dummy_socket.bind(('::', 0))
free_port = dummy_socket.getsockname()[1]
dummy_socket.close()

with pytest.raises(OSError) as exc:
await connect_tcp('localhost', free_port)

assert exc.match('All connection attempts failed')
assert isinstance(exc.value.__cause__, ExceptionGroup)
for exc in exc.value.__cause__.exceptions:
assert isinstance(exc, ConnectionRefusedError)


class TestUNIXStream:
@pytest.mark.skipif(sys.platform == 'win32',
Expand Down

0 comments on commit d0ffbec

Please sign in to comment.