From 7bfcc53a3e1d8c6f74ee735c7fc1412c8dea75ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=AE=87=E9=80=B8?= Date: Tue, 11 Jun 2024 19:46:57 +0800 Subject: [PATCH] feat: split sources and waitables --- src/iocp/psn/mod.rs | 195 +++++++++++++++++--------------------------- 1 file changed, 74 insertions(+), 121 deletions(-) diff --git a/src/iocp/psn/mod.rs b/src/iocp/psn/mod.rs index 023ec76..f0adbf6 100644 --- a/src/iocp/psn/mod.rs +++ b/src/iocp/psn/mod.rs @@ -23,7 +23,7 @@ use std::os::windows::io::{ RawHandle, RawSocket, }; use std::ptr::null_mut; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, Mutex, RwLock}; use std::time::Duration; use wait::WaitCompletionPacket; @@ -44,27 +44,35 @@ use windows_sys::Win32::System::IO::{ use super::dur2timeout; use crate::{Event, PollMode, NOTIFY_KEY}; +/// Macro to lock and ignore lock poisoning. +macro_rules! lock { + ($lock_result:expr) => {{ + $lock_result.unwrap_or_else(|e| e.into_inner()) + }}; +} + /// Interface to kqueue. #[derive(Debug)] pub struct Poller { /// The I/O completion port. port: Arc, - /// Attribute map. - sources: RwLock>, + + /// The state of the sources registered with this poller. + /// + /// Each source is keyed by its raw socket ID. + sources: RwLock>, + + /// The state of the waitable handles registered with this poller. + waitables: Mutex>, } -/// Attributes of added sources. +/// A waitable object with key and [`WaitCompletionPacket`]. +/// +/// [`WaitCompletionPacket`]: wait::WaitCompletionPacket #[derive(Debug)] -pub(crate) enum SourceAttr { - /// A socket with key. - Socket { key: usize }, - /// A waitable object with key and [`WaitCompletionPacket`]. - /// - /// [`WaitCompletionPacket`]: wait::WaitCompletionPacket - Waitable { - key: usize, - packet: wait::WaitCompletionPacket, - }, +struct WaitableAttr { + key: usize, + packet: wait::WaitCompletionPacket, } impl Poller { @@ -80,6 +88,7 @@ impl Poller { Ok(Poller { port, sources: RwLock::default(), + waitables: Mutex::default(), }) } @@ -107,11 +116,11 @@ impl Poller { ); let _enter = span.enter(); - self.add_source( - socket as _, - SourceAttr::Socket { key: interest.key }, - |_| Ok(()), - )?; + let mut sources = lock!(self.sources.write()); + if sources.contains_key(&socket) { + return Err(io::Error::from(io::ErrorKind::AlreadyExists)); + } + sources.insert(socket, interest.key); let info = create_registration(socket, interest, mode, true); self.update_source(info) @@ -134,7 +143,9 @@ impl Poller { let socket = socket.as_raw_socket(); - self.has_socket(socket as _)?; + lock!(self.sources.read()) + .get(&socket) + .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?; let info = create_registration(socket, interest, mode, true); unsafe { self.update_source(info) } @@ -151,12 +162,11 @@ impl Poller { let socket = socket.as_raw_socket(); - if let SourceAttr::Socket { key } = self.remove_source(socket as _)? { - let info = create_registration(socket, Event::none(key), PollMode::Oneshot, false); - unsafe { self.update_source(info) } - } else { - Err(io::Error::from(io::ErrorKind::NotFound)) - } + let key = lock!(self.sources.write()) + .remove(&socket) + .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?; + let info = create_registration(socket, Event::none(key), PollMode::Oneshot, false); + unsafe { self.update_source(info) } } /// Add a new waitable to the poller. @@ -182,23 +192,20 @@ impl Poller { let key = interest.key; - let packet = wait::WaitCompletionPacket::new()?; - self.add_source( - handle as _, - SourceAttr::Waitable { key, packet }, - |source| { - if let SourceAttr::Waitable { key, packet } = source { - packet.associate( - self.port.as_raw_handle(), - handle, - *key, - interest_to_events(&interest) as _, - ) - } else { - unreachable!() - } - }, - ) + let mut waitables = lock!(self.waitables.lock()); + if waitables.contains_key(&handle) { + return Err(io::Error::from(io::ErrorKind::AlreadyExists)); + } + + let mut packet = wait::WaitCompletionPacket::new()?; + packet.associate( + self.port.as_raw_handle(), + handle, + key, + interest_to_events(&interest) as _, + )?; + waitables.insert(handle, WaitableAttr { key, packet }); + Ok(()) } /// Update a waitable in the poller. @@ -222,91 +229,34 @@ impl Poller { )); } - self.has_waitable(waitable as _, |key, packet| { - let cancelled = packet.cancel()?; - if !cancelled { - // The packet could not be reused, create a new one. - *packet = WaitCompletionPacket::new()?; - } - packet.associate( - self.port.as_raw_handle(), - waitable, - key, - interest_to_events(&interest) as _, - ) - }) + let mut waitables = lock!(self.waitables.lock()); + let WaitableAttr { key, packet } = waitables + .get_mut(&waitable) + .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?; + + let cancelled = packet.cancel()?; + if !cancelled { + // The packet could not be reused, create a new one. + *packet = WaitCompletionPacket::new()?; + } + packet.associate( + self.port.as_raw_handle(), + waitable, + *key, + interest_to_events(&interest) as _, + ) } /// Delete a waitable from the poller. pub(crate) fn remove_waitable(&self, waitable: RawHandle) -> io::Result<()> { tracing::trace!("remove: handle={:?}, waitable={:p}", self.port, waitable); - if let SourceAttr::Waitable { mut packet, .. } = self.remove_source(waitable as _)? { - packet.cancel()?; - Ok(()) - } else { - Err(io::Error::from(io::ErrorKind::NotFound)) - } - } + let WaitableAttr { mut packet, .. } = lock!(self.waitables.lock()) + .remove(&waitable) + .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?; - /// Add a source to the sources set. - #[inline] - pub(crate) fn add_source( - &self, - handle: usize, - source: SourceAttr, - handler: impl FnOnce(&mut SourceAttr) -> io::Result<()>, - ) -> io::Result<()> { - let mut sources = self.sources.write().unwrap_or_else(|e| e.into_inner()); - if sources.contains_key(&handle) { - return Err(io::Error::from(io::ErrorKind::AlreadyExists)); - } - let source = sources.entry(handle).or_insert(source); - handler(source) - } - - /// Tell if a socket is currently inside the set. - #[inline] - pub(crate) fn has_socket(&self, handle: usize) -> io::Result { - if let Some(SourceAttr::Socket { key }) = self - .sources - .read() - .unwrap_or_else(|e| e.into_inner()) - .get(&handle) - { - Ok(*key) - } else { - Err(io::Error::from(io::ErrorKind::NotFound)) - } - } - - /// Tell if a waitable is currently inside the set. - #[inline] - pub(crate) fn has_waitable( - &self, - handle: usize, - handler: impl FnOnce(usize, &mut WaitCompletionPacket) -> io::Result<()>, - ) -> io::Result<()> { - if let Some(SourceAttr::Waitable { key, packet }) = self - .sources - .write() - .unwrap_or_else(|e| e.into_inner()) - .get_mut(&handle) - { - handler(*key, packet) - } else { - Err(io::Error::from(io::ErrorKind::NotFound)) - } - } - - /// Remove a source from the sources set. - #[inline] - pub(crate) fn remove_source(&self, handle: usize) -> io::Result { - self.sources - .write() - .unwrap_or_else(|e| e.into_inner()) - .remove(&handle) - .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound)) + packet.cancel()?; + Ok(()) } /// Add or modify the registration. @@ -415,6 +365,9 @@ impl AsHandle for Poller { } } +unsafe impl Send for Poller {} +unsafe impl Sync for Poller {} + /// A list of reported I/O events. pub struct Events { list: Vec,