Skip to content

Commit

Permalink
feat: split sources and waitables
Browse files Browse the repository at this point in the history
  • Loading branch information
Berrysoft committed Jun 11, 2024
1 parent 6104c0d commit 7bfcc53
Showing 1 changed file with 74 additions and 121 deletions.
195 changes: 74 additions & 121 deletions src/iocp/psn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use std::os::windows::io::{
RawHandle, RawSocket,
};
use std::ptr::null_mut;
use std::sync::{Arc, RwLock};
use std::sync::{Arc, Mutex, RwLock};
use std::time::Duration;

use wait::WaitCompletionPacket;
Expand All @@ -44,27 +44,35 @@ use windows_sys::Win32::System::IO::{
use super::dur2timeout;
use crate::{Event, PollMode, NOTIFY_KEY};

/// Macro to lock and ignore lock poisoning.
macro_rules! lock {
($lock_result:expr) => {{
$lock_result.unwrap_or_else(|e| e.into_inner())
}};
}

/// Interface to kqueue.
#[derive(Debug)]
pub struct Poller {
/// The I/O completion port.
port: Arc<OwnedHandle>,
/// Attribute map.
sources: RwLock<HashMap<usize, SourceAttr>>,

/// The state of the sources registered with this poller.
///
/// Each source is keyed by its raw socket ID.
sources: RwLock<HashMap<RawSocket, usize>>,

/// The state of the waitable handles registered with this poller.
waitables: Mutex<HashMap<RawHandle, WaitableAttr>>,
}

/// Attributes of added sources.
/// A waitable object with key and [`WaitCompletionPacket`].
///
/// [`WaitCompletionPacket`]: wait::WaitCompletionPacket
#[derive(Debug)]
pub(crate) enum SourceAttr {
/// A socket with key.
Socket { key: usize },
/// A waitable object with key and [`WaitCompletionPacket`].
///
/// [`WaitCompletionPacket`]: wait::WaitCompletionPacket
Waitable {
key: usize,
packet: wait::WaitCompletionPacket,
},
struct WaitableAttr {
key: usize,
packet: wait::WaitCompletionPacket,
}

impl Poller {
Expand All @@ -80,6 +88,7 @@ impl Poller {
Ok(Poller {
port,
sources: RwLock::default(),
waitables: Mutex::default(),
})
}

Expand Down Expand Up @@ -107,11 +116,11 @@ impl Poller {
);
let _enter = span.enter();

self.add_source(
socket as _,
SourceAttr::Socket { key: interest.key },
|_| Ok(()),
)?;
let mut sources = lock!(self.sources.write());
if sources.contains_key(&socket) {
return Err(io::Error::from(io::ErrorKind::AlreadyExists));
}
sources.insert(socket, interest.key);

let info = create_registration(socket, interest, mode, true);
self.update_source(info)
Expand All @@ -134,7 +143,9 @@ impl Poller {

let socket = socket.as_raw_socket();

self.has_socket(socket as _)?;
lock!(self.sources.read())
.get(&socket)
.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?;

let info = create_registration(socket, interest, mode, true);
unsafe { self.update_source(info) }
Expand All @@ -151,12 +162,11 @@ impl Poller {

let socket = socket.as_raw_socket();

if let SourceAttr::Socket { key } = self.remove_source(socket as _)? {
let info = create_registration(socket, Event::none(key), PollMode::Oneshot, false);
unsafe { self.update_source(info) }
} else {
Err(io::Error::from(io::ErrorKind::NotFound))
}
let key = lock!(self.sources.write())
.remove(&socket)
.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?;
let info = create_registration(socket, Event::none(key), PollMode::Oneshot, false);
unsafe { self.update_source(info) }
}

/// Add a new waitable to the poller.
Expand All @@ -182,23 +192,20 @@ impl Poller {

let key = interest.key;

let packet = wait::WaitCompletionPacket::new()?;
self.add_source(
handle as _,
SourceAttr::Waitable { key, packet },
|source| {
if let SourceAttr::Waitable { key, packet } = source {
packet.associate(
self.port.as_raw_handle(),
handle,
*key,
interest_to_events(&interest) as _,
)
} else {
unreachable!()
}
},
)
let mut waitables = lock!(self.waitables.lock());
if waitables.contains_key(&handle) {
return Err(io::Error::from(io::ErrorKind::AlreadyExists));
}

let mut packet = wait::WaitCompletionPacket::new()?;
packet.associate(
self.port.as_raw_handle(),
handle,
key,
interest_to_events(&interest) as _,
)?;
waitables.insert(handle, WaitableAttr { key, packet });
Ok(())
}

/// Update a waitable in the poller.
Expand All @@ -222,91 +229,34 @@ impl Poller {
));
}

self.has_waitable(waitable as _, |key, packet| {
let cancelled = packet.cancel()?;
if !cancelled {
// The packet could not be reused, create a new one.
*packet = WaitCompletionPacket::new()?;
}
packet.associate(
self.port.as_raw_handle(),
waitable,
key,
interest_to_events(&interest) as _,
)
})
let mut waitables = lock!(self.waitables.lock());
let WaitableAttr { key, packet } = waitables
.get_mut(&waitable)
.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?;

let cancelled = packet.cancel()?;
if !cancelled {
// The packet could not be reused, create a new one.
*packet = WaitCompletionPacket::new()?;
}
packet.associate(
self.port.as_raw_handle(),
waitable,
*key,
interest_to_events(&interest) as _,
)
}

/// Delete a waitable from the poller.
pub(crate) fn remove_waitable(&self, waitable: RawHandle) -> io::Result<()> {
tracing::trace!("remove: handle={:?}, waitable={:p}", self.port, waitable);

if let SourceAttr::Waitable { mut packet, .. } = self.remove_source(waitable as _)? {
packet.cancel()?;
Ok(())
} else {
Err(io::Error::from(io::ErrorKind::NotFound))
}
}
let WaitableAttr { mut packet, .. } = lock!(self.waitables.lock())
.remove(&waitable)
.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?;

/// Add a source to the sources set.
#[inline]
pub(crate) fn add_source(
&self,
handle: usize,
source: SourceAttr,
handler: impl FnOnce(&mut SourceAttr) -> io::Result<()>,
) -> io::Result<()> {
let mut sources = self.sources.write().unwrap_or_else(|e| e.into_inner());
if sources.contains_key(&handle) {
return Err(io::Error::from(io::ErrorKind::AlreadyExists));
}
let source = sources.entry(handle).or_insert(source);
handler(source)
}

/// Tell if a socket is currently inside the set.
#[inline]
pub(crate) fn has_socket(&self, handle: usize) -> io::Result<usize> {
if let Some(SourceAttr::Socket { key }) = self
.sources
.read()
.unwrap_or_else(|e| e.into_inner())
.get(&handle)
{
Ok(*key)
} else {
Err(io::Error::from(io::ErrorKind::NotFound))
}
}

/// Tell if a waitable is currently inside the set.
#[inline]
pub(crate) fn has_waitable(
&self,
handle: usize,
handler: impl FnOnce(usize, &mut WaitCompletionPacket) -> io::Result<()>,
) -> io::Result<()> {
if let Some(SourceAttr::Waitable { key, packet }) = self
.sources
.write()
.unwrap_or_else(|e| e.into_inner())
.get_mut(&handle)
{
handler(*key, packet)
} else {
Err(io::Error::from(io::ErrorKind::NotFound))
}
}

/// Remove a source from the sources set.
#[inline]
pub(crate) fn remove_source(&self, handle: usize) -> io::Result<SourceAttr> {
self.sources
.write()
.unwrap_or_else(|e| e.into_inner())
.remove(&handle)
.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))
packet.cancel()?;
Ok(())
}

/// Add or modify the registration.
Expand Down Expand Up @@ -415,6 +365,9 @@ impl AsHandle for Poller {
}
}

unsafe impl Send for Poller {}
unsafe impl Sync for Poller {}

/// A list of reported I/O events.
pub struct Events {
list: Vec<OVERLAPPED_ENTRY>,
Expand Down

0 comments on commit 7bfcc53

Please sign in to comment.