Skip to content

Commit

Permalink
Make it possible to override hostname resolution and socket behavior
Browse files Browse the repository at this point in the history
See python-triogh-170.

Still needs tests.
  • Loading branch information
njsmith committed Jul 27, 2017
1 parent 64a0bbc commit 4e3fa63
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 16 deletions.
21 changes: 13 additions & 8 deletions docs/source/reference-io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,19 @@ standard library :mod:`socket` module. Most constants (like
are simply re-exported unchanged. But there are also some differences,
which are described here.

.. function:: socket(...)
socketpair(...)
fromfd(...)
fromshare(...)

Trio provides analogues to all the standard library functions that
return socket objects; their interface is identical, except that
they're modified to return trio socket objects instead.
First, Trio provides analogues to all the standard library functions
that return socket objects; their interface is identical, except that
they're modified to return trio socket objects instead:

.. autofunction:: socket

.. autofunction:: socketpair

.. autofunction:: fromfd

.. function:: fromshare(data)

Like :func:`socket.fromshare`, but returns a trio socket object.

In addition, there is a new function to directly convert a standard
library socket into a trio socket:
Expand Down
30 changes: 30 additions & 0 deletions docs/source/reference-testing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,36 @@ implementations:
.. autofunction:: check_half_closeable_stream


Virtual networking for testing
------------------------------

In the previous section you learned how to use virtual in-memory
streams to test protocols that are written against trio's
:class:`~trio.abc.Stream` abstraction. But what if you have more
complicated networking code – the kind of code that makes connections
to multiple hosts, or opens a listening socket, or sends UDP packets?

Trio doesn't itself provide a virtual in-memory network implementation
for testing – but :mod:`trio.socket` module does provide the hooks you
need to write your own! And if you're interested in helping implement
a reusable virtual network for testing, then `please get in touch
<https://github.com/python-trio/trio/issues/170>`__.

Note that these APIs are actually in :mod:`trio.socket` and
:mod:`trio.abc`, but we document them here because they're primarily
intended for testing.

.. autofunction:: trio.socket.set_custom_hostname_resolver

.. autoclass:: trio.abc.HostnameResolver
:members:

.. autofunction:: trio.socket.set_custom_socket_factory

.. autoclass:: trio.abc.SocketFactory
:members:


Testing checkpoints
--------------------

Expand Down
62 changes: 62 additions & 0 deletions trio/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,68 @@ def after_io_wait(self, timeout):
"""


class HostnameResolver(metaclass=_abc.ABCMeta):
"""If you have a custom hostname resolver, then implementing
:class:`HostnameResolver` allows you to register this to be used by trio.
See :func:`trio.socket.set_custom_hostname_resolver`.
"""
__slots__ = ()

@_abc.abstractmethod
async def getaddrinfo(self, host, port, family=0, type=0, proto=0, flags=0):
"""A custom implementation of :func:`~trio.socket.getaddrinfo`.
Called by :func:`trio.socket.getaddrinfo`.
If ``host`` is given as a numeric IP address, then
:func:`~trio.socket.getaddrinfo` may handle the request itself rather
than calling this method.
"""

@_abc.abstractmethod
async def getnameinfo(self, sockaddr, flags):
"""A custom implementation of :func:`~trio.socket.getnameinfo`.
Called by :func:`trio.socket.getnameinfo`.
"""


class SocketFactory(metaclass=_abc.ABCMeta):
"""If you write a custom class implementing the trio socket interface,
then you can use a :class:`SocketFactory` to get trio to use it.
See :func:`trio.socket.set_custom_socket_factory`.
"""

@_abc.abstractmethod
def socket(self, family=None, type=None, proto=None):
"""Create and return a socket object.
Called by :func:`trio.socket.socket`.
Note that unlike :func:`trio.socket.socket`, this does not take a
``fileno=`` argument. If a ``fileno=`` is specified, then
:func:`trio.socket.socket` returns a regular trio socket object
instead of calling this method.
"""

@_abc.abstractmethod
def is_trio_socket(self, obj):
"""Check if the given object is a socket instance.
Called by :func:`trio.socket.is_trio_socket`, which returns True if
the given object is a builtin trio socket object *or* if this method
returns True.
"""


# We use ABCMeta instead of ABC, plus setting __slots__=(), so as not to force
# a __dict__ onto subclasses.
class AsyncResource(metaclass=_abc.ABCMeta):
Expand Down
129 changes: 121 additions & 8 deletions trio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,76 @@ async def __aexit__(self, etype, value, tb):
_reexport(_name)


################################################################
# Overrides
################################################################

_overrides = _core.RunLocal(hostname_resolver=None, socket_factory=None)

def set_custom_hostname_resolver(hostname_resolver):
"""Set a custom hostname resolver.
By default, trio's :func:`getaddrinfo` and :func:`getnameinfo` functions
use the standard system resolver functions. This function allows you to
customize that behavior. The main intended use case is for testing, but it
might also be useful for using third-party resolvers like `c-ares
<https://c-ares.haxx.se/>`__ (though be warned that these rarely make
perfect drop-in replacements for the system resolver). See
:class:`trio.abc.HostnameResolver` for more details.
Setting a custom hostname resolver affects all future calls to
:func:`getaddrinfo` and :func:`getnameinfo` within the enclosing call to
:func:`trio.run`. All other hostname resolution in trio is implemented in
terms of these functions.
Generally you should call this function just once, right at the beginning
of your program.
Args:
hostname_resolver (trio.abc.HostnameResolver or None): The new custom
hostname resolver, or None to restore the default behavior.
Returns:
The previous hostname resolver (which may be None).
"""
old = _overrides.hostname_resolver
_overrides.hostname_resolver = hostname_resolver
return old

__all__.append("set_custom_hostname_resolver")


def set_custom_socket_factory(socket_factory):
"""Set a custom socket object factory.
This function allows you to replace trio's normal socket class with a
custom class. This is very useful for testing, and probably a bad idea in
any other circumstance. See :class:`trio.abc.HostnameResolver` for more
details.
Setting a custom socket factory affects all future calls to :func:`socket`
and :func:`is_trio_socket` within the enclosing call to
:func:`trio.run`.
Generally you should call this function just once, right at the beginning
of your program.
Args:
socket_factory (trio.abc.SocketFactory or None): The new custom
socket factory, or None to restore the default behavior.
Returns:
The previous socket factory (which may be None).
"""
old = _overrides.socket_factory
_overrides.socket_factory = socket_factory
return old

__all__.append("set_custom_socket_factory")


################################################################
# getaddrinfo and friends
################################################################
Expand All @@ -107,6 +177,9 @@ async def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0):
different host than the one you intended; see `bpo-17305
<https://bugs.python.org/issue17305>`__.)
This function's behavior can be customized using
:func:`set_custom_hostname_resolver`.
"""

# If host and port are numeric, then getaddrinfo doesn't block and we can
Expand Down Expand Up @@ -134,9 +207,13 @@ def numeric_only_failure(exc):
# idna.encode will error out if the hostname has Capital Letters
# in it; with uts46=True it will lowercase them instead.
host = _idna.encode(host, uts46=True)
return await _run_in_worker_thread(
_stdlib_socket.getaddrinfo, host, port, family, type, proto, flags,
cancellable=True)
hr = _overrides.hostname_resolver
if hr is not None:
return await hr.getaddrinfo(host, port, family, type, proto, flags)
else:
return await _run_in_worker_thread(
_stdlib_socket.getaddrinfo, host, port, family, type, proto, flags,
cancellable=True)

__all__.append("getaddrinfo")

Expand All @@ -147,9 +224,16 @@ async def getnameinfo(sockaddr, flags):
Arguments and return values are identical to :func:`socket.getnameinfo`,
except that this version is async.
This function's behavior can be customized using
:func:`set_custom_hostname_resolver`.
"""
return await _run_in_worker_thread(
_stdlib_socket.getnameinfo, sockaddr, flags, cancellable=True)
hr = _overrides.hostname_resolver
if hr is not None:
return await hr.getnameinfo(sockaddr, flags)
else:
return await _run_in_worker_thread(
_stdlib_socket.getnameinfo, sockaddr, flags, cancellable=True)

__all__.append("getnameinfo")

Expand All @@ -173,32 +257,55 @@ async def getprotobyname(name):
################################################################

def from_stdlib_socket(sock):
"""Convert a standard library :func:`socket.socket` into a trio socket.
"""Convert a standard library :func:`socket.socket` object into a trio
socket object.
"""
return _SocketType(sock)
__all__.append("from_stdlib_socket")


@_wraps(_stdlib_socket.fromfd, assigned=(), updated=())
def fromfd(*args, **kwargs):
"""Like :func:`socket.fromfd`, but returns a trio socket object.
"""
return from_stdlib_socket(_stdlib_socket.fromfd(*args, **kwargs))
__all__.append("fromfd")


if hasattr(_stdlib_socket, "fromshare"):
@_wraps(_stdlib_socket.fromshare, assigned=(), updated=())
def fromshare(*args, **kwargs):
return from_stdlib_socket(_stdlib_socket.fromshare(*args, **kwargs))
__all__.append("fromshare")


@_wraps(_stdlib_socket.socketpair, assigned=(), updated=())
def socketpair(*args, **kwargs):
"""Like :func:`socket.socketpair`, but returns a pair of trio socket
objects.
"""
left, right = _stdlib_socket.socketpair(*args, **kwargs)
return (from_stdlib_socket(left), from_stdlib_socket(right))
__all__.append("socketpair")


@_wraps(_stdlib_socket.socket, assigned=(), updated=())
def socket(*args, **kwargs):
return from_stdlib_socket(_stdlib_socket.socket(*args, **kwargs))
def socket(family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None):
"""Create a new trio socket, like :func:`socket.socket`.
This function's behavior can be customized using
:func:`set_custom_socket_factory`.
"""
if fileno is None:
sf = _overrides.socket_factory
if sf is not None:
return sf.socket(family, type, proto)
stdlib_socket = _stdlib_socket.socket(family, type, proto, fileno)
return from_stdlib_socket(stdlib_socket)
__all__.append("socket")


Expand All @@ -209,7 +316,13 @@ def socket(*args, **kwargs):
def is_trio_socket(obj):
"""Check whether the given object is a trio socket.
This function's behavior can be customized using
:func:`set_custom_socket_factory`.
"""
sf = _overrides.socket_factory
if sf is not None and sf.is_trio_socket(obj):
return True
return isinstance(obj, _SocketType)

__all__.append("is_trio_socket")
Expand Down

0 comments on commit 4e3fa63

Please sign in to comment.