Skip to content

Commit

Permalink
Avoid calling await getaddrinfo(...) in exception handler (#812)
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert authored Oct 28, 2024
1 parent c484425 commit 3a62738
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 6 deletions.
5 changes: 5 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ Version history

This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.

**UNRELEASED**

- Fixed a misleading ``ValueError`` in the context of DNS failures
(`#815 <https://github.com/agronholm/anyio/issues/815>`_; PR by @graingert)

**4.6.2**

- Fixed regression caused by (`#807 <https://github.com/agronholm/anyio/pull/807>`_)
Expand Down
15 changes: 9 additions & 6 deletions src/anyio/_core/_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,14 @@ async def try_connect(remote_host: str, event: Event) -> None:
try:
addr_obj = ip_address(remote_host)
except ValueError:
addr_obj = None

if addr_obj is not None:
if isinstance(addr_obj, IPv6Address):
target_addrs = [(socket.AF_INET6, addr_obj.compressed)]
else:
target_addrs = [(socket.AF_INET, addr_obj.compressed)]
else:
# getaddrinfo() will raise an exception if name resolution fails
gai_res = await getaddrinfo(
target_host, remote_port, family=family, type=socket.SOCK_STREAM
Expand All @@ -194,7 +202,7 @@ async def try_connect(remote_host: str, event: Event) -> None:
# Organize the list so that the first address is an IPv6 address (if available)
# and the second one is an IPv4 addresses. The rest can be in whatever order.
v6_found = v4_found = False
target_addrs: list[tuple[socket.AddressFamily, str]] = []
target_addrs = []
for af, *rest, sa in gai_res:
if af == socket.AF_INET6 and not v6_found:
v6_found = True
Expand All @@ -204,11 +212,6 @@ async def try_connect(remote_host: str, event: Event) -> None:
target_addrs.insert(1, (af, sa[0]))
else:
target_addrs.append((af, sa[0]))
else:
if isinstance(addr_obj, IPv6Address):
target_addrs = [(socket.AF_INET6, addr_obj.compressed)]
else:
target_addrs = [(socket.AF_INET, addr_obj.compressed)]

oserrors: list[OSError] = []
async with create_task_group() as tg:
Expand Down
11 changes: 11 additions & 0 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1838,3 +1838,14 @@ async def test_getnameinfo() -> None:
expected_result = socket.getnameinfo(("127.0.0.1", 6666), 0)
result = await getnameinfo(("127.0.0.1", 6666))
assert result == expected_result


async def test_connect_tcp_getaddrinfo_context() -> None:
"""
See https://github.com/agronholm/anyio/issues/815
"""
with pytest.raises(socket.gaierror) as exc_info:
async with await connect_tcp("anyio.invalid", 6666):
pass

assert exc_info.value.__context__ is None

0 comments on commit 3a62738

Please sign in to comment.