Skip to content

Commit

Permalink
Notify tasks waiting on sockets that they've been closed by another task
Browse files Browse the repository at this point in the history
Fixes #14.
  • Loading branch information
agronholm committed Oct 21, 2018
1 parent e48f0dd commit 8822ef9
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 29 deletions.
12 changes: 7 additions & 5 deletions anyio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def wait_socket_readable(sock: Union[socket.SocketType, ssl.SSLSocket]) -> Await
Wait until the given socket has data to be read.
:param sock: a socket object
:raises anyio.exceptions.ClosedResourceError: if the socket is closed while waiting
"""
return _get_asynclib().wait_socket_readable(sock)
Expand All @@ -262,6 +263,7 @@ def wait_socket_writable(sock: Union[socket.SocketType, ssl.SSLSocket]) -> Await
Wait until the given socket can be written to.
:param sock: a socket object
:raises anyio.exceptions.ClosedResourceError: if the socket is closed while waiting
"""
return _get_asynclib().wait_socket_writable(sock)
Expand Down Expand Up @@ -302,7 +304,7 @@ async def connect_tcp(

return stream
except BaseException:
sock.close()
await sock.close()
raise


Expand All @@ -322,7 +324,7 @@ async def connect_unix(path: Union[str, Path]) -> SocketStream:
await sock.connect(path)
return _networking.SocketStream(sock)
except BaseException:
sock.close()
await sock.close()
raise


Expand All @@ -349,7 +351,7 @@ async def create_tcp_server(
sock.listen()
return _networking.SocketStreamServer(sock, ssl_context)
except BaseException:
sock.close()
await sock.close()
raise


Expand All @@ -376,7 +378,7 @@ async def create_unix_server(
sock.listen()
return _networking.SocketStreamServer(sock, None)
except BaseException:
sock.close()
await sock.close()
raise


Expand Down Expand Up @@ -413,7 +415,7 @@ async def create_udp_socket(

return _networking.DatagramSocket(sock)
except BaseException:
sock.close()
await sock.close()
raise


Expand Down
13 changes: 12 additions & 1 deletion anyio/_backends/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .._networking import BaseSocket
from .. import abc, claim_current_thread, _local, T_Retval
from ..exceptions import ExceptionGroup, CancelledError
from ..exceptions import ExceptionGroup, CancelledError, ClosedResourceError

try:
from asyncio import run as native_run, create_task, get_running_loop, current_task, all_tasks
Expand Down Expand Up @@ -417,6 +417,7 @@ async def aopen(*args, **kwargs):
# Networking
#


class Socket(BaseSocket):
__slots__ = '_loop', '_read_event', '_write_event'

Expand All @@ -435,6 +436,9 @@ async def _wait_readable(self) -> None:
finally:
self._loop.remove_reader(self._raw_socket)

if self._raw_socket.fileno() == -1:
raise ClosedResourceError

async def _wait_writable(self) -> None:
check_cancelled()
self._loop.add_writer(self._raw_socket, self._write_event.set)
Expand All @@ -444,6 +448,13 @@ async def _wait_writable(self) -> None:
finally:
self._loop.remove_writer(self._raw_socket)

if self._raw_socket.fileno() == -1:
raise ClosedResourceError

async def _notify_close(self) -> None:
self._read_event.set()
self._write_event.set()

async def _check_cancelled(self) -> None:
check_cancelled()

Expand Down
42 changes: 37 additions & 5 deletions anyio/_backends/curio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .._networking import BaseSocket
from .. import abc, T_Retval, claim_current_thread, _local
from ..exceptions import ExceptionGroup, CancelledError
from ..exceptions import ExceptionGroup, CancelledError, ClosedResourceError


def run(func: Callable[..., T_Retval], *args) -> T_Retval:
Expand Down Expand Up @@ -244,11 +244,43 @@ async def aopen(*args, **kwargs):
#

class Socket(BaseSocket):
def _wait_readable(self) -> None:
return wait_socket_readable(self._raw_socket)
_reader_tasks = {}
_writer_tasks = {}

def _wait_writable(self) -> None:
return wait_socket_writable(self._raw_socket)
async def _wait_readable(self):
task = await curio.current_task()
self._reader_tasks[self._raw_socket] = task
try:
await curio.traps._read_wait(self._raw_socket)
except curio.TaskCancelled:
if self._raw_socket.fileno() == -1:
raise ClosedResourceError from None
else:
raise
finally:
del self._reader_tasks[self._raw_socket]

async def _wait_writable(self):
task = await curio.current_task()
self._writer_tasks[self._raw_socket] = task
try:
await curio.traps._write_wait(self._raw_socket)
except curio.TaskCancelled:
if self._raw_socket.fileno() == -1:
raise ClosedResourceError from None
else:
raise
finally:
del self._writer_tasks[self._raw_socket]

async def _notify_close(self) -> None:
task = Socket._reader_tasks.get(self._raw_socket)
if task:
await task.cancel(blocking=False)

task = Socket._writer_tasks.get(self._raw_socket)
if task:
await task.cancel(blocking=False)

def _check_cancelled(self) -> Awaitable[None]:
return check_cancelled()
Expand Down
21 changes: 17 additions & 4 deletions anyio/_backends/trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .._networking import BaseSocket
from .. import abc, claim_current_thread, T_Retval, _local
from ..exceptions import ExceptionGroup
from ..exceptions import ExceptionGroup, ClosedResourceError


class DummyAwaitable:
Expand Down Expand Up @@ -131,18 +131,31 @@ class Socket(BaseSocket):
def _wait_readable(self):
return wait_socket_readable(self._raw_socket)

def _wait_writable(self) -> None:
def _wait_writable(self):
return wait_socket_writable(self._raw_socket)

async def _notify_close(self):
trio.hazmat.notify_socket_close(self._raw_socket)

def _check_cancelled(self) -> None:
return trio.hazmat.checkpoint_if_cancelled()

def _run_in_thread(self, func: Callable, *args):
return run_in_thread(func, *args)


wait_socket_readable = trio.hazmat.wait_socket_readable
wait_socket_writable = trio.hazmat.wait_socket_writable
async def wait_socket_readable(sock):
try:
await trio.hazmat.wait_socket_readable(sock)
except trio.ClosedResourceError as exc:
raise ClosedResourceError().with_traceback(exc.__traceback__) from None


async def wait_socket_writable(sock):
try:
await trio.hazmat.wait_socket_writable(sock)
except trio.ClosedResourceError as exc:
raise ClosedResourceError().with_traceback(exc.__traceback__) from None


#
Expand Down
22 changes: 15 additions & 7 deletions anyio/_networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ async def _wait_readable(self) -> None:
async def _wait_writable(self) -> None:
pass

@abstractmethod
async def _notify_close(self) -> None:
pass

@abstractmethod
async def _check_cancelled(self) -> None:
pass
Expand Down Expand Up @@ -61,6 +65,10 @@ async def bind(self, address: Union[Tuple[str, int], str]) -> None:
# In all other cases, do this in a worker thread to avoid blocking the event loop thread
await self._run_in_thread(self._raw_socket.bind, address)

async def close(self):
await self._notify_close()
self._raw_socket.close()

async def connect(self, address: Union[tuple, str, bytes]) -> None:
await self._check_cancelled()
try:
Expand Down Expand Up @@ -167,8 +175,8 @@ def __init__(self, sock: BaseSocket, ssl_context: Optional[ssl.SSLContext] = Non
self._ssl_context = ssl_context
self._server_hostname = server_hostname

def close(self):
self._socket.close()
async def close(self):
await self._socket.close()

async def receive_some(self, max_bytes: Optional[int]) -> bytes:
return await self._socket.recv(max_bytes)
Expand Down Expand Up @@ -223,8 +231,8 @@ def __init__(self, sock: BaseSocket, ssl_context: Optional[ssl.SSLContext]) -> N
self._socket = sock
self._ssl_context = ssl_context

def close(self) -> None:
self._socket.close()
async def close(self) -> None:
await self._socket.close()

@property
def address(self) -> Union[tuple, str]:
Expand All @@ -239,7 +247,7 @@ async def accept(self):

return stream
except BaseException:
sock.close()
await sock.close()
raise


Expand All @@ -249,8 +257,8 @@ class DatagramSocket(abc.DatagramSocket):
def __init__(self, sock: BaseSocket) -> None:
self._socket = sock

def close(self):
self._socket.close()
async def close(self):
await self._socket.close()

@property
def address(self) -> Union[Tuple[str, int], str]:
Expand Down
12 changes: 6 additions & 6 deletions anyio/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,10 @@ async def __aenter__(self):
return self

async def __aexit__(self, *exc_info):
self.close()
await self.close()

@abstractmethod
def close(self) -> None:
async def close(self) -> None:
"""Close the underlying socket."""

@abstractmethod
Expand All @@ -282,10 +282,10 @@ async def __aenter__(self):
return self

async def __aexit__(self, *exc_info):
self.close()
await self.close()

@abstractmethod
def close(self) -> None:
async def close(self) -> None:
"""Close the underlying socket."""

@property
Expand Down Expand Up @@ -316,10 +316,10 @@ async def __aenter__(self):
return self

async def __aexit__(self, *exc_info):
self.close()
await self.close()

@abstractmethod
def close(self) -> None:
async def close(self) -> None:
"""Close the underlying socket."""

@property
Expand Down
4 changes: 4 additions & 0 deletions anyio/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,7 @@ def __init__(self, data: bytes) -> None:
super().__init__(
'The delimiter was not found among the first {} bytes read'.format(len(data)))
self.data = data


class ClosedResourceError(Exception):
"""Raised when a resource is closed by another task."""
11 changes: 10 additions & 1 deletion tests/test_networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from anyio import (
create_task_group, connect_tcp, create_udp_socket, connect_unix, create_unix_server,
create_tcp_server)
from anyio.exceptions import IncompleteRead, DelimiterNotFound
from anyio.exceptions import IncompleteRead, DelimiterNotFound, ClosedResourceError


@pytest.mark.anyio
Expand Down Expand Up @@ -163,3 +163,12 @@ async def test_udp_noconnect():
response, addr = await socket.receive(100)
assert response == b'halb'
assert addr == ('127.0.0.1', socket.port)


@pytest.mark.anyio
async def test_udp_close_socket_from_other_task():
async with create_task_group() as tg:
async with await create_udp_socket(interface='127.0.0.1') as udp:
await tg.spawn(udp.close)
with pytest.raises(ClosedResourceError):
await udp.receive(100)

0 comments on commit 8822ef9

Please sign in to comment.