From 6b668afeea536ea3182baadec39da755728c79d0 Mon Sep 17 00:00:00 2001 From: jtnunley Date: Fri, 3 Feb 2023 09:10:48 -0800 Subject: [PATCH 1/8] Implement our own Windows backend Implement polling for Windows Reimplement wepoll in Rust Fix immediately obvious bugs Fix AFD event translation Fix MSRV build errors Make sure to unlock packets after retrieval Lock sources list with rwlock instead of mutex --- Cargo.toml | 10 +- src/iocp/afd.rs | 607 ++++++++++++++++++++++++ src/iocp/mod.rs | 785 +++++++++++++++++++++++++++++++ src/iocp/port.rs | 317 +++++++++++++ src/lib.rs | 4 +- src/wepoll.rs | 254 ---------- tests/concurrent_modification.rs | 2 +- tests/io.rs | 38 ++ tests/precision.rs | 4 +- 9 files changed, 1760 insertions(+), 261 deletions(-) create mode 100644 src/iocp/afd.rs create mode 100644 src/iocp/mod.rs create mode 100644 src/iocp/port.rs delete mode 100644 src/wepoll.rs create mode 100644 tests/io.rs diff --git a/Cargo.toml b/Cargo.toml index ac0b0ac..e6e57ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,13 +32,19 @@ autocfg = "1" libc = "0.2.77" [target.'cfg(windows)'.dependencies] -wepoll-ffi = { version = "0.1.2", features = ["null-overlapped-wakeups-patch"] } +bitflags = "1.3.2" +concurrent-queue = "2.1.0" +pin-project-lite = "0.2.9" [target.'cfg(windows)'.dependencies.windows-sys] version = "0.45" features = [ + "Win32_Networking_WinSock", "Win32_System_IO", - "Win32_Foundation" + "Win32_System_LibraryLoader", + "Win32_System_WindowsProgramming", + "Win32_Storage_FileSystem", + "Win32_Foundation", ] [dev-dependencies] diff --git a/src/iocp/afd.rs b/src/iocp/afd.rs new file mode 100644 index 0000000..a8a2de3 --- /dev/null +++ b/src/iocp/afd.rs @@ -0,0 +1,607 @@ +//! Safe wrapper around \Device\Afd + +use super::port::{Completion, CompletionHandle}; + +use std::cell::UnsafeCell; +use std::fmt; +use std::io; +use std::marker::{PhantomData, PhantomPinned}; +use std::mem::{size_of, transmute, MaybeUninit}; +use std::os::windows::prelude::{AsRawHandle, RawHandle, RawSocket}; +use std::pin::Pin; +use std::ptr; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Once; + +use windows_sys::Win32::Foundation::{ + CloseHandle, HANDLE, HINSTANCE, NTSTATUS, STATUS_NOT_FOUND, STATUS_PENDING, STATUS_SUCCESS, + UNICODE_STRING, +}; +use windows_sys::Win32::Networking::WinSock::{ + WSAIoctl, SIO_BASE_HANDLE, SIO_BSP_HANDLE_POLL, SOCKET_ERROR, +}; +use windows_sys::Win32::Storage::FileSystem::{ + FILE_OPEN, FILE_SHARE_READ, FILE_SHARE_WRITE, SYNCHRONIZE, +}; +use windows_sys::Win32::System::LibraryLoader::{GetModuleHandleW, GetProcAddress}; +use windows_sys::Win32::System::WindowsProgramming::{IO_STATUS_BLOCK, OBJECT_ATTRIBUTES}; + +#[derive(Default)] +#[repr(C)] +pub(super) struct AfdPollInfo { + /// The timeout for this poll. + timeout: i64, + + /// The number of handles being polled. + handle_count: u32, + + /// Whether or not this poll is exclusive for this handle. + exclusive: u32, + + /// The handles to poll. + handles: [AfdPollHandleInfo; 1], +} + +#[derive(Default)] +#[repr(C)] +struct AfdPollHandleInfo { + /// The handle to poll. + handle: HANDLE, + + /// The events to poll for. + events: AfdPollMask, + + /// The status of the poll. + status: NTSTATUS, +} + +impl AfdPollInfo { + pub(super) fn handle_count(&self) -> u32 { + self.handle_count + } + + pub(super) fn events(&self) -> AfdPollMask { + self.handles[0].events + } +} + +bitflags::bitflags! { + #[derive(Default)] + pub(super) struct AfdPollMask: u32 { + const RECEIVE = 0x001; + const RECEIVE_EXPEDITED = 0x002; + const SEND = 0x004; + const DISCONNECT = 0x008; + const ABORT = 0x010; + const LOCAL_CLOSE = 0x020; + const ACCEPT = 0x080; + const CONNECT_FAIL = 0x100; + } +} + +pub(super) trait HasAfdInfo { + fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell>; +} + +macro_rules! define_ntdll_import { + ( + $( + $(#[$attr:meta])* + fn $name:ident($($arg:ident: $arg_ty:ty),*) -> $ret:ty; + )* + ) => { + /// Imported functions from ntdll.dll. + #[allow(non_snake_case)] + pub(super) struct NtdllImports { + $( + $(#[$attr])* + $name: unsafe extern "system" fn($($arg_ty),*) -> $ret, + )* + } + + #[allow(non_snake_case)] + impl NtdllImports { + unsafe fn load(ntdll: HINSTANCE) -> io::Result { + $( + let $name = { + const NAME: &str = concat!(stringify!($name), "\0"); + let addr = GetProcAddress(ntdll, NAME.as_ptr() as *const _); + + let addr = match addr { + Some(addr) => addr, + None => { + log::error!("Failed to load ntdll function {}", NAME); + return Err(io::Error::last_os_error()); + }, + }; + + transmute::<_, unsafe extern "system" fn($($arg_ty),*) -> $ret>(addr) + }; + )* + + Ok(Self { + $( + $name, + )* + }) + } + + $( + $(#[$attr])* + unsafe fn $name(&self, $($arg: $arg_ty),*) -> $ret { + (self.$name)($($arg),*) + } + )* + } + }; +} + +define_ntdll_import! { + /// Cancels an ongoing I/O operation. + fn NtCancelIoFileEx( + FileHandle: HANDLE, + IoRequestToCancel: *mut IO_STATUS_BLOCK, + IoStatusBlock: *mut IO_STATUS_BLOCK + ) -> NTSTATUS; + + /// Opens or creates a file handle. + #[allow(clippy::too_many_arguments)] + fn NtCreateFile( + FileHandle: *mut HANDLE, + DesiredAccess: u32, + ObjectAttributes: *mut OBJECT_ATTRIBUTES, + IoStatusBlock: *mut IO_STATUS_BLOCK, + AllocationSize: *mut i64, + FileAttributes: u32, + ShareAccess: u32, + CreateDisposition: u32, + CreateOptions: u32, + EaBuffer: *mut (), + EaLength: u32 + ) -> NTSTATUS; + + /// Runs an I/O control on a file handle. + /// + /// Practically equivalent to `ioctl`. + #[allow(clippy::too_many_arguments)] + fn NtDeviceIoControlFile( + FileHandle: HANDLE, + Event: HANDLE, + ApcRoutine: *mut (), + ApcContext: *mut (), + IoStatusBlock: *mut IO_STATUS_BLOCK, + IoControlCode: u32, + InputBuffer: *mut (), + InputBufferLength: u32, + OutputBuffer: *mut (), + OutputBufferLength: u32 + ) -> NTSTATUS; + + /// Converts `NTSTATUS` to a DOS error code. + fn RtlNtStatusToDosError( + Status: NTSTATUS + ) -> u32; +} + +impl NtdllImports { + fn get() -> io::Result<&'static Self> { + macro_rules! s { + ($e:expr) => {{ + $e as u16 + }}; + } + + // ntdll.dll + static NTDLL_NAME: &[u16] = &[ + s!('n'), + s!('t'), + s!('d'), + s!('l'), + s!('l'), + s!('.'), + s!('d'), + s!('l'), + s!('l'), + s!('\0'), + ]; + static NTDLL_IMPORTS: OnceCell> = OnceCell::new(); + + NTDLL_IMPORTS + .get_or_init(|| unsafe { + let ntdll = GetModuleHandleW(NTDLL_NAME.as_ptr() as *const _); + + if ntdll == 0 { + log::error!("Failed to load ntdll.dll"); + return Err(io::Error::last_os_error()); + } + + NtdllImports::load(ntdll) + }) + .as_ref() + .map_err(|e| io::Error::from(e.kind())) + } + + pub(super) fn force_load() -> io::Result<()> { + Self::get()?; + Ok(()) + } +} + +/// The handle to the AFD device. +pub(super) struct Afd { + /// The handle to the AFD device. + handle: HANDLE, + + /// We own `T`. + _marker: PhantomData, +} + +impl fmt::Debug for Afd { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct WriteAsHex(HANDLE); + + impl fmt::Debug for WriteAsHex { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:010x}", self.0) + } + } + + f.debug_struct("Afd") + .field("handle", &WriteAsHex(self.handle)) + .finish() + } +} + +impl Drop for Afd { + fn drop(&mut self) { + unsafe { + CloseHandle(self.handle); + } + } +} + +impl AsRawHandle for Afd { + fn as_raw_handle(&self) -> RawHandle { + self.handle as _ + } +} + +impl Afd +where + T::Completion: AsIoStatusBlock + HasAfdInfo, +{ + /// Create a new AFD device. + pub(super) fn new() -> io::Result { + macro_rules! s { + ($e:expr) => { + ($e) as u16 + }; + } + + /// \Device\Afd\Smol + const AFD_NAME: &[u16] = &[ + s!('\\'), + s!('D'), + s!('e'), + s!('v'), + s!('i'), + s!('c'), + s!('e'), + s!('\\'), + s!('A'), + s!('f'), + s!('d'), + s!('\\'), + s!('S'), + s!('m'), + s!('o'), + s!('l'), + s!('\0'), + ]; + + // Set up device attributes. + let mut device_name = UNICODE_STRING { + Length: (AFD_NAME.len() * size_of::()) as u16, + MaximumLength: (AFD_NAME.len() * size_of::()) as u16, + Buffer: AFD_NAME.as_ptr() as *mut _, + }; + let mut device_attributes = OBJECT_ATTRIBUTES { + Length: size_of::() as u32, + RootDirectory: 0, + ObjectName: &mut device_name, + Attributes: 0, + SecurityDescriptor: ptr::null_mut(), + SecurityQualityOfService: ptr::null_mut(), + }; + + let mut handle = MaybeUninit::::uninit(); + let mut iosb = MaybeUninit::::zeroed(); + let ntdll = NtdllImports::get()?; + + let result = unsafe { + ntdll.NtCreateFile( + handle.as_mut_ptr(), + SYNCHRONIZE, + &mut device_attributes, + iosb.as_mut_ptr(), + ptr::null_mut(), + 0, + FILE_SHARE_READ | FILE_SHARE_WRITE, + FILE_OPEN, + 0, + ptr::null_mut(), + 0, + ) + }; + + if result != STATUS_SUCCESS { + let real_code = unsafe { ntdll.RtlNtStatusToDosError(result) }; + + return Err(io::Error::from_raw_os_error(real_code as i32)); + } + + let handle = unsafe { handle.assume_init() }; + + Ok(Self { + handle, + _marker: PhantomData, + }) + } + + /// Begin polling with the provided handle. + pub(super) fn poll( + &self, + packet: T, + base_socket: RawSocket, + afd_events: AfdPollMask, + ) -> io::Result<()> { + const IOCTL_AFD_POLL: u32 = 0x00012024; + + // Lock the packet. + if !packet.get().try_lock() { + return Err(io::Error::new( + io::ErrorKind::WouldBlock, + "packet is already in use", + )); + } + + // Set up the AFD poll info. + let poll_info = unsafe { + let poll_info = Pin::into_inner_unchecked(packet.get().afd_info()).get(); + + // Initialize the AFD poll info. + (*poll_info).exclusive = false.into(); + (*poll_info).handle_count = 1; + (*poll_info).timeout = std::i64::MAX; + (*poll_info).handles[0].handle = base_socket as HANDLE; + (*poll_info).handles[0].status = 0; + (*poll_info).handles[0].events = afd_events; + + poll_info + }; + + let iosb = T::into_ptr(packet).cast::(); + // Set Status to pending + unsafe { + (*iosb).Anonymous.Status = STATUS_PENDING; + } + + let ntdll = NtdllImports::get()?; + let result = unsafe { + ntdll.NtDeviceIoControlFile( + self.handle, + 0, + ptr::null_mut(), + iosb.cast(), + iosb.cast(), + IOCTL_AFD_POLL, + poll_info.cast(), + size_of::() as u32, + poll_info.cast(), + size_of::() as u32, + ) + }; + + match result { + STATUS_SUCCESS => Ok(()), + STATUS_PENDING => Err(io::ErrorKind::WouldBlock.into()), + status => { + let real_code = unsafe { ntdll.RtlNtStatusToDosError(status) }; + + Err(io::Error::from_raw_os_error(real_code as i32)) + } + } + } + + /// Cancel an ongoing poll operation. + /// + /// # Safety + /// + /// The poll operation must currently be in progress for this AFD. + pub(super) unsafe fn cancel(&self, packet: &T) -> io::Result<()> { + let ntdll = NtdllImports::get()?; + + let result = { + // First, check if the packet is still in use. + let iosb = packet.as_ptr().cast::(); + + if (*iosb).Anonymous.Status != STATUS_PENDING { + return Ok(()); + } + + // Cancel the packet. + let mut cancel_iosb = MaybeUninit::::zeroed(); + + ntdll.NtCancelIoFileEx(self.handle, iosb, cancel_iosb.as_mut_ptr()) + }; + + if result == STATUS_SUCCESS || result == STATUS_NOT_FOUND { + Ok(()) + } else { + let real_code = ntdll.RtlNtStatusToDosError(result); + + Err(io::Error::from_raw_os_error(real_code as i32)) + } + } +} + +/// A one-time initialization cell. +struct OnceCell { + /// The value. + value: UnsafeCell>, + + /// The one-time initialization. + once: Once, +} + +unsafe impl Send for OnceCell {} +unsafe impl Sync for OnceCell {} + +impl OnceCell { + /// Creates a new `OnceCell`. + pub const fn new() -> Self { + OnceCell { + value: UnsafeCell::new(MaybeUninit::uninit()), + once: Once::new(), + } + } + + /// Gets the value or initializes it. + pub fn get_or_init(&self, f: F) -> &T + where + F: FnOnce() -> T, + { + self.once.call_once(|| unsafe { + let value = f(); + *self.value.get() = MaybeUninit::new(value); + }); + + unsafe { &*self.value.get().cast() } + } +} + +pin_project_lite::pin_project! { + /// An I/O status block paired with some auxillary data. + #[repr(C)] + pub(super) struct IoStatusBlock { + // The I/O status block. + iosb: UnsafeCell, + + // Whether or not the block is in use. + in_use: AtomicBool, + + // The auxillary data. + #[pin] + data: T, + + // This block is not allowed to move. + #[pin] + _marker: PhantomPinned, + } +} + +impl fmt::Debug for IoStatusBlock { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("IoStatusBlock") + .field("iosb", &"..") + .field("in_use", &self.in_use) + .field("data", &self.data) + .finish() + } +} + +impl From for IoStatusBlock { + fn from(data: T) -> Self { + Self { + iosb: UnsafeCell::new(unsafe { std::mem::zeroed() }), + in_use: AtomicBool::new(false), + data, + _marker: PhantomPinned, + } + } +} + +impl IoStatusBlock { + pub(super) fn iosb(self: Pin<&Self>) -> &UnsafeCell { + self.project_ref().iosb + } + + pub(super) fn data(self: Pin<&Self>) -> Pin<&T> { + self.project_ref().data + } +} + +impl HasAfdInfo for IoStatusBlock { + fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell> { + self.project_ref().data.afd_info() + } +} + +/// Can be transmuted to an I/O status block. +/// +/// # Safety +/// +/// A pointer to `T` must be able to be converted to a pointer to `IO_STATUS_BLOCK` +/// without any issues. +pub(super) unsafe trait AsIoStatusBlock {} + +unsafe impl AsIoStatusBlock for IoStatusBlock {} +unsafe impl Completion for IoStatusBlock { + fn try_lock(self: Pin<&Self>) -> bool { + !self.in_use.swap(true, Ordering::SeqCst) + } + + unsafe fn unlock(self: Pin<&Self>) { + self.in_use.store(false, Ordering::SeqCst); + } +} + +/// Get the base socket associated with a socket. +pub(super) fn base_socket(sock: RawSocket) -> io::Result { + // First, try the SIO_BASE_HANDLE ioctl. + let result = unsafe { try_socket_ioctl(sock, SIO_BASE_HANDLE) }; + + match result { + Ok(sock) => return Ok(sock), + Err(e) if e.kind() == io::ErrorKind::InvalidInput => return Err(e), + Err(_) => {} + } + + // Some poorly coded LSPs may not handle SIO_BASE_HANDLE properly, but in some cases may + // handle SIO_BSP_HANDLE_POLL better. Try that. + let result = unsafe { try_socket_ioctl(sock, SIO_BSP_HANDLE_POLL)? }; + if result == sock { + return Err(io::Error::from(io::ErrorKind::InvalidInput)); + } + + // Try `SIO_BASE_HANDLE` again, in case the LSP fixed itself. + unsafe { try_socket_ioctl(result, SIO_BASE_HANDLE) } +} + +/// Run an IOCTL on a socket and return a socket. +/// +/// # Safety +/// +/// The socket must be valid. +unsafe fn try_socket_ioctl(sock: RawSocket, ioctl: u32) -> io::Result { + let mut out = MaybeUninit::::uninit(); + let mut bytes = MaybeUninit::::uninit(); + + let result = WSAIoctl( + sock as _, + ioctl, + ptr::null_mut(), + 0, + out.as_mut_ptr().cast(), + size_of::() as u32, + bytes.as_mut_ptr(), + ptr::null_mut(), + None, + ); + + if result == SOCKET_ERROR { + return Err(io::Error::last_os_error()); + } + + Ok(out.assume_init()) +} diff --git a/src/iocp/mod.rs b/src/iocp/mod.rs new file mode 100644 index 0000000..9762de5 --- /dev/null +++ b/src/iocp/mod.rs @@ -0,0 +1,785 @@ +//! Bindings to Windows I/O Completion Ports. + +mod afd; +mod port; + +use afd::{base_socket, Afd, AfdPollInfo, AfdPollMask, HasAfdInfo, IoStatusBlock}; +use port::{IoCompletionPort, OverlappedEntry}; +use windows_sys::Win32::Foundation::{ERROR_INVALID_HANDLE, ERROR_IO_PENDING, STATUS_CANCELLED}; + +use crate::{Event, PollMode}; + +use concurrent_queue::ConcurrentQueue; +use pin_project_lite::pin_project; + +use std::cell::UnsafeCell; +use std::collections::hash_map::{Entry, HashMap}; +use std::fmt; +use std::io; +use std::marker::PhantomPinned; +use std::os::windows::io::{AsRawHandle, RawHandle, RawSocket}; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex, MutexGuard, RwLock}; +use std::time::{Duration, Instant}; + +#[cfg(not(polling_no_io_safety))] +use std::os::windows::io::{AsHandle, BorrowedHandle}; + +/// Interface to I/O completion ports. +#[derive(Debug)] +pub(super) struct Poller { + /// The I/O completion port. + port: IoCompletionPort, + + /// List of currently active AFD instances. + afd: Mutex>>>, + + /// The state of the sources registered with this poller. + sources: RwLock>, + + /// Sockets with pending updates. + pending_updates: ConcurrentQueue, + + /// Are we currently polling? + polling: AtomicBool, + + /// A list of completion packets. + packets: Mutex>>, + + /// The packet used to notify the poller. + notifier: Packet, +} + +unsafe impl Send for Poller {} +unsafe impl Sync for Poller {} + +impl Poller { + /// Creates a new poller. + pub(super) fn new() -> io::Result { + // Make sure AFD is able to be used. + if let Err(e) = afd::NtdllImports::force_load() { + return Err(crate::unsupported_error(format!( + "Failed to initialize I/O completion ports: {}\nThis usually only happens for old Windows or Wine.", + e + ))); + } + + let port = IoCompletionPort::new(0)?; + + log::trace!("new: handle={:?}", &port); + + Ok(Poller { + port, + afd: Mutex::new(Vec::new()), + sources: RwLock::new(HashMap::new()), + pending_updates: ConcurrentQueue::bounded(1024), + polling: AtomicBool::new(false), + packets: Mutex::new(Vec::with_capacity(1024)), + notifier: Arc::pin( + PacketInner::Wakeup { + _pinned: PhantomPinned, + } + .into(), + ), + }) + } + + /// Whether this poller supports level-triggered events. + pub(super) fn supports_level(&self) -> bool { + true + } + + /// Whether this poller supports edge-triggered events. + pub(super) fn supports_edge(&self) -> bool { + false + } + + /// Add a new source to the poller. + pub(super) fn add(&self, socket: RawSocket, interest: Event, mode: PollMode) -> io::Result<()> { + log::trace!( + "add: handle={:?}, sock={}, ev={:?}", + self.port, + socket, + interest + ); + + // We don't support edge-triggered events. + if matches!(mode, PollMode::Edge) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "edge-triggered events are not supported", + )); + } + + // Create a new packet. + let socket_state = { + let state = SocketState { + socket, + base_socket: base_socket(socket)?, + interest, + interest_error: true, + afd: self.afd_handle()?, + mode, + waiting_on_delete: false, + status: SocketStatus::Idle, + }; + + Arc::pin(IoStatusBlock::from(PacketInner::Socket { + packet: UnsafeCell::new(AfdPollInfo::default()), + socket: Mutex::new(state), + })) + }; + + // Keep track of the source in the poller. + { + let mut sources = self.sources.write().unwrap_or_else(|e| e.into_inner()); + + match sources.entry(socket) { + Entry::Vacant(v) => { + v.insert(Pin::>::clone(&socket_state)); + } + + Entry::Occupied(_) => { + return Err(io::Error::from(io::ErrorKind::AlreadyExists)); + } + } + } + + // Update the packet. + self.update_packet(socket_state) + } + + /// Update a source in the poller. + pub(super) fn modify( + &self, + socket: RawSocket, + interest: Event, + mode: PollMode, + ) -> io::Result<()> { + log::trace!( + "modify: handle={:?}, sock={}, ev={:?}", + self.port, + socket, + interest + ); + + // We don't support edge-triggered events. + if matches!(mode, PollMode::Edge) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "edge-triggered events are not supported", + )); + } + + // Get a reference to the source. + let source = { + let sources = self.sources.read().unwrap_or_else(|e| e.into_inner()); + + match sources.get(&socket) { + Some(s) => s.clone(), + None => { + return Err(io::Error::from(io::ErrorKind::NotFound)); + } + } + }; + + // Set the new event. + if source.as_ref().set_events(interest, mode) { + self.update_packet(source)?; + } + + Ok(()) + } + + /// Delete a source from the poller. + pub(super) fn delete(&self, socket: RawSocket) -> io::Result<()> { + log::trace!("remove: handle={:?}, sock={}", self.port, socket); + + // Get a reference to the source. + let source = { + let mut sources = self.sources.write().unwrap_or_else(|e| e.into_inner()); + + match sources.remove(&socket) { + Some(s) => s, + None => { + // Just return. + return Ok(()); + } + } + }; + + // Indicate to the source that it is being deleted. + // This cancels any ongoing AFD_IOCTL_POLL operations. + source.begin_delete() + } + + /// Wait for events. + pub(super) fn wait(&self, events: &mut Events, timeout: Option) -> io::Result<()> { + log::trace!("wait: handle={:?}, timeout={:?}", self.port, timeout); + + let deadline = timeout.and_then(|timeout| Instant::now().checked_add(timeout)); + let mut packets = self.packets.lock().unwrap_or_else(|e| e.into_inner()); + let mut notified = false; + events.packets.clear(); + + loop { + let mut new_events = 0; + + // Indicate that we are now polling. + debug_assert!(!self.polling.swap(true, Ordering::SeqCst)); + + let guard = CallOnDrop(|| { + debug_assert!(self.polling.swap(false, Ordering::SeqCst)); + }); + + // Process every entry in the queue before we start polling. + self.drain_update_queue(false)?; + + // Get the time to wait for. + let timeout = deadline.map(|t| t.saturating_duration_since(Instant::now())); + + // Wait for I/O events. + let len = self.port.wait(&mut packets, timeout)?; + log::trace!("new events: handle={:?}, len={}", self.port, len); + + // We are no longer polling. + drop(guard); + + // Process all of the events. + for entry in packets.drain(..) { + let packet = entry.into_packet(); + + // Feed the event into the packet. + match packet.feed_event(self)? { + FeedEventResult::NoEvent => {} + FeedEventResult::Event(event) => { + events.packets.push(event); + new_events += 1; + } + FeedEventResult::Notified => { + notified = true; + } + } + } + + // Break if there was a notification or at least one event, or if deadline is reached. + let timeout_is_empty = + timeout.map_or(false, |t| t.as_secs() == 0 && t.subsec_nanos() == 0); + if notified || new_events > 0 || timeout_is_empty { + break; + } + + log::trace!("wait: no events found, re-entering polling loop"); + } + + Ok(()) + } + + /// Notify this poller. + pub(super) fn notify(&self) -> io::Result<()> { + // Push the notify packet into the IOCP. + self.port.post(0, 0, self.notifier.clone()) + } + + /// Run an update on a packet. + fn update_packet(&self, mut packet: Packet) -> io::Result<()> { + loop { + // If we are currently polling, we need to update the packet immediately. + if self.polling.load(Ordering::Acquire) { + packet.update()?; + return Ok(()); + } + + // Try to queue the update. + match self.pending_updates.push(packet) { + Ok(()) => return Ok(()), + Err(p) => packet = p.into_inner(), + } + + // If we failed to queue the update, we need to drain the queue first. + self.drain_update_queue(true)?; + } + } + + /// Drain the update queue. + fn drain_update_queue(&self, limit: bool) -> io::Result<()> { + let max = if limit { + self.pending_updates.capacity().unwrap() + } else { + std::usize::MAX + }; + + // Only drain the queue's capacity, since this could in theory run forever. + for _ in 0..max { + if let Ok(packet) = self.pending_updates.pop() { + packet.update()?; + } else { + return Ok(()); + } + } + + Ok(()) + } + + /// Get a handle to the AFD reference. + fn afd_handle(&self) -> io::Result>> { + const AFD_MAX_SIZE: usize = 32; + + // See if there are any existing AFD instances that we can use. + let mut afd_handles = self.afd.lock().unwrap_or_else(|e| e.into_inner()); + if let Some(handle) = afd_handles.iter().find(|h| { + let ref_count = Arc::strong_count(h).saturating_sub(1); + ref_count < AFD_MAX_SIZE + }) { + return Ok(handle.clone()); + } + + // Create a new AFD instance. + let afd = Arc::new(Afd::new()?); + + // Register the AFD instance with the I/O completion port. + self.port.register(&*afd, true)?; + + // Insert a copy of the AFD instance into the list. + afd_handles.push(afd.clone()); + + Ok(afd) + } +} + +impl AsRawHandle for Poller { + fn as_raw_handle(&self) -> RawHandle { + self.port.as_raw_handle() + } +} + +#[cfg(not(polling_no_io_safety))] +impl AsHandle for Poller { + fn as_handle(&self) -> BorrowedHandle<'_> { + unsafe { BorrowedHandle::borrow_raw(self.as_raw_handle()) } + } +} + +/// The container for events. +pub(super) struct Events { + /// List of IOCP packets. + packets: Vec, +} + +unsafe impl Send for Events {} + +impl Events { + /// Creates an empty list of events. + pub(super) fn new() -> Events { + Events { + packets: Vec::with_capacity(1024), + } + } + + /// Iterate over I/O events. + pub(super) fn iter(&self) -> impl Iterator + '_ { + self.packets.iter().cloned() + } +} + +/// The type of our completion packet. +type Packet = Pin>; +type PacketUnwrapped = IoStatusBlock; + +pin_project! { + /// The inner type of the packet. + #[project_ref = PacketInnerProj] + #[project = PacketInnerProjMut] + enum PacketInner { + // A packet for a socket. + Socket { + // The AFD packet state. + #[pin] + packet: UnsafeCell, + + // The socket state. + socket: Mutex + }, + + // A packet used to wake up the poller. + Wakeup { #[pin] _pinned: PhantomPinned }, + } +} + +impl fmt::Debug for PacketInner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Wakeup { .. } => f.write_str("Wakeup { .. }"), + Self::Socket { socket, .. } => f + .debug_struct("Socket") + .field("packet", &"..") + .field("socket", socket) + .finish(), + } + } +} + +impl HasAfdInfo for PacketInner { + fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell> { + match self.project_ref() { + PacketInnerProj::Socket { packet, .. } => packet, + PacketInnerProj::Wakeup { .. } => unreachable!(), + } + } +} + +impl PacketUnwrapped { + /// Set the new events that this socket is waiting on. + /// + /// Returns `true` if we need to be updated. + fn set_events(self: Pin<&Self>, interest: Event, mode: PollMode) -> bool { + let mut socket = match self.socket_state() { + Some(s) => s, + None => return false, + }; + + socket.interest = interest; + socket.mode = mode; + socket.interest_error = true; + + match socket.status { + SocketStatus::Polling { readable, writable } => { + (interest.readable && !readable) || (interest.writable && !writable) + } + _ => true, + } + } + + /// Update the socket and install the new status in AFD. + fn update(self: Pin>) -> io::Result<()> { + let mut socket = match self.as_ref().socket_state() { + Some(s) => s, + None => return Err(io::Error::new(io::ErrorKind::Other, "invalid socket state")), + }; + + // If we are waiting on a delete, just return, dropping the packet. + if socket.waiting_on_delete { + return Ok(()); + } + + // Check the current status. + match socket.status { + SocketStatus::Polling { readable, writable } => { + // If we need to poll for events aside from what we are currently polling, we need + // to update the packet. Cancel the ongoing poll. + if (socket.interest.readable && !readable) + || (socket.interest.writable && !writable) + { + return self.cancel(socket); + } + + // All events that we are currently waiting on are accounted for. + Ok(()) + } + + SocketStatus::Cancelled => { + // The ongoing operation was cancelled, and we're still waiting for it to return. + // For now, wait until the top-level loop calls feed_event(). + Ok(()) + } + + SocketStatus::Idle => { + // Start a new poll. + let result = socket.afd.poll( + self.clone(), + socket.base_socket, + event_to_afd_mask( + socket.interest.readable, + socket.interest.writable, + socket.interest_error, + ), + ); + + match result { + Ok(()) => {} + + Err(err) + if err.raw_os_error() == Some(ERROR_IO_PENDING as i32) + || err.kind() == io::ErrorKind::WouldBlock => + { + // The operation is pending. + } + + Err(err) if err.raw_os_error() == Some(ERROR_INVALID_HANDLE as i32) => { + // The socket was closed. We need to delete it. + // This should happen after we drop it here. + } + + Err(err) => return Err(err), + } + + // We are now polling for the current events. + socket.status = SocketStatus::Polling { + readable: socket.interest.readable, + writable: socket.interest.writable, + }; + + Ok(()) + } + } + } + + /// This socket state was notified; see if we need to update it. + fn feed_event(self: Pin>, poller: &Poller) -> io::Result { + let inner = self.as_ref().data().project_ref(); + + let (afd_info, socket) = match inner { + PacketInnerProj::Socket { packet, socket } => (packet, socket), + PacketInnerProj::Wakeup { .. } => { + // The poller was notified. + return Ok(FeedEventResult::Notified); + } + }; + + let mut socket_state = socket.lock().unwrap_or_else(|e| e.into_inner()); + let mut event = Event::none(socket_state.interest.key); + + // Put ourselves into the idle state. + socket_state.status = SocketStatus::Idle; + + // If we are waiting to be deleted, just return and let the drop handler do their thing. + if socket_state.waiting_on_delete { + return Ok(FeedEventResult::NoEvent); + } + + unsafe { + // SAFETY: The packet is not in transit. + let iosb = &mut *self.as_ref().iosb().get(); + + // Check the status. + match iosb.Anonymous.Status { + STATUS_CANCELLED => { + // Poll request was cancelled. + } + + status if status < 0 => { + // There was an error, so we signal both ends. + event.readable = true; + event.writable = true; + } + + _ => { + // Check in on the AFD data. + let afd_data = &*afd_info.get(); + + if afd_data.handle_count() >= 1 { + let events = afd_data.events(); + + // If we closed the socket, remove it from being polled. + if events.contains(AfdPollMask::LOCAL_CLOSE) { + let source = { + let mut sources = + poller.sources.write().unwrap_or_else(|e| e.into_inner()); + sources.remove(&socket_state.socket).unwrap() + }; + return source.begin_delete().map(|()| FeedEventResult::NoEvent); + } + + // Report socket-related events. + let (readable, writable) = afd_mask_to_event(events); + event.readable = readable; + event.writable = writable; + } + } + } + } + + // Filter out events that the user didn't ask for. + { + if !socket_state.interest.readable { + event.readable = false; + } + + if !socket_state.interest.writable { + event.writable = false; + } + } + + // If this event doesn't have anything that interests us, don't return or + // update the oneshot state. + let return_value = if event.readable || event.writable { + // If we are in oneshot mode, remove the interest. + if matches!(socket_state.mode, PollMode::Oneshot) { + socket_state.interest = Event::none(socket_state.interest.key); + socket_state.interest_error = false; + } + + FeedEventResult::Event(event) + } else { + FeedEventResult::NoEvent + }; + + // Put ourselves in the update queue. + drop(socket_state); + poller.update_packet(self)?; + + // Return the event. + Ok(return_value) + } + + /// Begin deleting this socket. + fn begin_delete(self: Pin>) -> io::Result<()> { + // If we aren't already being deleted, start deleting. + let mut socket = self + .as_ref() + .socket_state() + .expect("can't delete notification packet"); + if !socket.waiting_on_delete { + socket.waiting_on_delete = true; + + if matches!(socket.status, SocketStatus::Polling { .. }) { + // Cancel the ongoing poll. + self.cancel(socket)?; + } + } + + // Either drop it now or wait for it to be dropped later. + Ok(()) + } + + fn cancel(self: &Pin>, mut socket: MutexGuard<'_, SocketState>) -> io::Result<()> { + assert!(matches!(socket.status, SocketStatus::Polling { .. })); + + // Send the cancel request. + unsafe { + socket.afd.cancel(self)?; + } + + // Move state to cancelled. + socket.status = SocketStatus::Cancelled; + + Ok(()) + } + + fn socket_state(self: Pin<&Self>) -> Option> { + let inner = self.data().project_ref(); + + let state = match inner { + PacketInnerProj::Wakeup { .. } => return None, + PacketInnerProj::Socket { socket, .. } => socket, + }; + + let guard = state.lock().unwrap_or_else(|e| e.into_inner()); + Some(guard) + } +} + +/// Per-socket state. +#[derive(Debug)] +struct SocketState { + /// The raw socket handle. + socket: RawSocket, + + /// The base socket handle. + base_socket: RawSocket, + + /// The event that this socket is currently waiting on. + interest: Event, + + /// Whether to listen for error events. + interest_error: bool, + + /// The current poll mode. + mode: PollMode, + + /// The AFD instance that this socket is registered with. + afd: Arc>, + + /// Whether this socket is waiting to be deleted. + waiting_on_delete: bool, + + /// The current status of the socket. + status: SocketStatus, +} + +/// The mode that a socket can be in. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum SocketStatus { + /// We are currently not polling. + Idle, + + /// We are currently polling these events. + Polling { + /// We are currently polling for readable events. + readable: bool, + + /// We are currently polling for writable events. + writable: bool, + }, + + /// The last poll operation was cancelled, and we're waiting for it to + /// complete. + Cancelled, +} + +/// The result of calling `feed_event`. +#[derive(Debug)] +enum FeedEventResult { + /// No event was yielded. + NoEvent, + + /// An event was yielded. + Event(Event), + + /// The poller has been notified. + Notified, +} + +fn event_to_afd_mask(readable: bool, writable: bool, error: bool) -> afd::AfdPollMask { + use afd::AfdPollMask as AfdPoll; + + let mut mask = AfdPoll::empty(); + + if error || readable || writable { + mask |= AfdPoll::ABORT | AfdPoll::CONNECT_FAIL; + } + + if readable { + mask |= + AfdPoll::RECEIVE | AfdPoll::ACCEPT | AfdPoll::DISCONNECT | AfdPoll::RECEIVE_EXPEDITED; + } + + if writable { + mask |= AfdPoll::SEND; + } + + mask +} + +fn afd_mask_to_event(mask: afd::AfdPollMask) -> (bool, bool) { + use afd::AfdPollMask as AfdPoll; + + let mut readable = false; + let mut writable = false; + + if mask.intersects( + AfdPoll::RECEIVE | AfdPoll::ACCEPT | AfdPoll::DISCONNECT | AfdPoll::RECEIVE_EXPEDITED, + ) { + readable = true; + } + + if mask.intersects(AfdPoll::SEND) { + writable = true; + } + + if mask.intersects(AfdPoll::ABORT | AfdPoll::CONNECT_FAIL) { + readable = true; + writable = true; + } + + (readable, writable) +} + +struct CallOnDrop(F); + +impl Drop for CallOnDrop { + fn drop(&mut self) { + (self.0)(); + } +} diff --git a/src/iocp/port.rs b/src/iocp/port.rs new file mode 100644 index 0000000..ccc9575 --- /dev/null +++ b/src/iocp/port.rs @@ -0,0 +1,317 @@ +//! A safe wrapper around the Windows I/O API. + +use std::convert::{TryFrom, TryInto}; +use std::fmt; +use std::io; +use std::marker::PhantomData; +use std::mem::MaybeUninit; +use std::ops::Deref; +use std::os::windows::io::{AsRawHandle, RawHandle}; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; + +use windows_sys::Win32::Foundation::{CloseHandle, HANDLE, INVALID_HANDLE_VALUE}; +use windows_sys::Win32::Storage::FileSystem::SetFileCompletionNotificationModes; +use windows_sys::Win32::System::WindowsProgramming::{FILE_SKIP_SET_EVENT_ON_HANDLE, INFINITE}; +use windows_sys::Win32::System::IO::{ + CreateIoCompletionPort, GetQueuedCompletionStatusEx, PostQueuedCompletionStatus, OVERLAPPED, + OVERLAPPED_ENTRY, +}; + +/// A completion block which can be used with I/O completion ports. +/// +/// # Safety +/// +/// This must be a valid completion block. +pub(super) unsafe trait Completion { + /// Signal to the completion block that we are about to start an operation. + fn try_lock(self: Pin<&Self>) -> bool; + + /// Unlock the completion block. + unsafe fn unlock(self: Pin<&Self>); +} + +/// The pointer to a completion block. +/// +/// # Safety +/// +/// This must be a valid completion block. +pub(super) unsafe trait CompletionHandle: Deref + Sized { + /// Type of the completion block. + type Completion: Completion; + + /// Get a pointer to the completion block. + fn get(&self) -> Pin<&Self::Completion>; + + /// Convert this block into a pointer that can be passed as `*mut OVERLAPPED`. + fn into_ptr(this: Self) -> *mut OVERLAPPED; + + /// Convert a pointer that was passed as `*mut OVERLAPPED` into a pointer to this block. + /// + /// # Safety + /// + /// This must be a valid pointer to a completion block. + unsafe fn from_ptr(ptr: *mut OVERLAPPED) -> Self; + + /// Convert to a pointer without losing ownership. + fn as_ptr(&self) -> *mut OVERLAPPED; +} + +unsafe impl<'a, T: Completion> CompletionHandle for Pin<&'a T> { + type Completion = T; + + fn get(&self) -> Pin<&Self::Completion> { + *self + } + + fn into_ptr(this: Self) -> *mut OVERLAPPED { + unsafe { Pin::into_inner_unchecked(this) as *const T as *mut OVERLAPPED } + } + + unsafe fn from_ptr(ptr: *mut OVERLAPPED) -> Self { + Pin::new_unchecked(&*(ptr as *const T)) + } + + fn as_ptr(&self) -> *mut OVERLAPPED { + self.get_ref() as *const T as *mut OVERLAPPED + } +} + +unsafe impl CompletionHandle for Pin> { + type Completion = T; + + fn get(&self) -> Pin<&Self::Completion> { + self.as_ref() + } + + fn into_ptr(this: Self) -> *mut OVERLAPPED { + unsafe { Arc::into_raw(Pin::into_inner_unchecked(this)) as *const T as *mut OVERLAPPED } + } + + unsafe fn from_ptr(ptr: *mut OVERLAPPED) -> Self { + Pin::new_unchecked(Arc::from_raw(ptr as *const T)) + } + + fn as_ptr(&self) -> *mut OVERLAPPED { + self.as_ref().get_ref() as *const T as *mut OVERLAPPED + } +} + +/// A handle to the I/O completion port. +pub(super) struct IoCompletionPort { + /// The underlying handle. + handle: HANDLE, + + /// We own the status block. + _marker: PhantomData, +} + +impl Drop for IoCompletionPort { + fn drop(&mut self) { + unsafe { + CloseHandle(self.handle); + } + } +} + +impl AsRawHandle for IoCompletionPort { + fn as_raw_handle(&self) -> RawHandle { + self.handle as _ + } +} + +impl fmt::Debug for IoCompletionPort { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct WriteAsHex(HANDLE); + + impl fmt::Debug for WriteAsHex { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:010x}", self.0) + } + } + + f.debug_struct("IoCompletionPort") + .field("handle", &WriteAsHex(self.handle)) + .finish() + } +} + +impl IoCompletionPort { + /// Create a new I/O completion port. + pub(super) fn new(threads: usize) -> io::Result { + let handle = unsafe { + CreateIoCompletionPort( + INVALID_HANDLE_VALUE, + 0, + 0, + threads.try_into().expect("too many threads"), + ) + }; + + if handle == 0 { + Err(io::Error::last_os_error()) + } else { + Ok(Self { + handle, + _marker: PhantomData, + }) + } + } + + /// Register a handle with this I/O completion port. + pub(super) fn register( + &self, + handle: &impl AsRawHandle, // TODO change to AsHandle + skip_set_event_on_handle: bool, + ) -> io::Result<()> { + let handle = handle.as_raw_handle(); + + let result = + unsafe { CreateIoCompletionPort(handle as _, self.handle, handle as usize, 0) }; + + if result == 0 { + return Err(io::Error::last_os_error()); + } + + if skip_set_event_on_handle { + // Set the skip event on handle. + let result = unsafe { + SetFileCompletionNotificationModes(handle as _, FILE_SKIP_SET_EVENT_ON_HANDLE as _) + }; + + if result == 0 { + return Err(io::Error::last_os_error()); + } + } + + Ok(()) + } + + /// Post a completion packet to this port. + pub(super) fn post(&self, bytes_transferred: usize, id: usize, packet: T) -> io::Result<()> { + let result = unsafe { + PostQueuedCompletionStatus( + self.handle, + bytes_transferred + .try_into() + .expect("too many bytes transferred"), + id, + T::into_ptr(packet), + ) + }; + + if result == 0 { + Err(io::Error::last_os_error()) + } else { + Ok(()) + } + } + + /// Wait for completion packets to arrive. + pub(super) fn wait( + &self, + packets: &mut Vec>, + timeout: Option, + ) -> io::Result { + // Drop the current packets. + packets.clear(); + + let mut count = MaybeUninit::::uninit(); + let timeout = timeout.map_or(INFINITE, dur2timeout); + + let result = unsafe { + GetQueuedCompletionStatusEx( + self.handle, + packets.as_mut_ptr() as _, + packets.capacity().try_into().expect("too many packets"), + count.as_mut_ptr(), + timeout, + 0, + ) + }; + + if result == 0 { + let io_error = io::Error::last_os_error(); + if io_error.kind() == io::ErrorKind::TimedOut { + Ok(0) + } else { + Err(io_error) + } + } else { + let count = unsafe { count.assume_init() }; + unsafe { + packets.set_len(count as _); + } + Ok(count as _) + } + } +} + +/// An `OVERLAPPED_ENTRY` resulting from an I/O completion port. +#[repr(transparent)] +pub(super) struct OverlappedEntry { + /// The underlying entry. + entry: OVERLAPPED_ENTRY, + + /// We own the status block. + _marker: PhantomData, +} + +impl fmt::Debug for OverlappedEntry { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("OverlappedEntry { .. }") + } +} + +impl OverlappedEntry { + /// Convert into the completion packet. + pub(super) fn into_packet(self) -> T { + let packet = unsafe { self.packet() }; + std::mem::forget(self); + packet + } + + unsafe fn packet(&self) -> T { + let packet = T::from_ptr(self.entry.lpOverlapped); + packet.get().unlock(); + packet + } +} + +impl Drop for OverlappedEntry { + fn drop(&mut self) { + drop(unsafe { self.packet() }); + } +} + +// Implementation taken from https://github.com/rust-lang/rust/blob/db5476571d9b27c862b95c1e64764b0ac8980e23/src/libstd/sys/windows/mod.rs +fn dur2timeout(dur: Duration) -> u32 { + // Note that a duration is a (u64, u32) (seconds, nanoseconds) pair, and the + // timeouts in windows APIs are typically u32 milliseconds. To translate, we + // have two pieces to take care of: + // + // * Nanosecond precision is rounded up + // * Greater than u32::MAX milliseconds (50 days) is rounded up to INFINITE + // (never time out). + dur.as_secs() + .checked_mul(1000) + .and_then(|ms| ms.checked_add((dur.subsec_nanos() as u64) / 1_000_000)) + .and_then(|ms| { + if dur.subsec_nanos() % 1_000_000 > 0 { + ms.checked_add(1) + } else { + Some(ms) + } + }) + .and_then(|x| u32::try_from(x).ok()) + .unwrap_or(INFINITE) +} + +struct CallOnDrop(F); + +impl Drop for CallOnDrop { + fn drop(&mut self) { + (self.0)(); + } +} diff --git a/src/lib.rs b/src/lib.rs index a6795c7..2e2b41f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -113,8 +113,8 @@ cfg_if! { mod poll; use poll as sys; } else if #[cfg(target_os = "windows")] { - mod wepoll; - use wepoll as sys; + mod iocp; + use iocp as sys; } else { compile_error!("polling does not support this target OS"); } diff --git a/src/wepoll.rs b/src/wepoll.rs deleted file mode 100644 index 6c65266..0000000 --- a/src/wepoll.rs +++ /dev/null @@ -1,254 +0,0 @@ -//! Bindings to wepoll (Windows). - -use std::convert::TryInto; -use std::io; -use std::os::raw::c_int; -use std::os::windows::io::{AsRawHandle, RawHandle, RawSocket}; -use std::ptr; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::time::{Duration, Instant}; - -#[cfg(not(polling_no_io_safety))] -use std::os::windows::io::{AsHandle, BorrowedHandle}; - -use wepoll_ffi as we; - -use crate::{Event, PollMode}; - -/// Calls a wepoll function and results in `io::Result`. -macro_rules! wepoll { - ($fn:ident $args:tt) => {{ - let res = unsafe { we::$fn $args }; - if res == -1 { - Err(std::io::Error::last_os_error()) - } else { - Ok(res) - } - }}; -} - -/// Interface to wepoll. -#[derive(Debug)] -pub struct Poller { - handle: we::HANDLE, - notified: AtomicBool, -} - -unsafe impl Send for Poller {} -unsafe impl Sync for Poller {} - -impl Poller { - /// Creates a new poller. - pub fn new() -> io::Result { - let handle = unsafe { we::epoll_create1(0) }; - if handle.is_null() { - return Err(crate::unsupported_error( - format!( - "Failed to initialize Wepoll: {}\nThis usually only happens for old Windows or Wine.", - io::Error::last_os_error() - ) - )); - } - let notified = AtomicBool::new(false); - log::trace!("new: handle={:?}", handle); - Ok(Poller { handle, notified }) - } - - /// Whether this poller supports level-triggered events. - pub fn supports_level(&self) -> bool { - true - } - - /// Whether this poller supports edge-triggered events. - pub fn supports_edge(&self) -> bool { - false - } - - /// Adds a socket. - pub fn add(&self, sock: RawSocket, ev: Event, mode: PollMode) -> io::Result<()> { - log::trace!("add: handle={:?}, sock={}, ev={:?}", self.handle, sock, ev); - self.ctl(we::EPOLL_CTL_ADD, sock, Some((ev, mode))) - } - - /// Modifies a socket. - pub fn modify(&self, sock: RawSocket, ev: Event, mode: PollMode) -> io::Result<()> { - log::trace!( - "modify: handle={:?}, sock={}, ev={:?}", - self.handle, - sock, - ev - ); - self.ctl(we::EPOLL_CTL_MOD, sock, Some((ev, mode))) - } - - /// Deletes a socket. - pub fn delete(&self, sock: RawSocket) -> io::Result<()> { - log::trace!("remove: handle={:?}, sock={}", self.handle, sock); - self.ctl(we::EPOLL_CTL_DEL, sock, None) - } - - /// Waits for I/O events with an optional timeout. - /// - /// Returns the number of processed I/O events. - /// - /// If a notification occurs, this method will return but the notification event will not be - /// included in the `events` list nor contribute to the returned count. - pub fn wait(&self, events: &mut Events, timeout: Option) -> io::Result<()> { - log::trace!("wait: handle={:?}, timeout={:?}", self.handle, timeout); - let deadline = timeout.and_then(|t| Instant::now().checked_add(t)); - - loop { - // Convert the timeout to milliseconds. - let timeout_ms = match deadline.map(|d| d.saturating_duration_since(Instant::now())) { - None => -1, - Some(t) => { - // Round up to a whole millisecond. - let mut ms = t.as_millis().try_into().unwrap_or(std::u64::MAX); - if Duration::from_millis(ms) < t { - ms = ms.saturating_add(1); - } - ms.try_into().unwrap_or(std::i32::MAX) - } - }; - - // Wait for I/O events. - events.len = wepoll!(epoll_wait( - self.handle, - events.list.as_mut_ptr(), - events.list.len() as c_int, - timeout_ms, - ))? as usize; - log::trace!("new events: handle={:?}, len={}", self.handle, events.len); - - // Break if there was a notification or at least one event, or if deadline is reached. - if self.notified.swap(false, Ordering::SeqCst) || events.len > 0 || timeout_ms == 0 { - break; - } - } - - Ok(()) - } - - /// Sends a notification to wake up the current or next `wait()` call. - pub fn notify(&self) -> io::Result<()> { - log::trace!("notify: handle={:?}", self.handle); - - if self - .notified - .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) - .is_ok() - { - unsafe { - // This call errors if a notification has already been posted, but that's okay - we - // can just ignore the error. - // - // The original wepoll does not support notifications triggered this way, which is - // why wepoll-sys includes a small patch to support them. - windows_sys::Win32::System::IO::PostQueuedCompletionStatus( - self.handle as _, - 0, - 0, - ptr::null_mut(), - ); - } - } - Ok(()) - } - - /// Passes arguments to `epoll_ctl`. - fn ctl(&self, op: u32, sock: RawSocket, ev: Option<(Event, PollMode)>) -> io::Result<()> { - let mut ev = ev - .map(|(ev, mode)| { - let mut flags = match mode { - PollMode::Level => 0, - PollMode::Oneshot => we::EPOLLONESHOT, - PollMode::Edge => { - return Err(crate::unsupported_error( - "edge-triggered events are not supported with wepoll", - )); - } - }; - if ev.readable { - flags |= READ_FLAGS; - } - if ev.writable { - flags |= WRITE_FLAGS; - } - - Ok(we::epoll_event { - events: flags as u32, - data: we::epoll_data { - u64_: ev.key as u64, - }, - }) - }) - .transpose()?; - wepoll!(epoll_ctl( - self.handle, - op as c_int, - sock as we::SOCKET, - ev.as_mut() - .map(|ev| ev as *mut we::epoll_event) - .unwrap_or(ptr::null_mut()), - ))?; - Ok(()) - } -} - -impl AsRawHandle for Poller { - fn as_raw_handle(&self) -> RawHandle { - self.handle as RawHandle - } -} - -#[cfg(not(polling_no_io_safety))] -impl AsHandle for Poller { - fn as_handle(&self) -> BorrowedHandle<'_> { - // SAFETY: lifetime is bound by "self" - unsafe { BorrowedHandle::borrow_raw(self.as_raw_handle()) } - } -} - -impl Drop for Poller { - fn drop(&mut self) { - log::trace!("drop: handle={:?}", self.handle); - unsafe { - we::epoll_close(self.handle); - } - } -} - -/// Wepoll flags for all possible readability events. -const READ_FLAGS: u32 = we::EPOLLIN | we::EPOLLRDHUP | we::EPOLLHUP | we::EPOLLERR | we::EPOLLPRI; - -/// Wepoll flags for all possible writability events. -const WRITE_FLAGS: u32 = we::EPOLLOUT | we::EPOLLHUP | we::EPOLLERR; - -/// A list of reported I/O events. -pub struct Events { - list: Box<[we::epoll_event; 1024]>, - len: usize, -} - -unsafe impl Send for Events {} - -impl Events { - /// Creates an empty list. - pub fn new() -> Events { - let ev = we::epoll_event { - events: 0, - data: we::epoll_data { u64_: 0 }, - }; - let list = Box::new([ev; 1024]); - Events { list, len: 0 } - } - - /// Iterates over I/O events. - pub fn iter(&self) -> impl Iterator + '_ { - self.list[..self.len].iter().map(|ev| Event { - key: unsafe { ev.data.u64_ } as usize, - readable: (ev.events & READ_FLAGS) != 0, - writable: (ev.events & WRITE_FLAGS) != 0, - }) - } -} diff --git a/tests/concurrent_modification.rs b/tests/concurrent_modification.rs index 0687ad5..7f31f05 100644 --- a/tests/concurrent_modification.rs +++ b/tests/concurrent_modification.rs @@ -43,7 +43,7 @@ fn concurrent_modify() -> io::Result<()> { Parallel::new() .add(|| { - poller.wait(&mut events, None)?; + poller.wait(&mut events, Some(Duration::from_secs(10)))?; Ok(()) }) .add(|| { diff --git a/tests/io.rs b/tests/io.rs new file mode 100644 index 0000000..ab0c8a8 --- /dev/null +++ b/tests/io.rs @@ -0,0 +1,38 @@ +use polling::{Event, Poller}; +use std::io::{self, Write}; +use std::net::{TcpListener, TcpStream}; +use std::time::Duration; + +#[test] +fn basic_io() { + let poller = Poller::new().unwrap(); + let (read, mut write) = tcp_pair().unwrap(); + poller.add(&read, Event::readable(1)).unwrap(); + + // Nothing should be available at first. + let mut events = vec![]; + assert_eq!( + poller + .wait(&mut events, Some(Duration::from_secs(0))) + .unwrap(), + 0 + ); + assert!(events.is_empty()); + + // After a write, the event should be available now. + write.write_all(&[1]).unwrap(); + assert_eq!( + poller + .wait(&mut events, Some(Duration::from_secs(1))) + .unwrap(), + 1 + ); + assert_eq!(&*events, &[Event::readable(1)]); +} + +fn tcp_pair() -> io::Result<(TcpStream, TcpStream)> { + let listener = TcpListener::bind("127.0.0.1:0")?; + let a = TcpStream::connect(listener.local_addr()?)?; + let (b, _) = listener.accept()?; + Ok((a, b)) +} diff --git a/tests/precision.rs b/tests/precision.rs index d29bbce..de5d605 100644 --- a/tests/precision.rs +++ b/tests/precision.rs @@ -18,7 +18,7 @@ fn below_ms() -> io::Result<()> { let elapsed = now.elapsed(); assert_eq!(n, 0); - assert!(elapsed >= dur); + assert!(elapsed >= dur, "{:?} < {:?}", elapsed, dur); lowest = lowest.min(elapsed); } @@ -54,7 +54,7 @@ fn above_ms() -> io::Result<()> { let elapsed = now.elapsed(); assert_eq!(n, 0); - assert!(elapsed >= dur); + assert!(elapsed >= dur, "{:?} < {:?}", elapsed, dur); lowest = lowest.min(elapsed); } From 3195bc2ef3f37f04e69d589aac2c85e5cb8a69e7 Mon Sep 17 00:00:00 2001 From: John Nunley Date: Tue, 21 Feb 2023 15:27:01 -0800 Subject: [PATCH 2/8] Fix documentation and add better errors --- Cargo.toml | 4 ++-- README.md | 4 ++-- src/iocp/mod.rs | 39 +++++++++++++++++++++++++++++++++++++-- src/lib.rs | 4 ++-- 4 files changed, 43 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e6e57ad..3e86495 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,10 +7,10 @@ version = "2.5.2" authors = ["Stjepan Glavina "] edition = "2018" rust-version = "1.47" -description = "Portable interface to epoll, kqueue, event ports, and wepoll" +description = "Portable interface to epoll, kqueue, event ports, and IOCP" license = "Apache-2.0 OR MIT" repository = "https://github.com/smol-rs/polling" -keywords = ["mio", "epoll", "kqueue", "iocp", "wepoll"] +keywords = ["mio", "epoll", "kqueue", "iocp"] categories = ["asynchronous", "network-programming", "os"] exclude = ["/.*"] diff --git a/README.md b/README.md index 9d54a5c..67093bf 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ https://crates.io/crates/polling) [![Documentation](https://docs.rs/polling/badge.svg)]( https://docs.rs/polling) -Portable interface to epoll, kqueue, event ports, and wepoll. +Portable interface to epoll, kqueue, event ports, and IOCP. Supported platforms: - [epoll](https://en.wikipedia.org/wiki/Epoll): Linux, Android @@ -17,7 +17,7 @@ Supported platforms: DragonFly BSD - [event ports](https://illumos.org/man/port_create): illumos, Solaris - [poll](https://en.wikipedia.org/wiki/Poll_(Unix)): VxWorks, Fuchsia, other Unix systems -- [wepoll](https://github.com/piscisaureus/wepoll): Windows, Wine (version 7.13+) +- [IOCP](https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports): Windows, Wine (version 7.13+) Polling is done in oneshot mode, which means interest in I/O events needs to be reset after an event is delivered if we're interested in the next event of the same kind. diff --git a/src/iocp/mod.rs b/src/iocp/mod.rs index 9762de5..277aa15 100644 --- a/src/iocp/mod.rs +++ b/src/iocp/mod.rs @@ -1,4 +1,29 @@ //! Bindings to Windows I/O Completion Ports. +//! +//! I/O Completion Ports is a completion-based API rather than a polling-based API, like +//! epoll or kqueue. Therefore, we have to adapt the IOCP API to the crate's API. +//! +//! WinSock is powered by the Auxillary Function Driver (AFD) subsystem, which can be +//! accessed directly by using unstable `ntdll` functions. AFD exposes features that are not +//! available through the normal WinSock interface, such as IOCTL_AFD_POLL. This function is +//! similar to the exposed `WSAPoll` method. However, once the targeted socket is "ready", +//! a completion packet is queued to an I/O completion port. +//! +//! We take advantage of IOCTL_AFD_POLL to "translate" this crate's polling-based API +//! to the one Windows expects. When a device is added to the `Poller`, an IOCTL_AFD_POLL +//! operation is started and queued to the IOCP. To modify a currently registered device +//! (e.g. with `modify()` or `delete()`), the ongoing POLL is cancelled and then restarted +//! with new parameters. Whn the POLL eventually completes, the packet is posted to the IOCP. +//! From here it's a simple matter of using `GetQueuedCompletionStatusEx` to read the packets +//! from the IOCP and react accordingly. Notifying the poller is trivial, because we can +//! simply post a packet to the IOCP to wake it up. +//! +//! The main disadvantage of this strategy is that it relies on unstable Windows APIs. +//! However, as `libuv` (the backing I/O library for Node.JS) relies on the same unstable +//! AFD strategy, it is unlikely to be broken without plenty of advanced warning. +//! +//! Previously, this crate used the `wepoll` library for polling. `wepoll` uses a similar +//! AFD-based strategy for polling. mod afd; mod port; @@ -60,18 +85,28 @@ impl Poller { // Make sure AFD is able to be used. if let Err(e) = afd::NtdllImports::force_load() { return Err(crate::unsupported_error(format!( - "Failed to initialize I/O completion ports: {}\nThis usually only happens for old Windows or Wine.", + "Failed to initialize unstable Windows functions: {}\nThis usually only happens for old Windows or Wine.", e ))); } + + // Create a single AFD to test if we support it. + let afd = match Afd::new() { + Ok(afd) => afd, + Err(e) => return Err(crate::unsupported_error(format!( + "Failed to initialize \\Device\\Afd: {}\nThis usually only happens for old Windows or Wine.", + e, + ))), + }; let port = IoCompletionPort::new(0)?; + port.register(&afd, true)?; log::trace!("new: handle={:?}", &port); Ok(Poller { port, - afd: Mutex::new(Vec::new()), + afd: Mutex::new(vec![Arc::new(afd)]), sources: RwLock::new(HashMap::new()), pending_updates: ConcurrentQueue::bounded(1024), polling: AtomicBool::new(false), diff --git a/src/lib.rs b/src/lib.rs index 2e2b41f..6cbce52 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -//! Portable interface to epoll, kqueue, event ports, and wepoll. +//! Portable interface to epoll, kqueue, event ports, and IOCP. //! //! Supported platforms: //! - [epoll](https://en.wikipedia.org/wiki/Epoll): Linux, Android @@ -6,7 +6,7 @@ //! DragonFly BSD //! - [event ports](https://illumos.org/man/port_create): illumos, Solaris //! - [poll](https://en.wikipedia.org/wiki/Poll_(Unix)): VxWorks, Fuchsia, other Unix systems -//! - [wepoll](https://github.com/piscisaureus/wepoll): Windows, Wine (version 7.13+) +//! - [IOCP](https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports): Windows, Wine (version 7.13+) //! //! By default, polling is done in oneshot mode, which means interest in I/O events needs to //! be re-enabled after an event is delivered if we're interested in the next event of the same From 321a8d0efa83e58be3a3d2733a4decd1ea6744a9 Mon Sep 17 00:00:00 2001 From: John Nunley Date: Tue, 21 Feb 2023 15:45:44 -0800 Subject: [PATCH 3/8] fmt --- src/iocp/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/iocp/mod.rs b/src/iocp/mod.rs index 277aa15..7324778 100644 --- a/src/iocp/mod.rs +++ b/src/iocp/mod.rs @@ -89,7 +89,7 @@ impl Poller { e ))); } - + // Create a single AFD to test if we support it. let afd = match Afd::new() { Ok(afd) => afd, From f1c12183950350883154b1267170f14e05398a76 Mon Sep 17 00:00:00 2001 From: jtnunley Date: Wed, 22 Feb 2023 14:42:03 -0800 Subject: [PATCH 4/8] Review comments --- src/iocp/afd.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/iocp/afd.rs b/src/iocp/afd.rs index a8a2de3..e2f9c8b 100644 --- a/src/iocp/afd.rs +++ b/src/iocp/afd.rs @@ -67,6 +67,7 @@ impl AfdPollInfo { bitflags::bitflags! { #[derive(Default)] + #[repr(transparent)] pub(super) struct AfdPollMask: u32 { const RECEIVE = 0x001; const RECEIVE_EXPEDITED = 0x002; @@ -582,7 +583,7 @@ pub(super) fn base_socket(sock: RawSocket) -> io::Result { /// /// # Safety /// -/// The socket must be valid. +/// The `ioctl` parameter must be a valid I/O control that returns a valid socket. unsafe fn try_socket_ioctl(sock: RawSocket, ioctl: u32) -> io::Result { let mut out = MaybeUninit::::uninit(); let mut bytes = MaybeUninit::::uninit(); From 7f7ee5cdd07186c17f57b395b26310d693317cf4 Mon Sep 17 00:00:00 2001 From: jtnunley Date: Sun, 5 Mar 2023 14:11:49 -0800 Subject: [PATCH 5/8] Code review comments --- src/iocp/afd.rs | 4 ++-- src/iocp/mod.rs | 3 ++- src/iocp/port.rs | 10 ++++++++++ 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/iocp/afd.rs b/src/iocp/afd.rs index e2f9c8b..4a6600f 100644 --- a/src/iocp/afd.rs +++ b/src/iocp/afd.rs @@ -586,7 +586,7 @@ pub(super) fn base_socket(sock: RawSocket) -> io::Result { /// The `ioctl` parameter must be a valid I/O control that returns a valid socket. unsafe fn try_socket_ioctl(sock: RawSocket, ioctl: u32) -> io::Result { let mut out = MaybeUninit::::uninit(); - let mut bytes = MaybeUninit::::uninit(); + let mut bytes = 0u32; let result = WSAIoctl( sock as _, @@ -595,7 +595,7 @@ unsafe fn try_socket_ioctl(sock: RawSocket, ioctl: u32) -> io::Result 0, out.as_mut_ptr().cast(), size_of::() as u32, - bytes.as_mut_ptr(), + &mut bytes, ptr::null_mut(), None, ); diff --git a/src/iocp/mod.rs b/src/iocp/mod.rs index 7324778..13c781e 100644 --- a/src/iocp/mod.rs +++ b/src/iocp/mod.rs @@ -262,7 +262,8 @@ impl Poller { let mut new_events = 0; // Indicate that we are now polling. - debug_assert!(!self.polling.swap(true, Ordering::SeqCst)); + let was_polling = self.polling.swap(true, Ordering::SeqCst); + debug_assert!(!was_polling); let guard = CallOnDrop(|| { debug_assert!(self.polling.swap(false, Ordering::SeqCst)); diff --git a/src/iocp/port.rs b/src/iocp/port.rs index ccc9575..d322c39 100644 --- a/src/iocp/port.rs +++ b/src/iocp/port.rs @@ -42,6 +42,10 @@ pub(super) unsafe trait CompletionHandle: Deref + Sized { type Completion: Completion; /// Get a pointer to the completion block. + /// + /// The pointer is pinned since the underlying object should not be moved + /// after creation. This prevents it from being invalidated while it's + /// used in an overlapped operation. fn get(&self) -> Pin<&Self::Completion>; /// Convert this block into a pointer that can be passed as `*mut OVERLAPPED`. @@ -272,6 +276,12 @@ impl OverlappedEntry { packet } + /// Get the packet reference that this entry refers to. + /// + /// # Safety + /// + /// This function should only be called once, since it moves + /// out the `T` from the `OVERLAPPED_ENTRY`. unsafe fn packet(&self) -> T { let packet = T::from_ptr(self.entry.lpOverlapped); packet.get().unlock(); From c3ab16109ce2a86aaafe1717c56deb3c83a5f9de Mon Sep 17 00:00:00 2001 From: jtnunley Date: Sun, 5 Mar 2023 14:16:54 -0800 Subject: [PATCH 6/8] fmt --- src/iocp/port.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/iocp/port.rs b/src/iocp/port.rs index d322c39..3feae07 100644 --- a/src/iocp/port.rs +++ b/src/iocp/port.rs @@ -42,7 +42,7 @@ pub(super) unsafe trait CompletionHandle: Deref + Sized { type Completion: Completion; /// Get a pointer to the completion block. - /// + /// /// The pointer is pinned since the underlying object should not be moved /// after creation. This prevents it from being invalidated while it's /// used in an overlapped operation. @@ -277,9 +277,9 @@ impl OverlappedEntry { } /// Get the packet reference that this entry refers to. - /// + /// /// # Safety - /// + /// /// This function should only be called once, since it moves /// out the `T` from the `OVERLAPPED_ENTRY`. unsafe fn packet(&self) -> T { From 3a41a8e2a6f57e04497ea07d4f3835dca1dd9f91 Mon Sep 17 00:00:00 2001 From: jtnunley Date: Sun, 5 Mar 2023 15:47:06 -0800 Subject: [PATCH 7/8] Code review #2 --- src/iocp/mod.rs | 135 ++++++++++++++++++++++++++---------------------- 1 file changed, 74 insertions(+), 61 deletions(-) diff --git a/src/iocp/mod.rs b/src/iocp/mod.rs index 13c781e..4fafc01 100644 --- a/src/iocp/mod.rs +++ b/src/iocp/mod.rs @@ -45,12 +45,19 @@ use std::marker::PhantomPinned; use std::os::windows::io::{AsRawHandle, RawHandle, RawSocket}; use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, Mutex, MutexGuard, RwLock}; +use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak}; use std::time::{Duration, Instant}; #[cfg(not(polling_no_io_safety))] use std::os::windows::io::{AsHandle, BorrowedHandle}; +/// Macro to lock and ignore lock poisoning. +macro_rules! lock { + ($lock_result:expr) => {{ + ($lock_result).unwrap_or_else(|e| e.into_inner()) + }}; +} + /// Interface to I/O completion ports. #[derive(Debug)] pub(super) struct Poller { @@ -58,7 +65,10 @@ pub(super) struct Poller { port: IoCompletionPort, /// List of currently active AFD instances. - afd: Mutex>>>, + /// + /// Weak references are kept here so that the AFD handle is automatically dropped + /// when the last associated socket is dropped. + afd: Mutex>>>, /// The state of the sources registered with this poller. sources: RwLock>, @@ -90,23 +100,19 @@ impl Poller { ))); } - // Create a single AFD to test if we support it. - let afd = match Afd::new() { - Ok(afd) => afd, - Err(e) => return Err(crate::unsupported_error(format!( - "Failed to initialize \\Device\\Afd: {}\nThis usually only happens for old Windows or Wine.", - e, - ))), - }; + // Create and destroy a single AFD to test if we support it. + Afd::::new().map_err(|e| crate::unsupported_error(format!( + "Failed to initialize \\Device\\Afd: {}\nThis usually only happens for old Windows or Wine.", + e, + )))?; let port = IoCompletionPort::new(0)?; - port.register(&afd, true)?; log::trace!("new: handle={:?}", &port); Ok(Poller { port, - afd: Mutex::new(vec![Arc::new(afd)]), + afd: Mutex::new(vec![]), sources: RwLock::new(HashMap::new()), pending_updates: ConcurrentQueue::bounded(1024), polling: AtomicBool::new(false), @@ -168,7 +174,7 @@ impl Poller { // Keep track of the source in the poller. { - let mut sources = self.sources.write().unwrap_or_else(|e| e.into_inner()); + let mut sources = lock!(self.sources.write()); match sources.entry(socket) { Entry::Vacant(v) => { @@ -209,14 +215,12 @@ impl Poller { // Get a reference to the source. let source = { - let sources = self.sources.read().unwrap_or_else(|e| e.into_inner()); + let sources = lock!(self.sources.read()); - match sources.get(&socket) { - Some(s) => s.clone(), - None => { - return Err(io::Error::from(io::ErrorKind::NotFound)); - } - } + sources + .get(&socket) + .cloned() + .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))? }; // Set the new event. @@ -233,12 +237,12 @@ impl Poller { // Get a reference to the source. let source = { - let mut sources = self.sources.write().unwrap_or_else(|e| e.into_inner()); + let mut sources = lock!(self.sources.write()); match sources.remove(&socket) { Some(s) => s, None => { - // Just return. + // If the source has already been removed, then we can just return. return Ok(()); } } @@ -254,7 +258,7 @@ impl Poller { log::trace!("wait: handle={:?}, timeout={:?}", self.port, timeout); let deadline = timeout.and_then(|timeout| Instant::now().checked_add(timeout)); - let mut packets = self.packets.lock().unwrap_or_else(|e| e.into_inner()); + let mut packets = lock!(self.packets.lock()); let mut notified = false; events.packets.clear(); @@ -266,7 +270,8 @@ impl Poller { debug_assert!(!was_polling); let guard = CallOnDrop(|| { - debug_assert!(self.polling.swap(false, Ordering::SeqCst)); + let was_polling = self.polling.swap(false, Ordering::SeqCst); + debug_assert!(was_polling); }); // Process every entry in the queue before we start polling. @@ -347,38 +352,56 @@ impl Poller { }; // Only drain the queue's capacity, since this could in theory run forever. - for _ in 0..max { - if let Ok(packet) = self.pending_updates.pop() { - packet.update()?; - } else { - return Ok(()); - } - } - - Ok(()) + core::iter::from_fn(|| self.pending_updates.pop().ok()) + .take(max) + .try_for_each(|packet| packet.update()) } /// Get a handle to the AFD reference. fn afd_handle(&self) -> io::Result>> { const AFD_MAX_SIZE: usize = 32; - // See if there are any existing AFD instances that we can use. - let mut afd_handles = self.afd.lock().unwrap_or_else(|e| e.into_inner()); - if let Some(handle) = afd_handles.iter().find(|h| { - let ref_count = Arc::strong_count(h).saturating_sub(1); - ref_count < AFD_MAX_SIZE - }) { - return Ok(handle.clone()); + // Crawl the list and see if there are any existing AFD instances that we can use. + // Remove any unused AFD pointers. + let mut afd_handles = lock!(self.afd.lock()); + let mut i = 0; + while i < afd_handles.len() { + // Get the reference count of the AFD instance. + let refcount = Weak::strong_count(&afd_handles[i]); + + match refcount { + 0 => { + // Prune the AFD pointer if it has no references. + afd_handles.swap_remove(i); + } + + refcount if refcount >= AFD_MAX_SIZE => { + // Skip this one, since it is already at the maximum size. + i += 1; + } + + _ => { + // We can use this AFD instance. + match afd_handles[i].upgrade() { + Some(afd) => return Ok(afd), + None => { + // The last socket dropped the AFD before we could acquire it. + // Prune the AFD pointer and continue. + afd_handles.swap_remove(i); + } + } + } + } } - // Create a new AFD instance. + // No available handles, create a new AFD instance. let afd = Arc::new(Afd::new()?); // Register the AFD instance with the I/O completion port. self.port.register(&*afd, true)?; - // Insert a copy of the AFD instance into the list. - afd_handles.push(afd.clone()); + // Insert a weak pointer to the AFD instance into the list. + afd_handles.push(Arc::downgrade(&afd)); Ok(afd) } @@ -573,7 +596,7 @@ impl PacketUnwrapped { } }; - let mut socket_state = socket.lock().unwrap_or_else(|e| e.into_inner()); + let mut socket_state = lock!(socket.lock()); let mut event = Event::none(socket_state.interest.key); // Put ourselves into the idle state. @@ -609,11 +632,9 @@ impl PacketUnwrapped { // If we closed the socket, remove it from being polled. if events.contains(AfdPollMask::LOCAL_CLOSE) { - let source = { - let mut sources = - poller.sources.write().unwrap_or_else(|e| e.into_inner()); - sources.remove(&socket_state.socket).unwrap() - }; + let source = lock!(poller.sources.write()) + .remove(&socket_state.socket) + .unwrap(); return source.begin_delete().map(|()| FeedEventResult::NoEvent); } @@ -627,15 +648,8 @@ impl PacketUnwrapped { } // Filter out events that the user didn't ask for. - { - if !socket_state.interest.readable { - event.readable = false; - } - - if !socket_state.interest.writable { - event.writable = false; - } - } + event.readable &= socket_state.interest.readable; + event.writable &= socket_state.interest.writable; // If this event doesn't have anything that interests us, don't return or // update the oneshot state. @@ -665,7 +679,7 @@ impl PacketUnwrapped { let mut socket = self .as_ref() .socket_state() - .expect("can't delete notification packet"); + .expect("can't delete packet that doesn't belong to a socket"); if !socket.waiting_on_delete { socket.waiting_on_delete = true; @@ -701,8 +715,7 @@ impl PacketUnwrapped { PacketInnerProj::Socket { socket, .. } => socket, }; - let guard = state.lock().unwrap_or_else(|e| e.into_inner()); - Some(guard) + Some(lock!(state.lock())) } } From 2b75eead9d486d1208e65d9a259b19150b228fc6 Mon Sep 17 00:00:00 2001 From: jtnunley Date: Sun, 5 Mar 2023 16:04:02 -0800 Subject: [PATCH 8/8] Hygenic macro fix --- src/iocp/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/iocp/mod.rs b/src/iocp/mod.rs index 4fafc01..cc618b1 100644 --- a/src/iocp/mod.rs +++ b/src/iocp/mod.rs @@ -54,7 +54,7 @@ use std::os::windows::io::{AsHandle, BorrowedHandle}; /// Macro to lock and ignore lock poisoning. macro_rules! lock { ($lock_result:expr) => {{ - ($lock_result).unwrap_or_else(|e| e.into_inner()) + $lock_result.unwrap_or_else(|e| e.into_inner()) }}; } @@ -438,7 +438,7 @@ impl Events { /// Iterate over I/O events. pub(super) fn iter(&self) -> impl Iterator + '_ { - self.packets.iter().cloned() + self.packets.iter().copied() } }