diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index ce4c4db0..c9eb2b44 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -12,7 +12,9 @@ This library adheres to `Semantic Versioning 2.0 `_. - Added the ``wait_readable()`` and ``wait_writable()`` functions which will accept an object with a ``.fileno()`` method or an integer handle, and deprecated their now obsolete versions (``wait_socket_readable()`` and - ``wait_socket_writable()`` (PR by @davidbrochart) + ``wait_socket_writable()``) (PR by @davidbrochart) +- Added support for ``wait_readable()`` and ``wait_writable()`` on ``ProactorEventLoop`` + (used on asyncio + Windows by default) - Fixed the return type annotations of ``readinto()`` and ``readinto1()`` methods in the ``anyio.AsyncFile`` class (`#825 `_) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 38b68f4d..c1fd0d1e 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -103,7 +103,9 @@ from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream if TYPE_CHECKING: - from _typeshed import HasFileno + from _typeshed import FileDescriptorLike +else: + FileDescriptorLike = object if sys.version_info >= (3, 10): from typing import ParamSpec @@ -2734,7 +2736,7 @@ async def getnameinfo( return await get_running_loop().getnameinfo(sockaddr, flags) @classmethod - async def wait_readable(cls, obj: HasFileno | int) -> None: + async def wait_readable(cls, obj: FileDescriptorLike) -> None: await cls.checkpoint() try: read_events = _read_events.get() @@ -2746,25 +2748,30 @@ async def wait_readable(cls, obj: HasFileno | int) -> None: obj = obj.fileno() if read_events.get(obj): - raise BusyResourceError("reading from") from None + raise BusyResourceError("reading from") loop = get_running_loop() - event = read_events[obj] = asyncio.Event() - loop.add_reader(obj, event.set) + event = asyncio.Event() + try: + loop.add_reader(obj, event.set) + except NotImplementedError: + from anyio._core._asyncio_selector_thread import get_selector + + selector = get_selector() + selector.add_reader(obj, event.set) + remove_reader = selector.remove_reader + else: + remove_reader = loop.remove_reader + + read_events[obj] = event try: await event.wait() finally: - if read_events.pop(obj, None) is not None: - loop.remove_reader(obj) - readable = True - else: - readable = False - - if not readable: - raise ClosedResourceError + remove_reader(obj) + del read_events[obj] @classmethod - async def wait_writable(cls, obj: HasFileno | int) -> None: + async def wait_writable(cls, obj: FileDescriptorLike) -> None: await cls.checkpoint() try: write_events = _write_events.get() @@ -2776,22 +2783,27 @@ async def wait_writable(cls, obj: HasFileno | int) -> None: obj = obj.fileno() if write_events.get(obj): - raise BusyResourceError("writing to") from None + raise BusyResourceError("writing to") loop = get_running_loop() - event = write_events[obj] = asyncio.Event() - loop.add_writer(obj, event.set) + event = asyncio.Event() + try: + loop.add_writer(obj, event.set) + except NotImplementedError: + from anyio._core._asyncio_selector_thread import get_selector + + selector = get_selector() + selector.add_writer(obj, event.set) + remove_writer = selector.remove_writer + else: + remove_writer = loop.remove_writer + + write_events[obj] = event try: await event.wait() finally: - if write_events.pop(obj, None) is not None: - loop.remove_writer(obj) - writable = True - else: - writable = False - - if not writable: - raise ClosedResourceError + del write_events[obj] + remove_writer(obj) @classmethod def current_default_thread_limiter(cls) -> CapacityLimiter: diff --git a/src/anyio/_core/_asyncio_selector_thread.py b/src/anyio/_core/_asyncio_selector_thread.py new file mode 100644 index 00000000..d98c3040 --- /dev/null +++ b/src/anyio/_core/_asyncio_selector_thread.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import asyncio +import socket +import threading +from collections.abc import Callable +from selectors import EVENT_READ, EVENT_WRITE, DefaultSelector +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from _typeshed import FileDescriptorLike + +_selector_lock = threading.Lock() +_selector: Selector | None = None + + +class Selector: + def __init__(self) -> None: + self._thread = threading.Thread(target=self.run, name="AnyIO socket selector") + self._selector = DefaultSelector() + self._send, self._receive = socket.socketpair() + self._send.setblocking(False) + self._receive.setblocking(False) + self._selector.register(self._receive, EVENT_READ) + self._closed = False + + def start(self) -> None: + self._thread.start() + threading._register_atexit(self._stop) # type: ignore[attr-defined] + + def _stop(self) -> None: + global _selector + self._closed = True + self._notify_self() + self._send.close() + self._thread.join() + self._selector.unregister(self._receive) + self._receive.close() + self._selector.close() + _selector = None + assert ( + not self._selector.get_map() + ), "selector still has registered file descriptors after shutdown" + + def _notify_self(self) -> None: + try: + self._send.send(b"\x00") + except BlockingIOError: + pass + + def add_reader(self, fd: FileDescriptorLike, callback: Callable[[], Any]) -> None: + loop = asyncio.get_running_loop() + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, EVENT_READ, {EVENT_READ: (loop, callback)}) + else: + if EVENT_READ in key.data: + raise ValueError( + "this file descriptor is already registered for reading" + ) + + key.data[EVENT_READ] = loop, callback + self._selector.modify(fd, key.events | EVENT_READ, key.data) + + self._notify_self() + + def add_writer(self, fd: FileDescriptorLike, callback: Callable[[], Any]) -> None: + loop = asyncio.get_running_loop() + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, EVENT_WRITE, {EVENT_WRITE: (loop, callback)}) + else: + if EVENT_WRITE in key.data: + raise ValueError( + "this file descriptor is already registered for writing" + ) + + key.data[EVENT_WRITE] = loop, callback + self._selector.modify(fd, key.events | EVENT_WRITE, key.data) + + self._notify_self() + + def remove_reader(self, fd: FileDescriptorLike) -> bool: + try: + key = self._selector.get_key(fd) + except KeyError: + return False + + if new_events := key.events ^ EVENT_READ: + del key.data[EVENT_READ] + self._selector.modify(fd, new_events, key.data) + else: + self._selector.unregister(fd) + + return True + + def remove_writer(self, fd: FileDescriptorLike) -> bool: + try: + key = self._selector.get_key(fd) + except KeyError: + return False + + if new_events := key.events ^ EVENT_WRITE: + del key.data[EVENT_WRITE] + self._selector.modify(fd, new_events, key.data) + else: + self._selector.unregister(fd) + + return True + + def run(self) -> None: + while not self._closed: + for key, events in self._selector.select(): + if key.fileobj is self._receive: + try: + while self._receive.recv(4096): + pass + except BlockingIOError: + pass + + continue + + if events & EVENT_READ: + loop, callback = key.data[EVENT_READ] + self.remove_reader(key.fd) + try: + loop.call_soon_threadsafe(callback) + except RuntimeError: + pass # the loop was already closed + + if events & EVENT_WRITE: + loop, callback = key.data[EVENT_WRITE] + self.remove_writer(key.fd) + try: + loop.call_soon_threadsafe(callback) + except RuntimeError: + pass # the loop was already closed + + +def get_selector() -> Selector: + global _selector + + with _selector_lock: + if _selector is None: + _selector = Selector() + _selector.start() + + return _selector diff --git a/src/anyio/_core/_sockets.py b/src/anyio/_core/_sockets.py index 12da4c5c..a822d060 100644 --- a/src/anyio/_core/_sockets.py +++ b/src/anyio/_core/_sockets.py @@ -32,9 +32,9 @@ from ._tasks import create_task_group, move_on_after if TYPE_CHECKING: - from _typeshed import HasFileno + from _typeshed import FileDescriptorLike else: - HasFileno = object + FileDescriptorLike = object if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup @@ -609,9 +609,6 @@ def wait_socket_readable(sock: socket.socket) -> Awaitable[None]: Wait until the given socket has data to be read. - This does **NOT** work on Windows when using the asyncio backend with a proactor - event loop (default on py3.8+). - .. warning:: Only use this on raw sockets that have not been wrapped by any higher level constructs like socket streams! @@ -649,7 +646,7 @@ def wait_socket_writable(sock: socket.socket) -> Awaitable[None]: return get_async_backend().wait_writable(sock.fileno()) -def wait_readable(obj: HasFileno | int) -> Awaitable[None]: +def wait_readable(obj: FileDescriptorLike) -> Awaitable[None]: """ Wait until the given object has data to be read. @@ -663,10 +660,11 @@ def wait_readable(obj: HasFileno | int) -> Awaitable[None]: descriptors aren't supported, and neither are handles that refer to anything besides a ``SOCKET``. - This does **NOT** work on Windows when using the asyncio backend with a proactor - event loop (default on py3.8+). + On backends where this functionality is not natively provided (asyncio + ``ProactorEventLoop`` on Windows), it is provided using a separate selector thread + which is set to shut down when the interpreter shuts down. - .. warning:: Only use this on raw sockets that have not been wrapped by any higher + .. warning:: Don't use this on raw sockets that have been wrapped by any higher level constructs like socket streams! :param obj: an object with a ``.fileno()`` method or an integer handle @@ -679,25 +677,22 @@ def wait_readable(obj: HasFileno | int) -> Awaitable[None]: return get_async_backend().wait_readable(obj) -def wait_writable(obj: HasFileno | int) -> Awaitable[None]: +def wait_writable(obj: FileDescriptorLike) -> Awaitable[None]: """ Wait until the given object can be written to. - This does **NOT** work on Windows when using the asyncio backend with a proactor - event loop (default on py3.8+). - - .. seealso:: See the documentation of :func:`wait_readable` for the definition of - ``obj``. - - .. warning:: Only use this on raw sockets that have not been wrapped by any higher - level constructs like socket streams! - :param obj: an object with a ``.fileno()`` method or an integer handle :raises ~anyio.ClosedResourceError: if the object was closed while waiting for the object to become writable :raises ~anyio.BusyResourceError: if another task is already waiting for the object to become writable + .. seealso:: See the documentation of :func:`wait_readable` for the definition of + ``obj`` and notes on backend compatibility. + + .. warning:: Don't use this on raw sockets that have been wrapped by any higher + level constructs like socket streams! + """ return get_async_backend().wait_writable(obj) diff --git a/tests/test_sockets.py b/tests/test_sockets.py index 8965ea61..b5143df0 100644 --- a/tests/test_sockets.py +++ b/tests/test_sockets.py @@ -65,7 +65,7 @@ from exceptiongroup import ExceptionGroup if TYPE_CHECKING: - from _typeshed import HasFileno + from _typeshed import FileDescriptorLike AnyIPAddressFamily = Literal[ AddressFamily.AF_UNSPEC, AddressFamily.AF_INET, AddressFamily.AF_INET6 @@ -1858,16 +1858,7 @@ async def test_connect_tcp_getaddrinfo_context() -> None: @pytest.mark.parametrize("socket_type", ["socket", "fd"]) @pytest.mark.parametrize("event", ["readable", "writable"]) -async def test_wait_socket( - anyio_backend_name: str, event: str, socket_type: str -) -> None: - if anyio_backend_name == "asyncio" and platform.system() == "Windows": - import asyncio - - policy = asyncio.get_event_loop_policy() - if policy.__class__.__name__ == "WindowsProactorEventLoopPolicy": - pytest.skip("Does not work on asyncio/Windows/ProactorEventLoop") - +async def test_wait_socket(event: str, socket_type: str) -> None: wait = wait_readable if event == "readable" else wait_writable with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_sock: @@ -1880,20 +1871,15 @@ async def test_wait_socket( conn, addr = server_sock.accept() with conn: - sock_or_fd: HasFileno | int = conn.fileno() if socket_type == "fd" else conn - with fail_after(10): + sock_or_fd: FileDescriptorLike = ( + conn.fileno() if socket_type == "fd" else conn + ) + with fail_after(3): await wait(sock_or_fd) assert conn.recv(1024) == b"Hello, world" async def test_deprecated_wait_socket(anyio_backend_name: str) -> None: - if anyio_backend_name == "asyncio" and platform.system() == "Windows": - import asyncio - - policy = asyncio.get_event_loop_policy() - if policy.__class__.__name__ == "WindowsProactorEventLoopPolicy": - pytest.skip("Does not work on asyncio/Windows/ProactorEventLoop") - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: with pytest.warns( DeprecationWarning,