diff --git a/pyproject.toml b/pyproject.toml index 79beab840d..d123d4f158 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,10 +60,6 @@ disallow_untyped_calls = false # files not yet fully typed [[tool.mypy.overrides]] module = [ -# 2761 -"trio/_core/_generated_io_windows", -"trio/_core/_io_windows", - # internal "trio/_windows_pipes", diff --git a/test-requirements.in b/test-requirements.in index 7461a1e4ed..86e733657f 100644 --- a/test-requirements.in +++ b/test-requirements.in @@ -22,6 +22,7 @@ codespell # https://github.com/python-trio/trio/pull/654#issuecomment-420518745 mypy-extensions; implementation_name == "cpython" typing-extensions +types-cffi; implementation_name == "cpython" # Trio's own dependencies cffi; os_name == "nt" diff --git a/test-requirements.txt b/test-requirements.txt index 2c86c7439c..150aa20174 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -124,8 +124,12 @@ tomlkit==0.12.1 # via pylint trustme==1.1.0 # via -r test-requirements.in +types-cffi==1.15.1.15 ; implementation_name == "cpython" + # via -r test-requirements.in types-pyopenssl==23.2.0.2 ; implementation_name == "cpython" # via -r test-requirements.in +types-setuptools==68.1.0.0 + # via types-cffi typing-extensions==4.7.1 # via # -r test-requirements.in diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index b81255d8a9..ca444373fa 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -3,16 +3,24 @@ # ************************************************************* from __future__ import annotations -import sys -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ContextManager from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import GLOBAL_RUN_CONTEXT +if TYPE_CHECKING: + from .._file_io import _HasFileNo + from ._windows_cffi import Handle, CData + from typing_extensions import Buffer + + from ._unbounded_queue import UnboundedQueue + +import sys + assert not TYPE_CHECKING or sys.platform == "win32" -async def wait_readable(sock): +async def wait_readable(sock: (_HasFileNo | int)) -> None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock) @@ -20,7 +28,7 @@ async def wait_readable(sock): raise RuntimeError("must be called from async context") -async def wait_writable(sock): +async def wait_writable(sock: (_HasFileNo | int)) -> None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock) @@ -28,7 +36,7 @@ async def wait_writable(sock): raise RuntimeError("must be called from async context") -def notify_closing(handle): +def notify_closing(handle: (Handle | int | _HasFileNo)) -> None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle) @@ -36,7 +44,7 @@ def notify_closing(handle): raise RuntimeError("must be called from async context") -def register_with_iocp(handle): +def register_with_iocp(handle: (int | CData)) -> None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle) @@ -44,17 +52,21 @@ def register_with_iocp(handle): raise RuntimeError("must be called from async context") -async def wait_overlapped(handle, lpOverlapped): +async def wait_overlapped( + handle_: (int | CData), lpOverlapped: (CData | int) +) -> object: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_overlapped( - handle, lpOverlapped + handle_, lpOverlapped ) except AttributeError: raise RuntimeError("must be called from async context") -async def write_overlapped(handle, data, file_offset=0): +async def write_overlapped( + handle: (int | CData), data: Buffer, file_offset: int = 0 +) -> int: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.write_overlapped( @@ -64,7 +76,9 @@ async def write_overlapped(handle, data, file_offset=0): raise RuntimeError("must be called from async context") -async def readinto_overlapped(handle, buffer, file_offset=0): +async def readinto_overlapped( + handle: (int | CData), buffer: Buffer, file_offset: int = 0 +) -> int: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.readinto_overlapped( @@ -74,7 +88,7 @@ async def readinto_overlapped(handle, buffer, file_offset=0): raise RuntimeError("must be called from async context") -def current_iocp(): +def current_iocp() -> int: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp() @@ -82,7 +96,7 @@ def current_iocp(): raise RuntimeError("must be called from async context") -def monitor_completion_key(): +def monitor_completion_key() -> ContextManager[tuple[int, UnboundedQueue[object]]]: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key() diff --git a/trio/_core/_io_common.py b/trio/_core/_io_common.py index c1af293278..14cd9d33e6 100644 --- a/trio/_core/_io_common.py +++ b/trio/_core/_io_common.py @@ -9,10 +9,11 @@ if TYPE_CHECKING: from ._io_epoll import EpollWaiters + from ._io_windows import AFDWaiters # Utility function shared between _io_epoll and _io_windows -def wake_all(waiters: EpollWaiters, exc: BaseException) -> None: +def wake_all(waiters: EpollWaiters | AFDWaiters, exc: BaseException) -> None: try: current_task = _core.current_task() except RuntimeError: diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index 9757d25b5f..ba84525506 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -5,7 +5,16 @@ import socket import sys from contextlib import contextmanager -from typing import TYPE_CHECKING, Literal +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterator, + Literal, + Optional, + TypeVar, + cast, +) import attr from outcome import Value @@ -16,12 +25,15 @@ from ._windows_cffi import ( INVALID_HANDLE_VALUE, AFDPollFlags, + CData, CompletionModes, ErrorCodes, FileFlags, + Handle, IoControlCodes, WSAIoctls, _handle, + _Overlapped, ffi, kernel32, ntdll, @@ -29,11 +41,15 @@ ws2_32, ) -assert not TYPE_CHECKING or sys.platform == "win32" - if TYPE_CHECKING: - from typing_extensions import TypeAlias + from typing_extensions import Buffer, TypeAlias + + from .._file_io import _HasFileNo + from ._traps import Abort, RaiseCancelT + from ._unbounded_queue import UnboundedQueue + EventResult: TypeAlias = int +T = TypeVar("T") # There's a lot to be said about the overall design of a Windows event # loop. See @@ -185,13 +201,111 @@ class CKeys(enum.IntEnum): USER_DEFINED = 4 # and above -def _check(success): +# AFD_POLL has a finer-grained set of events than other APIs. We collapse them +# down into Unix-style "readable" and "writable". +# +# Note: AFD_POLL_LOCAL_CLOSE isn't a reliable substitute for notify_closing(), +# because even if the user closes the socket *handle*, the socket *object* +# could still remain open, e.g. if the socket was dup'ed (possibly into +# another process). Explicitly calling notify_closing() guarantees that +# everyone waiting on the *handle* wakes up, which is what you'd expect. +# +# However, we can't avoid getting LOCAL_CLOSE notifications -- the kernel +# delivers them whether we ask for them or not -- so better to include them +# here for documentation, and so that when we check (delivered & requested) we +# get a match. + +READABLE_FLAGS = ( + AFDPollFlags.AFD_POLL_RECEIVE + | AFDPollFlags.AFD_POLL_ACCEPT + | AFDPollFlags.AFD_POLL_DISCONNECT # other side sent an EOF + | AFDPollFlags.AFD_POLL_ABORT + | AFDPollFlags.AFD_POLL_LOCAL_CLOSE +) + +WRITABLE_FLAGS = ( + AFDPollFlags.AFD_POLL_SEND + | AFDPollFlags.AFD_POLL_CONNECT_FAIL + | AFDPollFlags.AFD_POLL_ABORT + | AFDPollFlags.AFD_POLL_LOCAL_CLOSE +) + + +# Annoyingly, while the API makes it *seem* like you can happily issue as many +# independent AFD_POLL operations as you want without them interfering with +# each other, in fact if you issue two AFD_POLL operations for the same socket +# at the same time with notification going to the same IOCP port, then Windows +# gets super confused. For example, if we issue one operation from +# wait_readable, and another independent operation from wait_writable, then +# Windows may complete the wait_writable operation when the socket becomes +# readable. +# +# To avoid this, we have to coalesce all the operations on a single socket +# into one, and when the set of waiters changes we have to throw away the old +# operation and start a new one. +@attr.s(slots=True, eq=False) +class AFDWaiters: + read_task: Optional[_core.Task] = attr.ib(default=None) + write_task: Optional[_core.Task] = attr.ib(default=None) + current_op: Optional[AFDPollOp] = attr.ib(default=None) + + +# We also need to bundle up all the info for a single op into a standalone +# object, because we need to keep all these objects alive until the operation +# finishes, even if we're throwing it away. +@attr.s(slots=True, eq=False, frozen=True) +class AFDPollOp: + lpOverlapped: CData = attr.ib() + poll_info: Any = attr.ib() + waiters: AFDWaiters = attr.ib() + afd_group: AFDGroup = attr.ib() + + +# The Windows kernel has a weird issue when using AFD handles. If you have N +# instances of wait_readable/wait_writable registered with a single AFD handle, +# then cancelling any one of them takes something like O(N**2) time. So if we +# used just a single AFD handle, then cancellation would quickly become very +# expensive, e.g. a program with N active sockets would take something like +# O(N**3) time to unwind after control-C. The solution is to spread our sockets +# out over multiple AFD handles, so that N doesn't grow too large for any +# individual handle. +MAX_AFD_GROUP_SIZE = 500 # at 1000, the cubic scaling is just starting to bite + + +@attr.s(slots=True, eq=False) +class AFDGroup: + size: int = attr.ib() + handle: Handle = attr.ib() + + +assert not TYPE_CHECKING or sys.platform == "win32" + + +@attr.s(slots=True, eq=False, frozen=True) +class _WindowsStatistics: + tasks_waiting_read: int = attr.ib() + tasks_waiting_write: int = attr.ib() + tasks_waiting_overlapped: int = attr.ib() + completion_key_monitors: int = attr.ib() + backend: Literal["windows"] = attr.ib(init=False, default="windows") + + +# Maximum number of events to dequeue from the completion port on each pass +# through the run loop. Somewhat arbitrary. Should be large enough to collect +# a good set of tasks on each loop, but not so large to waste tons of memory. +# (Each WindowsIOManager holds a buffer whose size is ~32x this number.) +MAX_EVENTS = 1000 + + +def _check(success: T) -> T: if not success: raise_winerror() return success -def _get_underlying_socket(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): +def _get_underlying_socket( + sock: _HasFileNo | int | Handle, *, which: WSAIoctls = WSAIoctls.SIO_BASE_HANDLE +) -> Handle: if hasattr(sock, "fileno"): sock = sock.fileno() base_ptr = ffi.new("HANDLE *") @@ -210,10 +324,10 @@ def _get_underlying_socket(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): if failed: code = ws2_32.WSAGetLastError() raise_winerror(code) - return base_ptr[0] + return Handle(base_ptr[0]) -def _get_base_socket(sock): +def _get_base_socket(sock: _HasFileNo | int | Handle) -> Handle: # There is a development kit for LSPs called Komodia Redirector. # It does some unusual (some might say evil) things like intercepting # SIO_BASE_HANDLE (fails) and SIO_BSP_HANDLE_SELECT (returns the same @@ -260,7 +374,7 @@ def _get_base_socket(sock): sock = next_sock -def _afd_helper_handle(): +def _afd_helper_handle() -> Handle: # The "AFD" driver is exposed at the NT path "\Device\Afd". We're using # the Win32 CreateFile, though, so we have to pass a Win32 path. \\.\ is # how Win32 refers to the NT \GLOBAL??\ directory, and GLOBALROOT is a @@ -292,130 +406,37 @@ def _afd_helper_handle(): return handle -# AFD_POLL has a finer-grained set of events than other APIs. We collapse them -# down into Unix-style "readable" and "writable". -# -# Note: AFD_POLL_LOCAL_CLOSE isn't a reliable substitute for notify_closing(), -# because even if the user closes the socket *handle*, the socket *object* -# could still remain open, e.g. if the socket was dup'ed (possibly into -# another process). Explicitly calling notify_closing() guarantees that -# everyone waiting on the *handle* wakes up, which is what you'd expect. -# -# However, we can't avoid getting LOCAL_CLOSE notifications -- the kernel -# delivers them whether we ask for them or not -- so better to include them -# here for documentation, and so that when we check (delivered & requested) we -# get a match. - -READABLE_FLAGS = ( - AFDPollFlags.AFD_POLL_RECEIVE - | AFDPollFlags.AFD_POLL_ACCEPT - | AFDPollFlags.AFD_POLL_DISCONNECT # other side sent an EOF - | AFDPollFlags.AFD_POLL_ABORT - | AFDPollFlags.AFD_POLL_LOCAL_CLOSE -) - -WRITABLE_FLAGS = ( - AFDPollFlags.AFD_POLL_SEND - | AFDPollFlags.AFD_POLL_CONNECT_FAIL - | AFDPollFlags.AFD_POLL_ABORT - | AFDPollFlags.AFD_POLL_LOCAL_CLOSE -) - - -# Annoyingly, while the API makes it *seem* like you can happily issue as many -# independent AFD_POLL operations as you want without them interfering with -# each other, in fact if you issue two AFD_POLL operations for the same socket -# at the same time with notification going to the same IOCP port, then Windows -# gets super confused. For example, if we issue one operation from -# wait_readable, and another independent operation from wait_writable, then -# Windows may complete the wait_writable operation when the socket becomes -# readable. -# -# To avoid this, we have to coalesce all the operations on a single socket -# into one, and when the set of waiters changes we have to throw away the old -# operation and start a new one. -@attr.s(slots=True, eq=False) -class AFDWaiters: - read_task = attr.ib(default=None) - write_task = attr.ib(default=None) - current_op = attr.ib(default=None) - - -# We also need to bundle up all the info for a single op into a standalone -# object, because we need to keep all these objects alive until the operation -# finishes, even if we're throwing it away. -@attr.s(slots=True, eq=False, frozen=True) -class AFDPollOp: - lpOverlapped = attr.ib() - poll_info = attr.ib() - waiters = attr.ib() - afd_group = attr.ib() - - -# The Windows kernel has a weird issue when using AFD handles. If you have N -# instances of wait_readable/wait_writable registered with a single AFD handle, -# then cancelling any one of them takes something like O(N**2) time. So if we -# used just a single AFD handle, then cancellation would quickly become very -# expensive, e.g. a program with N active sockets would take something like -# O(N**3) time to unwind after control-C. The solution is to spread our sockets -# out over multiple AFD handles, so that N doesn't grow too large for any -# individual handle. -MAX_AFD_GROUP_SIZE = 500 # at 1000, the cubic scaling is just starting to bite - - -@attr.s(slots=True, eq=False) -class AFDGroup: - size = attr.ib() - handle = attr.ib() - - -@attr.s(slots=True, eq=False, frozen=True) -class _WindowsStatistics: - tasks_waiting_read: int = attr.ib() - tasks_waiting_write: int = attr.ib() - tasks_waiting_overlapped: int = attr.ib() - completion_key_monitors: int = attr.ib() - backend: Literal["windows"] = attr.ib(init=False, default="windows") - - -# Maximum number of events to dequeue from the completion port on each pass -# through the run loop. Somewhat arbitrary. Should be large enough to collect -# a good set of tasks on each loop, but not so large to waste tons of memory. -# (Each WindowsIOManager holds a buffer whose size is ~32x this number.) -MAX_EVENTS = 1000 - - @attr.s(frozen=True) class CompletionKeyEventInfo: - lpOverlapped = attr.ib() - dwNumberOfBytesTransferred = attr.ib() + lpOverlapped: CData = attr.ib() + dwNumberOfBytesTransferred: int = attr.ib() class WindowsIOManager: - def __init__(self): + def __init__(self) -> None: # If this method raises an exception, then __del__ could run on a # half-initialized object. So we initialize everything that __del__ # touches to safe values up front, before we do anything that can # fail. self._iocp = None - self._all_afd_handles = [] + self._all_afd_handles: list[Handle] = [] self._iocp = _check( kernel32.CreateIoCompletionPort(INVALID_HANDLE_VALUE, ffi.NULL, 0, 0) ) self._events = ffi.new("OVERLAPPED_ENTRY[]", MAX_EVENTS) - self._vacant_afd_groups = set() + self._vacant_afd_groups: set[AFDGroup] = set() # {lpOverlapped: AFDPollOp} - self._afd_ops = {} + self._afd_ops: dict[CData, AFDPollOp] = {} # {socket handle: AFDWaiters} - self._afd_waiters = {} + self._afd_waiters: dict[Handle, AFDWaiters] = {} # {lpOverlapped: task} - self._overlapped_waiters = {} - self._posted_too_late_to_cancel = set() + self._overlapped_waiters: dict[CData, _core.Task] = {} + self._posted_too_late_to_cancel: set[CData] = set() - self._completion_key_queues = {} + self._completion_key_queues: dict[int, UnboundedQueue[object]] = {} self._completion_key_counter = itertools.count(CKeys.USER_DEFINED) with socket.socket() as s: @@ -455,7 +476,7 @@ def __init__(self): "netsh winsock show catalog" ) - def close(self): + def close(self) -> None: try: if self._iocp is not None: iocp = self._iocp @@ -466,10 +487,10 @@ def close(self): afd_handle = self._all_afd_handles.pop() _check(kernel32.CloseHandle(afd_handle)) - def __del__(self): + def __del__(self) -> None: self.close() - def statistics(self): + def statistics(self) -> _WindowsStatistics: tasks_waiting_read = 0 tasks_waiting_write = 0 for waiter in self._afd_waiters.values(): @@ -484,7 +505,8 @@ def statistics(self): completion_key_monitors=len(self._completion_key_queues), ) - def force_wakeup(self): + def force_wakeup(self) -> None: + assert self._iocp is not None _check( kernel32.PostQueuedCompletionStatus( self._iocp, 0, CKeys.FORCE_WAKEUP, ffi.NULL @@ -497,6 +519,7 @@ def get_events(self, timeout: float) -> EventResult: if timeout > 0 and milliseconds == 0: milliseconds = 1 try: + assert self._iocp is not None _check( kernel32.GetQueuedCompletionStatusEx( self._iocp, self._events, MAX_EVENTS, received, milliseconds, 0 @@ -590,8 +613,9 @@ def process_events(self, received: EventResult) -> None: ) queue.put_nowait(info) - def _register_with_iocp(self, handle, completion_key): - handle = _handle(handle) + def _register_with_iocp(self, handle_: int | CData, completion_key: int) -> None: + handle = _handle(handle_) + assert self._iocp is not None _check(kernel32.CreateIoCompletionPort(handle, self._iocp, completion_key, 0)) # Supposedly this makes things slightly faster, by disabling the # ability to do WaitForSingleObject(handle). We would never want to do @@ -607,7 +631,7 @@ def _register_with_iocp(self, handle, completion_key): # AFD stuff ################################################################ - def _refresh_afd(self, base_handle): + def _refresh_afd(self, base_handle: Handle) -> None: waiters = self._afd_waiters[base_handle] if waiters.current_op is not None: afd_group = waiters.current_op.afd_group @@ -645,7 +669,7 @@ def _refresh_afd(self, base_handle): lpOverlapped = ffi.new("LPOVERLAPPED") - poll_info = ffi.new("AFD_POLL_INFO *") + poll_info: Any = ffi.new("AFD_POLL_INFO *") poll_info.Timeout = 2**63 - 1 # INT64_MAX poll_info.NumberOfHandles = 1 poll_info.Exclusive = 0 @@ -683,7 +707,7 @@ def _refresh_afd(self, base_handle): if afd_group.size >= MAX_AFD_GROUP_SIZE: self._vacant_afd_groups.remove(afd_group) - async def _afd_poll(self, sock, mode): + async def _afd_poll(self, sock: _HasFileNo | int, mode: str) -> None: base_handle = _get_base_socket(sock) waiters = self._afd_waiters.get(base_handle) if waiters is None: @@ -696,7 +720,7 @@ async def _afd_poll(self, sock, mode): # we let it escape. self._refresh_afd(base_handle) - def abort_fn(_): + def abort_fn(_: RaiseCancelT) -> Abort: setattr(waiters, mode, None) self._refresh_afd(base_handle) return _core.Abort.SUCCEEDED @@ -704,15 +728,15 @@ def abort_fn(_): await _core.wait_task_rescheduled(abort_fn) @_public - async def wait_readable(self, sock): + async def wait_readable(self, sock: _HasFileNo | int) -> None: await self._afd_poll(sock, "read_task") @_public - async def wait_writable(self, sock): + async def wait_writable(self, sock: _HasFileNo | int) -> None: await self._afd_poll(sock, "write_task") @_public - def notify_closing(self, handle): + def notify_closing(self, handle: Handle | int | _HasFileNo) -> None: handle = _get_base_socket(handle) waiters = self._afd_waiters.get(handle) if waiters is not None: @@ -724,12 +748,14 @@ def notify_closing(self, handle): ################################################################ @_public - def register_with_iocp(self, handle): + def register_with_iocp(self, handle: int | CData) -> None: self._register_with_iocp(handle, CKeys.WAIT_OVERLAPPED) @_public - async def wait_overlapped(self, handle, lpOverlapped): - handle = _handle(handle) + async def wait_overlapped( + self, handle_: int | CData, lpOverlapped: CData | int + ) -> object: + handle = _handle(handle_) if isinstance(lpOverlapped, int): lpOverlapped = ffi.cast("LPOVERLAPPED", lpOverlapped) if lpOverlapped in self._overlapped_waiters: @@ -740,13 +766,14 @@ async def wait_overlapped(self, handle, lpOverlapped): self._overlapped_waiters[lpOverlapped] = task raise_cancel = None - def abort(raise_cancel_): + def abort(raise_cancel_: RaiseCancelT) -> Abort: nonlocal raise_cancel raise_cancel = raise_cancel_ try: _check(kernel32.CancelIoEx(handle, lpOverlapped)) except OSError as exc: if exc.winerror == ErrorCodes.ERROR_NOT_FOUND: + assert self._iocp is not None # Too late to cancel. If this happens because the # operation is already completed, we don't need to do # anything; we'll get a notification of that completion @@ -775,12 +802,14 @@ def abort(raise_cancel_): ) from exc return _core.Abort.FAILED + # TODO: what type does this return? info = await _core.wait_task_rescheduled(abort) - if lpOverlapped.Internal != 0: + lpOverlappedTyped = cast("_Overlapped", lpOverlapped) + if lpOverlappedTyped.Internal != 0: # the lpOverlapped reports the error as an NT status code, # which we must convert back to a Win32 error code before # it will produce the right sorts of exceptions - code = ntdll.RtlNtStatusToDosError(lpOverlapped.Internal) + code = ntdll.RtlNtStatusToDosError(lpOverlappedTyped.Internal) if code == ErrorCodes.ERROR_OPERATION_ABORTED: if raise_cancel is not None: raise_cancel() @@ -793,7 +822,9 @@ def abort(raise_cancel_): raise_winerror(code) return info - async def _perform_overlapped(self, handle, submit_fn): + async def _perform_overlapped( + self, handle: int | CData, submit_fn: Callable[[_Overlapped], None] + ) -> _Overlapped: # submit_fn(lpOverlapped) submits some I/O # it may raise an OSError with ERROR_IO_PENDING # the handle must already be registered using @@ -802,20 +833,22 @@ async def _perform_overlapped(self, handle, submit_fn): # operation will not be cancellable, depending on how Windows is # feeling today. So we need to check for cancellation manually. await _core.checkpoint_if_cancelled() - lpOverlapped = ffi.new("LPOVERLAPPED") + lpOverlapped = cast(_Overlapped, ffi.new("LPOVERLAPPED")) try: submit_fn(lpOverlapped) except OSError as exc: if exc.winerror != ErrorCodes.ERROR_IO_PENDING: raise - await self.wait_overlapped(handle, lpOverlapped) + await self.wait_overlapped(handle, cast(CData, lpOverlapped)) return lpOverlapped @_public - async def write_overlapped(self, handle, data, file_offset=0): + async def write_overlapped( + self, handle: int | CData, data: Buffer, file_offset: int = 0 + ) -> int: with ffi.from_buffer(data) as cbuf: - def submit_write(lpOverlapped): + def submit_write(lpOverlapped: _Overlapped) -> None: # yes, these are the real documented names offset_fields = lpOverlapped.DUMMYUNIONNAME.DUMMYSTRUCTNAME offset_fields.Offset = file_offset & 0xFFFFFFFF @@ -835,10 +868,12 @@ def submit_write(lpOverlapped): return lpOverlapped.InternalHigh @_public - async def readinto_overlapped(self, handle, buffer, file_offset=0): + async def readinto_overlapped( + self, handle: int | CData, buffer: Buffer, file_offset: int = 0 + ) -> int: with ffi.from_buffer(buffer, require_writable=True) as cbuf: - def submit_read(lpOverlapped): + def submit_read(lpOverlapped: _Overlapped) -> None: offset_fields = lpOverlapped.DUMMYUNIONNAME.DUMMYSTRUCTNAME offset_fields.Offset = file_offset & 0xFFFFFFFF offset_fields.OffsetHigh = file_offset >> 32 @@ -860,14 +895,15 @@ def submit_read(lpOverlapped): ################################################################ @_public - def current_iocp(self): + def current_iocp(self) -> int: + assert self._iocp is not None return int(ffi.cast("uintptr_t", self._iocp)) @contextmanager @_public - def monitor_completion_key(self): + def monitor_completion_key(self) -> Iterator[tuple[int, UnboundedQueue[object]]]: key = next(self._completion_key_counter) - queue = _core.UnboundedQueue() + queue = _core.UnboundedQueue[object]() self._completion_key_queues[key] = queue try: yield (key, queue) diff --git a/trio/_core/_tests/test_windows.py b/trio/_core/_tests/test_windows.py index 7beb59cc21..f0961cf9e4 100644 --- a/trio/_core/_tests/test_windows.py +++ b/trio/_core/_tests/test_windows.py @@ -15,7 +15,9 @@ # Mark all the tests in this file as being windows-only pytestmark = pytest.mark.skipif(not on_windows, reason="windows only") -assert sys.platform == "win32" or not TYPE_CHECKING # Skip type checking on Windows +assert ( + sys.platform == "win32" or not TYPE_CHECKING +) # Skip type checking when not on Windows from ... import _core, sleep from ...testing import wait_all_tasks_blocked @@ -25,6 +27,7 @@ from .._windows_cffi import ( INVALID_HANDLE_VALUE, FileFlags, + Handle, ffi, kernel32, raise_winerror, @@ -73,8 +76,10 @@ def test_winerror(monkeypatch: pytest.MonkeyPatch) -> None: # then we filter out the warning. @pytest.mark.filterwarnings("ignore:.*UnboundedQueue:trio.TrioDeprecationWarning") async def test_completion_key_listen() -> None: + from .. import _io_windows + async def post(key: int) -> None: - iocp = ffi.cast("HANDLE", _core.current_iocp()) + iocp = Handle(ffi.cast("HANDLE", _core.current_iocp())) for i in range(10): print("post", i) if i % 3 == 0: @@ -90,6 +95,7 @@ async def post(key: int) -> None: async for batch in queue: # pragma: no branch print("got some", batch) for info in batch: + assert isinstance(info, _io_windows.CompletionKeyEventInfo) assert info.lpOverlapped == 0 assert info.dwNumberOfBytesTransferred == i i += 1 @@ -153,8 +159,8 @@ def pipe_with_overlapped_read() -> Generator[tuple[BufferedWriter, int], None, N write_fd = msvcrt.open_osfhandle(write_handle, 0) yield os.fdopen(write_fd, "wb", closefd=False), read_handle finally: - kernel32.CloseHandle(ffi.cast("HANDLE", read_handle)) - kernel32.CloseHandle(ffi.cast("HANDLE", write_handle)) + kernel32.CloseHandle(Handle(ffi.cast("HANDLE", read_handle))) + kernel32.CloseHandle(Handle(ffi.cast("HANDLE", write_handle))) @restore_unraisablehook() diff --git a/trio/_core/_windows_cffi.py b/trio/_core/_windows_cffi.py index a65a332c2f..d411770971 100644 --- a/trio/_core/_windows_cffi.py +++ b/trio/_core/_windows_cffi.py @@ -2,7 +2,7 @@ import enum import re -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, NewType, Protocol, cast if TYPE_CHECKING: from typing_extensions import NoReturn, TypeAlias @@ -222,12 +222,196 @@ LIB = re.sub(r"\bPASCAL\b", "__stdcall", LIB) ffi = cffi.api.FFI() -CData: TypeAlias = cffi.api.FFI.CData ffi.cdef(LIB) -kernel32 = ffi.dlopen("kernel32.dll") -ntdll = ffi.dlopen("ntdll.dll") -ws2_32 = ffi.dlopen("ws2_32.dll") +CData: TypeAlias = cffi.api.FFI.CData +CType: TypeAlias = cffi.api.FFI.CType +AlwaysNull: TypeAlias = CType # We currently always pass ffi.NULL here. +Handle = NewType("Handle", CData) +HandleArray = NewType("HandleArray", CData) + + +class _Kernel32(Protocol): + """Statically typed version of the kernel32.dll functions we use.""" + + def CreateIoCompletionPort( + self, + FileHandle: Handle, + ExistingCompletionPort: CData | AlwaysNull, + CompletionKey: int, + NumberOfConcurrentThreads: int, + /, + ) -> Handle: + ... + + def CreateEventA( + self, + lpEventAttributes: AlwaysNull, + bManualReset: bool, + bInitialState: bool, + lpName: AlwaysNull, + /, + ) -> Handle: + ... + + def SetFileCompletionNotificationModes( + self, handle: Handle, flags: CompletionModes, / + ) -> int: + ... + + def PostQueuedCompletionStatus( + self, + CompletionPort: Handle, + dwNumberOfBytesTransferred: int, + dwCompletionKey: int, + lpOverlapped: CData | AlwaysNull, + /, + ) -> bool: + ... + + def CancelIoEx( + self, + hFile: Handle, + lpOverlapped: CData | AlwaysNull, + /, + ) -> bool: + ... + + def WriteFile( + self, + hFile: Handle, + # not sure about this type + lpBuffer: CData, + nNumberOfBytesToWrite: int, + lpNumberOfBytesWritten: AlwaysNull, + lpOverlapped: _Overlapped, + /, + ) -> bool: + ... + + def ReadFile( + self, + hFile: Handle, + # not sure about this type + lpBuffer: CData, + nNumberOfBytesToRead: int, + lpNumberOfBytesRead: AlwaysNull, + lpOverlapped: _Overlapped, + /, + ) -> bool: + ... + + def GetQueuedCompletionStatusEx( + self, + CompletionPort: Handle, + lpCompletionPortEntries: CData, + ulCount: int, + ulNumEntriesRemoved: CData, + dwMilliseconds: int, + fAlertable: bool | int, + /, + ) -> CData: + ... + + def CreateFileW( + self, + lpFileName: CData, + dwDesiredAccess: FileFlags, + dwShareMode: FileFlags, + lpSecurityAttributes: AlwaysNull, + dwCreationDisposition: FileFlags, + dwFlagsAndAttributes: FileFlags, + hTemplateFile: AlwaysNull, + /, + ) -> Handle: + ... + + def WaitForSingleObject(self, hHandle: Handle, dwMilliseconds: int, /) -> CData: + ... + + def WaitForMultipleObjects( + self, + nCount: int, + lpHandles: HandleArray, + bWaitAll: bool, + dwMilliseconds: int, + /, + ) -> ErrorCodes: + ... + + def SetEvent(self, handle: Handle, /) -> None: + ... + + def CloseHandle(self, handle: Handle, /) -> bool: + ... + + def DeviceIoControl( + self, + hDevice: Handle, + dwIoControlCode: int, + # this is wrong (it's not always null) + lpInBuffer: AlwaysNull, + nInBufferSize: int, + # this is also wrong + lpOutBuffer: AlwaysNull, + nOutBufferSize: int, + lpBytesReturned: AlwaysNull, + lpOverlapped: CData, + /, + ) -> bool: + ... + + +class _Nt(Protocol): + """Statically typed version of the dtdll.dll functions we use.""" + + def RtlNtStatusToDosError(self, status: int, /) -> ErrorCodes: + ... + + +class _Ws2(Protocol): + """Statically typed version of the ws2_32.dll functions we use.""" + + def WSAGetLastError(self) -> int: + ... + + def WSAIoctl( + self, + socket: CData, + dwIoControlCode: WSAIoctls, + lpvInBuffer: AlwaysNull, + cbInBuffer: int, + lpvOutBuffer: CData, + cbOutBuffer: int, + lpcbBytesReturned: CData, # int* + lpOverlapped: AlwaysNull, + # actually LPWSAOVERLAPPED_COMPLETION_ROUTINE + lpCompletionRoutine: AlwaysNull, + /, + ) -> int: + ... + + +class _DummyStruct(Protocol): + Offset: int + OffsetHigh: int + + +class _DummyUnion(Protocol): + DUMMYSTRUCTNAME: _DummyStruct + Pointer: object + + +class _Overlapped(Protocol): + Internal: int + InternalHigh: int + DUMMYUNIONNAME: _DummyUnion + hEvent: Handle + + +kernel32 = cast(_Kernel32, ffi.dlopen("kernel32.dll")) +ntdll = cast(_Nt, ffi.dlopen("ntdll.dll")) +ws2_32 = cast(_Ws2, ffi.dlopen("ws2_32.dll")) ################################################################ # Magic numbers @@ -237,7 +421,7 @@ # https://www.magnumdb.com # (Tip: check the box to see "Hex value") -INVALID_HANDLE_VALUE = ffi.cast("HANDLE", -1) +INVALID_HANDLE_VALUE = Handle(ffi.cast("HANDLE", -1)) class ErrorCodes(enum.IntEnum): @@ -255,7 +439,7 @@ class ErrorCodes(enum.IntEnum): ERROR_NOT_SOCKET = 10038 -class FileFlags(enum.IntEnum): +class FileFlags(enum.IntFlag): GENERIC_READ = 0x80000000 SYNCHRONIZE = 0x00100000 FILE_FLAG_OVERLAPPED = 0x40000000 @@ -309,7 +493,7 @@ class IoControlCodes(enum.IntEnum): ################################################################ -def _handle(obj: int | CData) -> CData: +def _handle(obj: int | CData) -> Handle: # For now, represent handles as either cffi HANDLEs or as ints. If you # try to pass in a file descriptor instead, it's not going to work # out. (For that msvcrt.get_osfhandle does the trick, but I don't know if @@ -317,8 +501,13 @@ def _handle(obj: int | CData) -> CData: # matter, Python never allocates an fd. So let's wait until we actually # encounter the problem before worrying about it. if isinstance(obj, int): - return ffi.cast("HANDLE", obj) - return obj + return Handle(ffi.cast("HANDLE", obj)) + return Handle(obj) + + +def handle_array(count: int) -> HandleArray: + """Make an array of handles.""" + return HandleArray(ffi.new(f"HANDLE[{count}]")) def raise_winerror( @@ -327,13 +516,17 @@ def raise_winerror( filename: str | None = None, filename2: str | None = None, ) -> NoReturn: + # assert sys.platform == "win32" # TODO: make this work in MyPy + # ... in the meanwhile, ffi.getwinerror() is undefined on non-Windows, necessitating the type + # ignores. + if winerror is None: - err = ffi.getwinerror() + err = ffi.getwinerror() # type: ignore[attr-defined,unused-ignore] if err is None: raise RuntimeError("No error set?") winerror, msg = err else: - err = ffi.getwinerror(winerror) + err = ffi.getwinerror(winerror) # type: ignore[attr-defined,unused-ignore] if err is None: raise RuntimeError("No error set?") _, msg = err diff --git a/trio/_subprocess_platform/waitid.py b/trio/_subprocess_platform/waitid.py index 2a2ca6719d..7d941747e3 100644 --- a/trio/_subprocess_platform/waitid.py +++ b/trio/_subprocess_platform/waitid.py @@ -43,7 +43,7 @@ def sync_wait_reapable(pid: int) -> None: int waitid(int idtype, int id, siginfo_t* result, int options); """ ) - waitid_cffi = waitid_ffi.dlopen(None).waitid + waitid_cffi = waitid_ffi.dlopen(None).waitid # type: ignore[attr-defined] def sync_wait_reapable(pid: int) -> None: P_PID = 1 diff --git a/trio/_tests/verify_types_windows.json b/trio/_tests/verify_types_windows.json index 021aa699d2..bab6797c02 100644 --- a/trio/_tests/verify_types_windows.json +++ b/trio/_tests/verify_types_windows.json @@ -7,32 +7,16 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.9857369255150554, + "completenessScore": 1, "diagnostics": [ - { - "message": "Return type annotation is missing", - "name": "trio.lowlevel.current_iocp" - }, { "message": "No docstring found for function \"trio.lowlevel.current_iocp\"", "name": "trio.lowlevel.current_iocp" }, - { - "message": "Return type annotation is missing", - "name": "trio.lowlevel.monitor_completion_key" - }, { "message": "No docstring found for function \"trio.lowlevel.monitor_completion_key\"", "name": "trio.lowlevel.monitor_completion_key" }, - { - "message": "Type annotation for parameter \"handle\" is missing", - "name": "trio.lowlevel.notify_closing" - }, - { - "message": "Return type annotation is missing", - "name": "trio.lowlevel.notify_closing" - }, { "message": "No docstring found for function \"trio.lowlevel.notify_closing\"", "name": "trio.lowlevel.notify_closing" @@ -41,94 +25,26 @@ "message": "No docstring found for function \"trio.lowlevel.open_process\"", "name": "trio.lowlevel.open_process" }, - { - "message": "Type annotation for parameter \"handle\" is missing", - "name": "trio.lowlevel.readinto_overlapped" - }, - { - "message": "Type annotation for parameter \"buffer\" is missing", - "name": "trio.lowlevel.readinto_overlapped" - }, - { - "message": "Type annotation for parameter \"file_offset\" is missing", - "name": "trio.lowlevel.readinto_overlapped" - }, - { - "message": "Return type annotation is missing", - "name": "trio.lowlevel.readinto_overlapped" - }, { "message": "No docstring found for function \"trio.lowlevel.readinto_overlapped\"", "name": "trio.lowlevel.readinto_overlapped" }, - { - "message": "Type annotation for parameter \"handle\" is missing", - "name": "trio.lowlevel.register_with_iocp" - }, - { - "message": "Return type annotation is missing", - "name": "trio.lowlevel.register_with_iocp" - }, { "message": "No docstring found for function \"trio.lowlevel.register_with_iocp\"", "name": "trio.lowlevel.register_with_iocp" }, - { - "message": "Type annotation for parameter \"handle\" is missing", - "name": "trio.lowlevel.wait_overlapped" - }, - { - "message": "Type annotation for parameter \"lpOverlapped\" is missing", - "name": "trio.lowlevel.wait_overlapped" - }, - { - "message": "Return type annotation is missing", - "name": "trio.lowlevel.wait_overlapped" - }, { "message": "No docstring found for function \"trio.lowlevel.wait_overlapped\"", "name": "trio.lowlevel.wait_overlapped" }, - { - "message": "Type annotation for parameter \"sock\" is missing", - "name": "trio.lowlevel.wait_readable" - }, - { - "message": "Return type annotation is missing", - "name": "trio.lowlevel.wait_readable" - }, { "message": "No docstring found for function \"trio.lowlevel.wait_readable\"", "name": "trio.lowlevel.wait_readable" }, - { - "message": "Type annotation for parameter \"sock\" is missing", - "name": "trio.lowlevel.wait_writable" - }, - { - "message": "Return type annotation is missing", - "name": "trio.lowlevel.wait_writable" - }, { "message": "No docstring found for function \"trio.lowlevel.wait_writable\"", "name": "trio.lowlevel.wait_writable" }, - { - "message": "Type annotation for parameter \"handle\" is missing", - "name": "trio.lowlevel.write_overlapped" - }, - { - "message": "Type annotation for parameter \"data\" is missing", - "name": "trio.lowlevel.write_overlapped" - }, - { - "message": "Type annotation for parameter \"file_offset\" is missing", - "name": "trio.lowlevel.write_overlapped" - }, - { - "message": "Return type annotation is missing", - "name": "trio.lowlevel.write_overlapped" - }, { "message": "No docstring found for function \"trio.lowlevel.write_overlapped\"", "name": "trio.lowlevel.write_overlapped" @@ -144,8 +60,8 @@ ], "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 622, - "withUnknownType": 9 + "withKnownType": 631, + "withUnknownType": 0 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 1, @@ -180,7 +96,7 @@ ], "otherSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 675, + "withKnownType": 680, "withUnknownType": 0 }, "packageName": "trio" diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index 47383c7055..b5612c59c3 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -267,7 +267,12 @@ def main() -> None: # pragma: no cover "runner.instruments", imports=IMPORTS_INSTRUMENT, ), - File(core / "_io_windows.py", "runner.io_manager", platform="win32"), + File( + core / "_io_windows.py", + "runner.io_manager", + platform="win32", + imports=IMPORTS_WINDOWS, + ), File( core / "_io_epoll.py", "runner.io_manager", @@ -317,6 +322,17 @@ def main() -> None: # pragma: no cover """ +IMPORTS_WINDOWS = """\ +from typing import TYPE_CHECKING, ContextManager + +if TYPE_CHECKING: + from .._file_io import _HasFileNo + from ._windows_cffi import Handle, CData + from typing_extensions import Buffer + + from ._unbounded_queue import UnboundedQueue +""" + if __name__ == "__main__": # pragma: no cover main() diff --git a/trio/_wait_for_object.py b/trio/_wait_for_object.py index 50a9d13ff2..d2193d9c86 100644 --- a/trio/_wait_for_object.py +++ b/trio/_wait_for_object.py @@ -9,6 +9,7 @@ ErrorCodes, _handle, ffi, + handle_array, kernel32, raise_winerror, ) @@ -57,7 +58,7 @@ async def WaitForSingleObject(obj: int | CData) -> None: def WaitForMultipleObjects_sync(*handles: int | CData) -> None: """Wait for any of the given Windows handles to be signaled.""" n = len(handles) - handle_arr = ffi.new(f"HANDLE[{n}]") + handle_arr = handle_array(n) for i in range(n): handle_arr[i] = handles[i] timeout = 0xFFFFFFFF # INFINITE