diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst
index ce4c4db0..c9eb2b44 100644
--- a/docs/versionhistory.rst
+++ b/docs/versionhistory.rst
@@ -12,7 +12,9 @@ This library adheres to `Semantic Versioning 2.0 `_.
- Added the ``wait_readable()`` and ``wait_writable()`` functions which will accept
an object with a ``.fileno()`` method or an integer handle, and deprecated
their now obsolete versions (``wait_socket_readable()`` and
- ``wait_socket_writable()`` (PR by @davidbrochart)
+ ``wait_socket_writable()``) (PR by @davidbrochart)
+- Added support for ``wait_readable()`` and ``wait_writable()`` on ``ProactorEventLoop``
+ (used on asyncio + Windows by default)
- Fixed the return type annotations of ``readinto()`` and ``readinto1()`` methods in the
``anyio.AsyncFile`` class
(`#825 `_)
diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py
index 38b68f4d..c1fd0d1e 100644
--- a/src/anyio/_backends/_asyncio.py
+++ b/src/anyio/_backends/_asyncio.py
@@ -103,7 +103,9 @@
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
if TYPE_CHECKING:
- from _typeshed import HasFileno
+ from _typeshed import FileDescriptorLike
+else:
+ FileDescriptorLike = object
if sys.version_info >= (3, 10):
from typing import ParamSpec
@@ -2734,7 +2736,7 @@ async def getnameinfo(
return await get_running_loop().getnameinfo(sockaddr, flags)
@classmethod
- async def wait_readable(cls, obj: HasFileno | int) -> None:
+ async def wait_readable(cls, obj: FileDescriptorLike) -> None:
await cls.checkpoint()
try:
read_events = _read_events.get()
@@ -2746,25 +2748,30 @@ async def wait_readable(cls, obj: HasFileno | int) -> None:
obj = obj.fileno()
if read_events.get(obj):
- raise BusyResourceError("reading from") from None
+ raise BusyResourceError("reading from")
loop = get_running_loop()
- event = read_events[obj] = asyncio.Event()
- loop.add_reader(obj, event.set)
+ event = asyncio.Event()
+ try:
+ loop.add_reader(obj, event.set)
+ except NotImplementedError:
+ from anyio._core._asyncio_selector_thread import get_selector
+
+ selector = get_selector()
+ selector.add_reader(obj, event.set)
+ remove_reader = selector.remove_reader
+ else:
+ remove_reader = loop.remove_reader
+
+ read_events[obj] = event
try:
await event.wait()
finally:
- if read_events.pop(obj, None) is not None:
- loop.remove_reader(obj)
- readable = True
- else:
- readable = False
-
- if not readable:
- raise ClosedResourceError
+ remove_reader(obj)
+ del read_events[obj]
@classmethod
- async def wait_writable(cls, obj: HasFileno | int) -> None:
+ async def wait_writable(cls, obj: FileDescriptorLike) -> None:
await cls.checkpoint()
try:
write_events = _write_events.get()
@@ -2776,22 +2783,27 @@ async def wait_writable(cls, obj: HasFileno | int) -> None:
obj = obj.fileno()
if write_events.get(obj):
- raise BusyResourceError("writing to") from None
+ raise BusyResourceError("writing to")
loop = get_running_loop()
- event = write_events[obj] = asyncio.Event()
- loop.add_writer(obj, event.set)
+ event = asyncio.Event()
+ try:
+ loop.add_writer(obj, event.set)
+ except NotImplementedError:
+ from anyio._core._asyncio_selector_thread import get_selector
+
+ selector = get_selector()
+ selector.add_writer(obj, event.set)
+ remove_writer = selector.remove_writer
+ else:
+ remove_writer = loop.remove_writer
+
+ write_events[obj] = event
try:
await event.wait()
finally:
- if write_events.pop(obj, None) is not None:
- loop.remove_writer(obj)
- writable = True
- else:
- writable = False
-
- if not writable:
- raise ClosedResourceError
+ del write_events[obj]
+ remove_writer(obj)
@classmethod
def current_default_thread_limiter(cls) -> CapacityLimiter:
diff --git a/src/anyio/_core/_asyncio_selector_thread.py b/src/anyio/_core/_asyncio_selector_thread.py
new file mode 100644
index 00000000..d98c3040
--- /dev/null
+++ b/src/anyio/_core/_asyncio_selector_thread.py
@@ -0,0 +1,150 @@
+from __future__ import annotations
+
+import asyncio
+import socket
+import threading
+from collections.abc import Callable
+from selectors import EVENT_READ, EVENT_WRITE, DefaultSelector
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+ from _typeshed import FileDescriptorLike
+
+_selector_lock = threading.Lock()
+_selector: Selector | None = None
+
+
+class Selector:
+ def __init__(self) -> None:
+ self._thread = threading.Thread(target=self.run, name="AnyIO socket selector")
+ self._selector = DefaultSelector()
+ self._send, self._receive = socket.socketpair()
+ self._send.setblocking(False)
+ self._receive.setblocking(False)
+ self._selector.register(self._receive, EVENT_READ)
+ self._closed = False
+
+ def start(self) -> None:
+ self._thread.start()
+ threading._register_atexit(self._stop) # type: ignore[attr-defined]
+
+ def _stop(self) -> None:
+ global _selector
+ self._closed = True
+ self._notify_self()
+ self._send.close()
+ self._thread.join()
+ self._selector.unregister(self._receive)
+ self._receive.close()
+ self._selector.close()
+ _selector = None
+ assert (
+ not self._selector.get_map()
+ ), "selector still has registered file descriptors after shutdown"
+
+ def _notify_self(self) -> None:
+ try:
+ self._send.send(b"\x00")
+ except BlockingIOError:
+ pass
+
+ def add_reader(self, fd: FileDescriptorLike, callback: Callable[[], Any]) -> None:
+ loop = asyncio.get_running_loop()
+ try:
+ key = self._selector.get_key(fd)
+ except KeyError:
+ self._selector.register(fd, EVENT_READ, {EVENT_READ: (loop, callback)})
+ else:
+ if EVENT_READ in key.data:
+ raise ValueError(
+ "this file descriptor is already registered for reading"
+ )
+
+ key.data[EVENT_READ] = loop, callback
+ self._selector.modify(fd, key.events | EVENT_READ, key.data)
+
+ self._notify_self()
+
+ def add_writer(self, fd: FileDescriptorLike, callback: Callable[[], Any]) -> None:
+ loop = asyncio.get_running_loop()
+ try:
+ key = self._selector.get_key(fd)
+ except KeyError:
+ self._selector.register(fd, EVENT_WRITE, {EVENT_WRITE: (loop, callback)})
+ else:
+ if EVENT_WRITE in key.data:
+ raise ValueError(
+ "this file descriptor is already registered for writing"
+ )
+
+ key.data[EVENT_WRITE] = loop, callback
+ self._selector.modify(fd, key.events | EVENT_WRITE, key.data)
+
+ self._notify_self()
+
+ def remove_reader(self, fd: FileDescriptorLike) -> bool:
+ try:
+ key = self._selector.get_key(fd)
+ except KeyError:
+ return False
+
+ if new_events := key.events ^ EVENT_READ:
+ del key.data[EVENT_READ]
+ self._selector.modify(fd, new_events, key.data)
+ else:
+ self._selector.unregister(fd)
+
+ return True
+
+ def remove_writer(self, fd: FileDescriptorLike) -> bool:
+ try:
+ key = self._selector.get_key(fd)
+ except KeyError:
+ return False
+
+ if new_events := key.events ^ EVENT_WRITE:
+ del key.data[EVENT_WRITE]
+ self._selector.modify(fd, new_events, key.data)
+ else:
+ self._selector.unregister(fd)
+
+ return True
+
+ def run(self) -> None:
+ while not self._closed:
+ for key, events in self._selector.select():
+ if key.fileobj is self._receive:
+ try:
+ while self._receive.recv(4096):
+ pass
+ except BlockingIOError:
+ pass
+
+ continue
+
+ if events & EVENT_READ:
+ loop, callback = key.data[EVENT_READ]
+ self.remove_reader(key.fd)
+ try:
+ loop.call_soon_threadsafe(callback)
+ except RuntimeError:
+ pass # the loop was already closed
+
+ if events & EVENT_WRITE:
+ loop, callback = key.data[EVENT_WRITE]
+ self.remove_writer(key.fd)
+ try:
+ loop.call_soon_threadsafe(callback)
+ except RuntimeError:
+ pass # the loop was already closed
+
+
+def get_selector() -> Selector:
+ global _selector
+
+ with _selector_lock:
+ if _selector is None:
+ _selector = Selector()
+ _selector.start()
+
+ return _selector
diff --git a/src/anyio/_core/_sockets.py b/src/anyio/_core/_sockets.py
index 12da4c5c..a822d060 100644
--- a/src/anyio/_core/_sockets.py
+++ b/src/anyio/_core/_sockets.py
@@ -32,9 +32,9 @@
from ._tasks import create_task_group, move_on_after
if TYPE_CHECKING:
- from _typeshed import HasFileno
+ from _typeshed import FileDescriptorLike
else:
- HasFileno = object
+ FileDescriptorLike = object
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup
@@ -609,9 +609,6 @@ def wait_socket_readable(sock: socket.socket) -> Awaitable[None]:
Wait until the given socket has data to be read.
- This does **NOT** work on Windows when using the asyncio backend with a proactor
- event loop (default on py3.8+).
-
.. warning:: Only use this on raw sockets that have not been wrapped by any higher
level constructs like socket streams!
@@ -649,7 +646,7 @@ def wait_socket_writable(sock: socket.socket) -> Awaitable[None]:
return get_async_backend().wait_writable(sock.fileno())
-def wait_readable(obj: HasFileno | int) -> Awaitable[None]:
+def wait_readable(obj: FileDescriptorLike) -> Awaitable[None]:
"""
Wait until the given object has data to be read.
@@ -663,10 +660,11 @@ def wait_readable(obj: HasFileno | int) -> Awaitable[None]:
descriptors aren't supported, and neither are handles that refer to anything besides
a ``SOCKET``.
- This does **NOT** work on Windows when using the asyncio backend with a proactor
- event loop (default on py3.8+).
+ On backends where this functionality is not natively provided (asyncio
+ ``ProactorEventLoop`` on Windows), it is provided using a separate selector thread
+ which is set to shut down when the interpreter shuts down.
- .. warning:: Only use this on raw sockets that have not been wrapped by any higher
+ .. warning:: Don't use this on raw sockets that have been wrapped by any higher
level constructs like socket streams!
:param obj: an object with a ``.fileno()`` method or an integer handle
@@ -679,25 +677,22 @@ def wait_readable(obj: HasFileno | int) -> Awaitable[None]:
return get_async_backend().wait_readable(obj)
-def wait_writable(obj: HasFileno | int) -> Awaitable[None]:
+def wait_writable(obj: FileDescriptorLike) -> Awaitable[None]:
"""
Wait until the given object can be written to.
- This does **NOT** work on Windows when using the asyncio backend with a proactor
- event loop (default on py3.8+).
-
- .. seealso:: See the documentation of :func:`wait_readable` for the definition of
- ``obj``.
-
- .. warning:: Only use this on raw sockets that have not been wrapped by any higher
- level constructs like socket streams!
-
:param obj: an object with a ``.fileno()`` method or an integer handle
:raises ~anyio.ClosedResourceError: if the object was closed while waiting for the
object to become writable
:raises ~anyio.BusyResourceError: if another task is already waiting for the object
to become writable
+ .. seealso:: See the documentation of :func:`wait_readable` for the definition of
+ ``obj`` and notes on backend compatibility.
+
+ .. warning:: Don't use this on raw sockets that have been wrapped by any higher
+ level constructs like socket streams!
+
"""
return get_async_backend().wait_writable(obj)
diff --git a/tests/test_sockets.py b/tests/test_sockets.py
index 8965ea61..b5143df0 100644
--- a/tests/test_sockets.py
+++ b/tests/test_sockets.py
@@ -65,7 +65,7 @@
from exceptiongroup import ExceptionGroup
if TYPE_CHECKING:
- from _typeshed import HasFileno
+ from _typeshed import FileDescriptorLike
AnyIPAddressFamily = Literal[
AddressFamily.AF_UNSPEC, AddressFamily.AF_INET, AddressFamily.AF_INET6
@@ -1858,16 +1858,7 @@ async def test_connect_tcp_getaddrinfo_context() -> None:
@pytest.mark.parametrize("socket_type", ["socket", "fd"])
@pytest.mark.parametrize("event", ["readable", "writable"])
-async def test_wait_socket(
- anyio_backend_name: str, event: str, socket_type: str
-) -> None:
- if anyio_backend_name == "asyncio" and platform.system() == "Windows":
- import asyncio
-
- policy = asyncio.get_event_loop_policy()
- if policy.__class__.__name__ == "WindowsProactorEventLoopPolicy":
- pytest.skip("Does not work on asyncio/Windows/ProactorEventLoop")
-
+async def test_wait_socket(event: str, socket_type: str) -> None:
wait = wait_readable if event == "readable" else wait_writable
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_sock:
@@ -1880,20 +1871,15 @@ async def test_wait_socket(
conn, addr = server_sock.accept()
with conn:
- sock_or_fd: HasFileno | int = conn.fileno() if socket_type == "fd" else conn
- with fail_after(10):
+ sock_or_fd: FileDescriptorLike = (
+ conn.fileno() if socket_type == "fd" else conn
+ )
+ with fail_after(3):
await wait(sock_or_fd)
assert conn.recv(1024) == b"Hello, world"
async def test_deprecated_wait_socket(anyio_backend_name: str) -> None:
- if anyio_backend_name == "asyncio" and platform.system() == "Windows":
- import asyncio
-
- policy = asyncio.get_event_loop_policy()
- if policy.__class__.__name__ == "WindowsProactorEventLoopPolicy":
- pytest.skip("Does not work on asyncio/Windows/ProactorEventLoop")
-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
with pytest.warns(
DeprecationWarning,