Skip to content

Commit

Permalink
Implicitly resolve hostnames in socket methods
Browse files Browse the repository at this point in the history
  • Loading branch information
njsmith committed Dec 21, 2017
1 parent 8bff4d5 commit 2440049
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 153 deletions.
68 changes: 6 additions & 62 deletions docs/source/reference-io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,9 @@ Socket objects
library socket objects <python:socket-objects>`, with a few
important differences:

**Async all the things:** Most obviously, everything is made
"trio-style": blocking methods become async methods, and the
following attributes are *not* supported:
First, and most obviously, everything is made "trio-style":
blocking methods become async methods, and the following attributes
are *not* supported:

* :meth:`~socket.socket.setblocking`: trio sockets always act like
blocking sockets; if you need to read/write from multiple sockets
Expand All @@ -351,71 +351,15 @@ Socket objects
synchronous, so it can't be implemented on top of an async
socket.

**No implicit name resolution:** In the standard library
:mod:`socket` API, there are number of methods that take network
addresses as arguments. When given a numeric address this is fine::

# OK
sock.bind(("127.0.0.1", 80))
sock.connect(("2607:f8b0:4000:80f::200e", 80))

But in the standard library, these methods also accept hostnames,
and in this case implicitly trigger a DNS lookup to find the IP
address::

# Might block!
sock.bind(("localhost", 80))
sock.connect(("google.com", 80))

This is problematic because DNS lookups are a blocking operation.

For simplicity, trio forbids such usages: hostnames must be
"pre-resolved" to numeric addresses before they are passed to
socket methods like :meth:`bind` or :meth:`connect`. In most cases
this can be easily accomplished by calling either
:meth:`resolve_local_address` or :meth:`resolve_remote_address`.

.. method:: resolve_local_address(address)

Resolve the given address into a numeric address suitable for
passing to :meth:`bind`.

This performs the same address resolution that the standard library
:meth:`~socket.socket.bind` call would do, taking into account the
current socket's settings (e.g. if this is an IPv6 socket then it
returns IPv6 addresses). In particular, a hostname of ``None`` is
mapped to the wildcard address.

.. method:: resolve_remote_address(address)

Resolve the given address into a numeric address suitable for
passing to :meth:`connect` or similar.

This performs the same address resolution that the standard library
:meth:`~socket.socket.connect` call would do, taking into account the
current socket's settings (e.g. if this is an IPv6 socket then it
returns IPv6 addresses). In particular, a hostname of ``None`` is
mapped to the localhost address.

The following methods are similar to the equivalents in
:func:`socket.socket`, but have some trio-specific quirks:

.. method:: bind

Bind this socket to the given address.

Unlike the stdlib :meth:`~socket.socket.bind`, this method
requires a pre-resolved address. See
:meth:`resolve_local_address`.
In addition, the following methods are similar to the equivalents
in :func:`socket.socket`, but have some trio-specific quirks:

.. method:: connect
:async:

Connect the socket to a remote address.

Similar to :meth:`socket.socket.connect`, except async and
requiring a pre-resolved address. See
:meth:`resolve_remote_address`.
Similar to :meth:`socket.socket.connect`, except async.

.. warning::

Expand Down
3 changes: 3 additions & 0 deletions newsfragments/377.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Trio socket methods like ``bind`` and ``connect`` no longer require
"pre-resolved" numeric addresses; you can now pass regular hostnames
and Trio will implicitly resolve them for you.
3 changes: 3 additions & 0 deletions newsfragments/377.removal.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
The ``resolve_local_address`` and ``resolve_remote_address`` methods
on Trio sockets have been deprecated; just pass your hostnames
directly to the socket methods you want to use.
93 changes: 34 additions & 59 deletions trio/_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def dup(self):

async def bind(self, address):
await _core.checkpoint()
self._check_address(address, require_resolved=True)
address = await self._resolve_local_address(address)
if (hasattr(_stdlib_socket, "AF_UNIX") and self.family == AF_UNIX
and address[0]):
# Use a thread for the filesystem traversal (unless it's an
Expand All @@ -488,62 +488,24 @@ async def wait_writable(self):
# Address handling
################################################################

# For socket operations that take addresses, Python helpfully accepts
# addresses containing names, and implicitly resolves them. This is no
# good, because the implicit resolution is blocking. We require that all
# such addresses be "pre-resolved" meaning:
#
# - For AF_INET or AF_INET6, they must contain only numeric elements. We
# check using getaddrinfo with AI_NUMERIC{HOST,SERV} flags set.
# - For other families, we cross our fingers and hope the user knows what
# they're doing.
#
# And we provide two convenience functions to do this "pre-resolution",
# which attempt to match what Python does.

def _check_address(self, address, *, require_resolved):
# Take an address in Python's representation, and returns a new address in
# the same representation, but with names resolved to numbers,
# etc.
async def _resolve_address(self, address, flags):
# Do some pre-checking (or exit early for non-IP sockets)
if self._sock.family == AF_INET:
if not isinstance(address, tuple) or not len(address) == 2:
await _core.checkpoint()
raise ValueError("address should be a (host, port) tuple")
elif self._sock.family == AF_INET6:
if not isinstance(address, tuple) or not 2 <= len(address) <= 4:
await _core.checkpoint()
raise ValueError(
"address should be a (host, port, [flowinfo, [scopeid]]) "
"tuple"
)
else:
return
if require_resolved: # for AF_INET{,6} only
try:
_stdlib_socket.getaddrinfo(
address[0],
address[1],
self._sock.family,
real_socket_type(self._sock.type),
self._sock.proto,
flags=_NUMERIC_ONLY
)
except gaierror as exc:
if exc.errno == _stdlib_socket.EAI_NONAME:
raise ValueError(
"expected an already-resolved numeric address, not {}"
.format(address)
)
else:
raise

# Take an address in Python's representation, and returns a new address in
# the same representation, but with names resolved to numbers,
# etc.
async def _resolve_address(self, address, flags):
await _core.checkpoint_if_cancelled()
try:
self._check_address(address, require_resolved=False)
except:
await _core.cancel_shielded_checkpoint()
raise
if self._sock.family not in (AF_INET, AF_INET6):
await _core.cancel_shielded_checkpoint()
await _core.checkpoint()
return address
# Since we always pass in an explicit family here, AI_ADDRCONFIG
# doesn't add any value -- if we have no ipv6 connectivity and are
Expand Down Expand Up @@ -574,18 +536,32 @@ async def _resolve_address(self, address, flags):
if len(address) >= 4:
normed[3] = address[3]
normed = tuple(normed)
# Should never fail:
self._check_address(normed, require_resolved=True)
return normed

# Returns something appropriate to pass to bind()
async def resolve_local_address(self, address):
async def _resolve_local_address(self, address):
return await self._resolve_address(address, AI_PASSIVE)

@deprecated(
"0.3.0",
issue=377,
instead="just pass the address to the method you want to use"
)
async def resolve_local_address(self, address):
return await self._resolve_local_address(address)

# Returns something appropriate to pass to connect()/sendto()/sendmsg()
async def resolve_remote_address(self, address):
async def _resolve_remote_address(self, address):
return await self._resolve_address(address, 0)

@deprecated(
"0.3.0",
issue=377,
instead="just pass the address to the method you want to use"
)
async def resolve_remote_address(self, address):
return await self._resolve_remote_address(address)

async def _nonblocking_helper(self, fn, args, kwargs, wait_fn):
# We have to reconcile two conflicting goals:
# - We want to make it look like we always blocked in doing these
Expand Down Expand Up @@ -661,8 +637,8 @@ async def connect(self, address):
# off, then the socket becomes writable as a completion
# notification. This means it isn't really cancellable... we close the
# socket if cancelled, to avoid confusion.
address = await self._resolve_remote_address(address)
async with _try_sync():
self._check_address(address, require_resolved=True)
# An interesting puzzle: can a non-blocking connect() return EINTR
# (= raise InterruptedError)? PEP 475 specifically left this as
# the one place where it lets an InterruptedError escape instead
Expand Down Expand Up @@ -786,13 +762,13 @@ async def connect(self, address):

@_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=())
async def sendto(self, *args):
"""Similar to :meth:`socket.socket.sendto`, but async and requiring a
pre-resolved address. See :meth:`resolve_remote_address`.
"""Similar to :meth:`socket.socket.sendto`, but async.
"""
# args is: data[, flags], address)
# and kwargs are not accepted
self._check_address(args[-1], require_resolved=True)
args = list(args)
args[-1] = await self._resolve_remote_address(args[-1])
return await self._nonblocking_helper(
_stdlib_socket.socket.sendto, args, {}, _core.wait_socket_writable
)
Expand All @@ -805,9 +781,7 @@ async def sendto(self, *args):

@_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=())
async def sendmsg(self, *args):
"""Similar to :meth:`socket.socket.sendmsg`, but async and
requiring a pre-resolved address. See
:meth:`resolve_remote_address`.
"""Similar to :meth:`socket.socket.sendmsg`, but async.
Only available on platforms where :meth:`socket.socket.sendmsg` is
available.
Expand All @@ -816,7 +790,8 @@ async def sendmsg(self, *args):
# args is: buffers[, ancdata[, flags[, address]]]
# and kwargs are not accepted
if len(args) == 4 and args[-1] is not None:
self._check_address(args[-1], require_resolved=True)
args = list(args)
args[-1] = await self._resolve_remote_address(args[-1])
return await self._nonblocking_helper(
_stdlib_socket.socket.sendmsg, args, {},
_core.wait_socket_writable
Expand Down
65 changes: 33 additions & 32 deletions trio/tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,35 +376,30 @@ async def test_SocketType_simple_server(address, socket_type):
assert await client.recv(1) == b"x"


# Direct thorough tests of the implicit resolver helpers
async def test_SocketType_resolve():
sock4 = tsocket.socket(family=tsocket.AF_INET)
with assert_checkpoints():
got = await sock4.resolve_local_address((None, 80))
got = await sock4._resolve_local_address((None, 80))
assert got == ("0.0.0.0", 80)
with assert_checkpoints():
got = await sock4.resolve_remote_address((None, 80))
got = await sock4._resolve_remote_address((None, 80))
assert got == ("127.0.0.1", 80)

sock6 = tsocket.socket(family=tsocket.AF_INET6)
with assert_checkpoints():
got = await sock6.resolve_local_address((None, 80))
got = await sock6._resolve_local_address((None, 80))
assert got == ("::", 80, 0, 0)

with assert_checkpoints():
got = await sock6.resolve_remote_address((None, 80))
got = await sock6._resolve_remote_address((None, 80))
assert got == ("::1", 80, 0, 0)

# AI_PASSIVE only affects the wildcard address, so for everything else
# resolve_local_address and resolve_remote_address should work the same:
for res in ["resolve_local_address", "resolve_remote_address"]:
# _resolve_local_address and _resolve_remote_address should work the same:
for res in ["_resolve_local_address", "_resolve_remote_address"]:

async def s4res(*args):
with assert_checkpoints():
return await getattr(sock4, res)(*args)
return await getattr(sock4, res)(*args)

async def s6res(*args):
with assert_checkpoints():
return await getattr(sock6, res)(*args)
return await getattr(sock6, res)(*args)

assert await s4res(("1.2.3.4", "http")) == ("1.2.3.4", 80)
assert await s6res(("1::2", "http")) == ("1::2", 80, 0, 0)
Expand Down Expand Up @@ -453,20 +448,21 @@ async def s6res(*args):
await s6res(("1.2.3.4", 80, 0, 0, 0))


async def test_SocketType_requires_preresolved(monkeypatch):
sock = tsocket.socket()
with pytest.raises(ValueError):
async def test_SocketType_unresolved_names():
with tsocket.socket() as sock:
await sock.bind(("localhost", 0))
assert sock.getsockname()[0] == "127.0.0.1"
sock.listen(10)

# I don't think it's possible to actually get a gaierror from the way we
# call getaddrinfo in _check_address, but just in case someone finds a
# way, check that it propagates correctly
def gai_oops(*args, **kwargs):
raise tsocket.gaierror("nope!")
with tsocket.socket() as sock2:
await sock2.connect(("localhost", sock.getsockname()[1]))
assert sock2.getpeername() == sock.getsockname()

monkeypatch.setattr(stdlib_socket, "getaddrinfo", gai_oops)
with pytest.raises(tsocket.gaierror):
await sock.bind(("localhost", 0))
# check gaierror propagates out
with tsocket.socket() as sock:
with pytest.raises(tsocket.gaierror):
# definitely not a valid request
await sock.bind(("1.2:3", -1))


# This tests all the complicated paths through _nonblocking_helper, using recv
Expand Down Expand Up @@ -623,11 +619,15 @@ async def test_send_recv_variants():
with a, b:
await a.bind(("127.0.0.1", 0))
await b.bind(("127.0.0.1", 0))
# recvfrom
assert await a.sendto(b"xxx", b.getsockname()) == 3
(data, addr) = await b.recvfrom(10)
assert data == b"xxx"
assert addr == a.getsockname()

targets = [b.getsockname(), ("localhost", b.getsockname()[1])]

# recvfrom + sendto, with and without names
for target in targets:
assert await a.sendto(b"xxx", target) == 3
(data, addr) = await b.recvfrom(10)
assert data == b"xxx"
assert addr == a.getsockname()

# sendto + flags
#
Expand Down Expand Up @@ -675,8 +675,9 @@ async def test_send_recv_variants():
assert addr == a.getsockname()

if hasattr(a, "sendmsg"):
assert await a.sendmsg([b"x", b"yz"], [], 0, b.getsockname()) == 3
assert await b.recvfrom(10) == (b"xyz", a.getsockname())
for target in targets:
assert await a.sendmsg([b"x", b"yz"], [], 0, target) == 3
assert await b.recvfrom(10) == (b"xyz", a.getsockname())

a = tsocket.socket(type=tsocket.SOCK_DGRAM)
b = tsocket.socket(type=tsocket.SOCK_DGRAM)
Expand Down

0 comments on commit 2440049

Please sign in to comment.