diff --git a/monoio/src/driver/legacy/iocp/afd.rs b/monoio/src/driver/legacy/iocp/afd.rs new file mode 100644 index 00000000..a4681a8f --- /dev/null +++ b/monoio/src/driver/legacy/iocp/afd.rs @@ -0,0 +1,198 @@ +use std::{ + ffi::c_void, + fs::File, + os::windows::prelude::{AsRawHandle, FromRawHandle, RawHandle}, + sync::atomic::{AtomicUsize, Ordering}, +}; + +use windows_sys::Win32::{ + Foundation::{ + RtlNtStatusToDosError, HANDLE, INVALID_HANDLE_VALUE, NTSTATUS, STATUS_NOT_FOUND, + STATUS_PENDING, STATUS_SUCCESS, UNICODE_STRING, + }, + Storage::FileSystem::{ + NtCreateFile, SetFileCompletionNotificationModes, FILE_OPEN, FILE_SHARE_READ, + FILE_SHARE_WRITE, SYNCHRONIZE, + }, + System::WindowsProgramming::{ + NtDeviceIoControlFile, FILE_SKIP_SET_EVENT_ON_HANDLE, IO_STATUS_BLOCK, IO_STATUS_BLOCK_0, + OBJECT_ATTRIBUTES, + }, +}; + +use super::CompletionPort; + +#[link(name = "ntdll")] +extern "system" { + /// See + /// + /// This is an undocumented API and as such not part of + /// from which `windows-sys` is generated, and also unlikely to be added, so + /// we manually declare it here + fn NtCancelIoFileEx( + FileHandle: HANDLE, + IoRequestToCancel: *mut IO_STATUS_BLOCK, + IoStatusBlock: *mut IO_STATUS_BLOCK, + ) -> NTSTATUS; +} + +static NEXT_TOKEN: AtomicUsize = AtomicUsize::new(0); + +macro_rules! s { + ($($id:expr)+) => { + &[$($id as u16),+] + } +} + +pub const POLL_RECEIVE: u32 = 0b0_0000_0001; +pub const POLL_RECEIVE_EXPEDITED: u32 = 0b0_0000_0010; +pub const POLL_SEND: u32 = 0b0_0000_0100; +pub const POLL_DISCONNECT: u32 = 0b0_0000_1000; +pub const POLL_ABORT: u32 = 0b0_0001_0000; +pub const POLL_LOCAL_CLOSE: u32 = 0b0_0010_0000; +// Not used as it indicated in each event where a connection is connected, not +// just the first time a connection is established. +// Also see https://github.com/piscisaureus/wepoll/commit/8b7b340610f88af3d83f40fb728e7b850b090ece. +pub const POLL_CONNECT: u32 = 0b0_0100_0000; +pub const POLL_ACCEPT: u32 = 0b0_1000_0000; +pub const POLL_CONNECT_FAIL: u32 = 0b1_0000_0000; + +pub const KNOWN_EVENTS: u32 = POLL_RECEIVE + | POLL_RECEIVE_EXPEDITED + | POLL_SEND + | POLL_DISCONNECT + | POLL_ABORT + | POLL_LOCAL_CLOSE + | POLL_ACCEPT + | POLL_CONNECT_FAIL; + +#[repr(C)] +pub struct AfdPollHandleInfo { + pub handle: HANDLE, + pub events: u32, + pub status: NTSTATUS, +} + +#[repr(C)] +pub struct AfdPollInfo { + pub timeout: i64, + pub number_of_handles: u32, + pub exclusive: u32, + pub handles: [AfdPollHandleInfo; 1], +} + +pub struct Afd { + file: File, +} + +impl Afd { + pub fn new(cp: &CompletionPort) -> std::io::Result { + const AFD_NAME: &[u16] = s!['\\' 'D' 'e' 'v' 'i' 'c' 'e' '\\' 'A' 'f' 'd' '\\' 'I' 'o']; + let mut device_name = UNICODE_STRING { + Length: std::mem::size_of_val(AFD_NAME) as u16, + MaximumLength: std::mem::size_of_val(AFD_NAME) as u16, + Buffer: AFD_NAME.as_ptr() as *mut u16, + }; + let mut device_attributes = OBJECT_ATTRIBUTES { + Length: std::mem::size_of::() as u32, + RootDirectory: 0, + ObjectName: &mut device_name, + Attributes: 0, + SecurityDescriptor: std::ptr::null_mut(), + SecurityQualityOfService: std::ptr::null_mut(), + }; + let mut handle = INVALID_HANDLE_VALUE; + let mut iosb = unsafe { std::mem::zeroed::() }; + let result = unsafe { + NtCreateFile( + &mut handle, + SYNCHRONIZE, + &mut device_attributes, + &mut iosb, + std::ptr::null_mut(), + 0, + FILE_SHARE_READ | FILE_SHARE_WRITE, + FILE_OPEN, + 0, + std::ptr::null_mut(), + 0, + ) + }; + + if result != STATUS_SUCCESS { + let error = unsafe { RtlNtStatusToDosError(result) }; + return Err(std::io::Error::from_raw_os_error(error as i32)); + } + + let file = unsafe { File::from_raw_handle(handle as RawHandle) }; + // Increment by 2 to reserve space for other types of handles. + // Non-AFD types (currently only NamedPipe), use odd numbered + // tokens. This allows the selector to differentiate between them + // and dispatch events accordingly. + let token = NEXT_TOKEN.fetch_add(2, Ordering::Relaxed) + 2; + cp.add_handle(token, file.as_raw_handle() as HANDLE)?; + let result = unsafe { + SetFileCompletionNotificationModes( + handle, + FILE_SKIP_SET_EVENT_ON_HANDLE as u8, // This is just 2, so fits in u8 + ) + }; + + if result == 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(Self { file }) + } + } + + pub unsafe fn poll( + &self, + info: &mut AfdPollInfo, + iosb: *mut IO_STATUS_BLOCK, + overlapped: *mut c_void, + ) -> std::io::Result { + const IOCTL_AFD_POLL: u32 = 0x00012024; + let info_ptr = info as *mut _ as *mut c_void; + (*iosb).Anonymous.Status = STATUS_PENDING; + + let result = NtDeviceIoControlFile( + self.file.as_raw_handle() as HANDLE, + 0, + None, + overlapped, + iosb, + IOCTL_AFD_POLL, + info_ptr, + std::mem::size_of::() as u32, + info_ptr, + std::mem::size_of::() as u32, + ); + + match result { + STATUS_SUCCESS => Ok(true), + STATUS_PENDING => Ok(false), + status => { + let error = RtlNtStatusToDosError(status); + Err(std::io::Error::from_raw_os_error(error as i32)) + } + } + } + + pub unsafe fn cancel(&self, iosb: *mut IO_STATUS_BLOCK) -> std::io::Result<()> { + if (*iosb).Anonymous.Status != STATUS_PENDING { + return Ok(()); + } + let mut cancel_iosb = IO_STATUS_BLOCK { + Anonymous: IO_STATUS_BLOCK_0 { Status: 0 }, + Information: 0, + }; + let status = NtCancelIoFileEx(self.file.as_raw_handle() as HANDLE, iosb, &mut cancel_iosb); + + if status == STATUS_SUCCESS || status == STATUS_NOT_FOUND { + Ok(()) + } else { + let error = RtlNtStatusToDosError(status); + Err(std::io::Error::from_raw_os_error(error as i32)) + } + } +} diff --git a/monoio/src/driver/legacy/iocp/event.rs b/monoio/src/driver/legacy/iocp/event.rs new file mode 100644 index 00000000..0f962ff8 --- /dev/null +++ b/monoio/src/driver/legacy/iocp/event.rs @@ -0,0 +1,120 @@ +use mio::Token; +use windows_sys::Win32::System::IO::OVERLAPPED_ENTRY; + +use super::afd; + +#[derive(Clone)] +pub struct Event { + pub flags: u32, + pub data: u64, +} + +impl Event { + pub fn new(token: Token) -> Event { + Event { + flags: 0, + data: usize::from(token) as u64, + } + } + + pub fn token(&self) -> Token { + Token(self.data as usize) + } + + pub fn set_readable(&mut self) { + self.flags |= afd::POLL_RECEIVE + } + + pub fn set_writable(&mut self) { + self.flags |= afd::POLL_SEND; + } + + pub fn from_entry(status: &OVERLAPPED_ENTRY) -> Event { + Event { + flags: status.dwNumberOfBytesTransferred, + data: status.lpCompletionKey as u64, + } + } + + pub fn to_entry(&self) -> OVERLAPPED_ENTRY { + OVERLAPPED_ENTRY { + dwNumberOfBytesTransferred: self.flags, + lpCompletionKey: self.data as usize, + lpOverlapped: std::ptr::null_mut(), + Internal: 0, + } + } + + pub fn is_readable(&self) -> bool { + self.flags & READABLE_FLAGS != 0 + } + + pub fn is_writable(&self) -> bool { + self.flags & WRITABLE_FLAGS != 0 + } + + pub fn is_error(&self) -> bool { + self.flags & ERROR_FLAGS != 0 + } + + pub fn is_read_closed(&self) -> bool { + self.flags & READ_CLOSED_FLAGS != 0 + } + + pub fn is_write_closed(&self) -> bool { + self.flags & WRITE_CLOSED_FLAGS != 0 + } + + pub fn is_priority(&self) -> bool { + self.flags & afd::POLL_RECEIVE_EXPEDITED != 0 + } +} + +pub(crate) const READABLE_FLAGS: u32 = afd::POLL_RECEIVE + | afd::POLL_DISCONNECT + | afd::POLL_ACCEPT + | afd::POLL_ABORT + | afd::POLL_CONNECT_FAIL; +pub(crate) const WRITABLE_FLAGS: u32 = afd::POLL_SEND | afd::POLL_ABORT | afd::POLL_CONNECT_FAIL; +pub(crate) const ERROR_FLAGS: u32 = afd::POLL_CONNECT_FAIL; +pub(crate) const READ_CLOSED_FLAGS: u32 = + afd::POLL_DISCONNECT | afd::POLL_ABORT | afd::POLL_CONNECT_FAIL; +pub(crate) const WRITE_CLOSED_FLAGS: u32 = afd::POLL_ABORT | afd::POLL_CONNECT_FAIL; + +pub struct Events { + pub statuses: Box<[OVERLAPPED_ENTRY]>, + + pub events: Vec, +} + +impl Events { + pub fn with_capacity(cap: usize) -> Events { + Events { + statuses: unsafe { vec![std::mem::zeroed(); cap].into_boxed_slice() }, + events: Vec::with_capacity(cap), + } + } + + pub fn is_empty(&self) -> bool { + self.events.is_empty() + } + + pub fn capacity(&self) -> usize { + self.events.capacity() + } + + pub fn len(&self) -> usize { + self.events.len() + } + + pub fn get(&self, idx: usize) -> Option<&Event> { + self.events.get(idx) + } + + pub fn clear(&mut self) { + self.events.clear(); + for status in self.statuses.iter_mut() { + *status = unsafe { std::mem::zeroed() }; + } + } +} diff --git a/monoio/src/driver/legacy/iocp/iocp.rs b/monoio/src/driver/legacy/iocp/iocp.rs new file mode 100644 index 00000000..4b1a5c5f --- /dev/null +++ b/monoio/src/driver/legacy/iocp/iocp.rs @@ -0,0 +1,123 @@ +use std::{ + os::windows::prelude::{AsRawHandle, FromRawHandle, IntoRawHandle, RawHandle}, + time::Duration, +}; + +use windows_sys::Win32::{ + Foundation::{CloseHandle, HANDLE, INVALID_HANDLE_VALUE}, + System::IO::{ + CreateIoCompletionPort, GetQueuedCompletionStatusEx, PostQueuedCompletionStatus, + OVERLAPPED_ENTRY, + }, +}; + +#[derive(Debug)] +pub struct CompletionPort { + handle: HANDLE, +} + +impl CompletionPort { + pub fn new(value: u32) -> std::io::Result { + let handle = unsafe { CreateIoCompletionPort(INVALID_HANDLE_VALUE, 0, 0, value) }; + + if handle == 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(Self { handle }) + } + } + + pub fn add_handle(&self, token: usize, handle: HANDLE) -> std::io::Result<()> { + let result = unsafe { CreateIoCompletionPort(handle, self.handle, token, 0) }; + + if result == 0 { + return Err(std::io::Error::last_os_error()); + } else { + Ok(()) + } + } + + pub fn get_many<'a>( + &self, + entries: &'a mut [OVERLAPPED_ENTRY], + timeout: Option, + ) -> std::io::Result<&'a mut [OVERLAPPED_ENTRY]> { + let mut count = 0; + let result = unsafe { + GetQueuedCompletionStatusEx( + self.handle, + entries.as_mut_ptr(), + std::cmp::min(entries.len(), u32::max_value() as usize) as u32, + &mut count, + duration_millis(timeout), + 0, + ) + }; + + if result == 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(&mut entries[..count as usize]) + } + } + + pub fn post(&self, entry: OVERLAPPED_ENTRY) -> std::io::Result<()> { + let result = unsafe { + PostQueuedCompletionStatus( + self.handle, + entry.dwNumberOfBytesTransferred, + entry.lpCompletionKey, + entry.lpOverlapped, + ) + }; + + if result == 0 { + Err(std::io::Error::last_os_error()) + } else { + Ok(()) + } + } +} + +impl Drop for CompletionPort { + fn drop(&mut self) { + unsafe { CloseHandle(self.handle) }; + } +} + +impl AsRawHandle for CompletionPort { + fn as_raw_handle(&self) -> RawHandle { + self.handle as RawHandle + } +} + +impl FromRawHandle for CompletionPort { + unsafe fn from_raw_handle(handle: RawHandle) -> Self { + Self { + handle: handle as HANDLE, + } + } +} + +impl IntoRawHandle for CompletionPort { + fn into_raw_handle(self) -> RawHandle { + self.handle as RawHandle + } +} + +#[inline] +fn duration_millis(dur: Option) -> u32 { + if let Some(dur) = dur { + // `Duration::as_millis` truncates, so round up. This avoids + // turning sub-millisecond timeouts into a zero timeout, unless + // the caller explicitly requests that by specifying a zero + // timeout. + let dur_ms = dur + .checked_add(Duration::from_nanos(999_999)) + .unwrap_or(dur) + .as_millis(); + std::cmp::min(dur_ms, u32::MAX as u128) as u32 + } else { + u32::MAX + } +} diff --git a/monoio/src/driver/legacy/iocp/mod.rs b/monoio/src/driver/legacy/iocp/mod.rs new file mode 100644 index 00000000..bfdcec64 --- /dev/null +++ b/monoio/src/driver/legacy/iocp/mod.rs @@ -0,0 +1,312 @@ +mod afd; +mod event; +mod iocp; +mod state; +mod waker; + +use std::{ + collections::VecDeque, + os::windows::prelude::RawSocket, + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, Mutex, + }, + time::Duration, +}; + +pub use afd::*; +pub use event::*; +pub use iocp::*; +pub use state::*; +pub use waker::*; +use windows_sys::Win32::{ + Foundation::WAIT_TIMEOUT, + System::IO::{OVERLAPPED, OVERLAPPED_ENTRY}, +}; + +pub struct Poller { + is_polling: AtomicBool, + cp: CompletionPort, + update_queue: Mutex>>>>, + afd: Mutex>>, +} + +impl Poller { + pub fn new() -> std::io::Result { + Ok(Self { + is_polling: AtomicBool::new(false), + cp: CompletionPort::new(0)?, + update_queue: Mutex::new(VecDeque::new()), + afd: Mutex::new(Vec::new()), + }) + } + + pub fn poll(&self, events: &mut Events, timeout: Option) -> std::io::Result<()> { + events.clear(); + + if timeout.is_none() { + loop { + let len = self.poll_inner(&mut events.statuses, &mut events.events, None)?; + if len == 0 { + continue; + } + break Ok(()); + } + } else { + self.poll_inner(&mut events.statuses, &mut events.events, timeout)?; + Ok(()) + } + } + + pub fn poll_inner( + &self, + entries: &mut [OVERLAPPED_ENTRY], + events: &mut Vec, + timeout: Option, + ) -> std::io::Result { + self.is_polling.swap(true, Ordering::AcqRel); + + unsafe { self.update_sockets_events() }?; + + let result = self.cp.get_many(entries, timeout); + + self.is_polling.store(false, Ordering::Relaxed); + + match result { + Ok(iocp_events) => Ok(unsafe { self.feed_events(events, iocp_events) }), + Err(ref e) if e.raw_os_error() == Some(WAIT_TIMEOUT as i32) => Ok(0), + Err(e) => Err(e), + } + } + + unsafe fn update_sockets_events(&self) -> std::io::Result<()> { + let mut queue = self.update_queue.lock().unwrap(); + for sock in queue.iter_mut() { + let mut sock_internal = sock.lock().unwrap(); + if !sock_internal.delete_pending { + sock_internal.update(sock)?; + } + } + + queue.retain(|sock| sock.lock().unwrap().error.is_some()); + + let mut afd = self.afd.lock().unwrap(); + afd.retain(|g| Arc::strong_count(g) > 1); + Ok(()) + } + + unsafe fn feed_events(&self, events: &mut Vec, entries: &[OVERLAPPED_ENTRY]) -> usize { + let mut n = 0; + let mut update_queue = self.update_queue.lock().unwrap(); + for entry in entries.iter() { + if entry.lpOverlapped.is_null() { + events.push(Event::from_entry(entry)); + n += 1; + continue; + } + + let sock_state = from_overlapped(entry.lpOverlapped); + let mut sock_guard = sock_state.lock().unwrap(); + if let Some(e) = sock_guard.feed_event() { + events.push(e); + n += 1; + } + + if !sock_guard.delete_pending { + update_queue.push_back(sock_state.clone()); + } + } + let mut afd = self.afd.lock().unwrap(); + afd.retain(|sock| Arc::strong_count(sock) > 1); + n + } + + pub fn register( + &self, + state: &mut SocketState, + token: mio::Token, + interests: mio::Interest, + ) -> std::io::Result<()> { + if state.inner.is_none() { + let flags = interests_to_afd_flags(interests); + + let inner = { + let sock = self._alloc_sock_for_rawsocket(state.socket)?; + let event = Event { + flags, + data: token.0 as u64, + }; + sock.lock().unwrap().set_event(event); + sock + }; + + self.queue_state(inner.clone()); + unsafe { self.update_sockets_events_if_polling()? }; + state.inner = Some(inner); + state.token = token; + state.interest = interests; + + Ok(()) + } else { + Err(std::io::ErrorKind::AlreadyExists.into()) + } + } + + pub fn reregister( + &self, + state: &mut SocketState, + token: mio::Token, + interests: mio::Interest, + ) -> std::io::Result<()> { + if let Some(inner) = state.inner.as_mut() { + { + let event = Event { + flags: interests_to_afd_flags(interests), + data: token.0 as u64, + }; + + inner.lock().unwrap().set_event(event); + } + + state.token = token; + state.interest = interests; + + self.queue_state(inner.clone()); + unsafe { self.update_sockets_events_if_polling() } + } else { + Err(std::io::ErrorKind::NotFound.into()) + } + } + + pub fn deregister(&mut self, state: &mut SocketState) -> std::io::Result<()> { + if let Some(inner) = state.inner.as_mut() { + { + let mut sock_state = inner.lock().unwrap(); + sock_state.mark_delete(); + } + state.inner = None; + Ok(()) + } else { + Err(std::io::ErrorKind::NotFound.into()) + } + } + + /// This function is called by register() and reregister() to start an + /// IOCTL_AFD_POLL operation corresponding to the registered events, but + /// only if necessary. + /// + /// Since it is not possible to modify or synchronously cancel an AFD_POLL + /// operation, and there can be only one active AFD_POLL operation per + /// (socket, completion port) pair at any time, it is expensive to change + /// a socket's event registration after it has been submitted to the kernel. + /// + /// Therefore, if no other threads are polling when interest in a socket + /// event is (re)registered, the socket is added to the 'update queue', but + /// the actual syscall to start the IOCTL_AFD_POLL operation is deferred + /// until just before the GetQueuedCompletionStatusEx() syscall is made. + /// + /// However, when another thread is already blocked on + /// GetQueuedCompletionStatusEx() we tell the kernel about the registered + /// socket event(s) immediately. + unsafe fn update_sockets_events_if_polling(&self) -> std::io::Result<()> { + if self.is_polling.load(Ordering::Acquire) { + self.update_sockets_events() + } else { + Ok(()) + } + } + + fn queue_state(&self, sock_state: Pin>>) { + let mut update_queue = self.update_queue.lock().unwrap(); + update_queue.push_back(sock_state); + } + + fn _alloc_sock_for_rawsocket( + &self, + raw_socket: RawSocket, + ) -> std::io::Result>>> { + const POLL_GROUP__MAX_GROUP_SIZE: usize = 32; + + let mut afd_group = self.afd.lock().unwrap(); + if afd_group.len() == 0 { + self._alloc_afd_group(&mut afd_group)?; + } else { + // + 1 reference in Vec + if Arc::strong_count(afd_group.last().unwrap()) > POLL_GROUP__MAX_GROUP_SIZE { + self._alloc_afd_group(&mut afd_group)?; + } + } + let afd = match afd_group.last() { + Some(arc) => arc.clone(), + None => unreachable!("Cannot acquire afd"), + }; + + Ok(Arc::pin(Mutex::new(SockState::new(raw_socket, afd)?))) + } + + fn _alloc_afd_group(&self, afd_group: &mut Vec>) -> std::io::Result<()> { + let afd = Afd::new(&self.cp)?; + let arc = Arc::new(afd); + afd_group.push(arc); + Ok(()) + } +} + +impl Drop for Poller { + fn drop(&mut self) { + loop { + let count: usize; + let mut statuses: [OVERLAPPED_ENTRY; 1024] = unsafe { std::mem::zeroed() }; + + let result = self + .cp + .get_many(&mut statuses, Some(std::time::Duration::from_millis(0))); + match result { + Ok(events) => { + count = events.iter().len(); + for event in events.iter() { + if event.lpOverlapped.is_null() { + } else { + // drain sock state to release memory of Arc reference + let _ = from_overlapped(event.lpOverlapped); + } + } + } + Err(_) => break, + } + + if count == 0 { + break; + } + } + + let mut afd_group = self.afd.lock().unwrap(); + afd_group.retain(|g| Arc::strong_count(g) > 1); + } +} + +pub fn from_overlapped(ptr: *mut OVERLAPPED) -> Pin>> { + let sock_ptr: *const Mutex = ptr as *const _; + unsafe { Pin::new_unchecked(Arc::from_raw(sock_ptr)) } +} + +pub fn into_overlapped(sock_state: Pin>>) -> *mut std::ffi::c_void { + let overlapped_ptr: *const Mutex = + unsafe { Arc::into_raw(Pin::into_inner_unchecked(sock_state)) }; + overlapped_ptr as *mut _ +} + +pub fn interests_to_afd_flags(interests: mio::Interest) -> u32 { + let mut flags = 0; + + if interests.is_readable() { + flags |= READABLE_FLAGS | READ_CLOSED_FLAGS | ERROR_FLAGS; + } + + if interests.is_writable() { + flags |= WRITABLE_FLAGS | WRITE_CLOSED_FLAGS | ERROR_FLAGS; + } + + flags +} diff --git a/monoio/src/driver/legacy/iocp/state.rs b/monoio/src/driver/legacy/iocp/state.rs new file mode 100644 index 00000000..8fb51010 --- /dev/null +++ b/monoio/src/driver/legacy/iocp/state.rs @@ -0,0 +1,282 @@ +use std::{ + marker::PhantomPinned, + os::windows::prelude::RawSocket, + pin::Pin, + sync::{Arc, Mutex}, +}; + +use windows_sys::Win32::{ + Foundation::{ERROR_INVALID_HANDLE, ERROR_IO_PENDING, HANDLE, STATUS_CANCELLED}, + Networking::WinSock::{ + WSAGetLastError, WSAIoctl, SIO_BASE_HANDLE, SIO_BSP_HANDLE, SIO_BSP_HANDLE_POLL, + SIO_BSP_HANDLE_SELECT, SOCKET_ERROR, + }, + System::WindowsProgramming::IO_STATUS_BLOCK, +}; + +use super::{afd, from_overlapped, into_overlapped, Afd, AfdPollInfo, Event}; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum SockPollStatus { + Idle, + Pending, + Cancelled, +} + +pub struct SocketState { + pub socket: RawSocket, + pub inner: Option>>>, + pub token: mio::Token, + pub interest: mio::Interest, +} + +impl SocketState { + pub fn new(socket: RawSocket) -> Self { + Self { + socket, + inner: None, + token: mio::Token(0), + interest: mio::Interest::READABLE, + } + } +} + +pub struct SockState { + pub iosb: IO_STATUS_BLOCK, + pub poll_info: AfdPollInfo, + pub afd: Arc, + + pub base_socket: RawSocket, + + pub user_evts: u32, + pub pending_evts: u32, + + pub user_data: u64, + + pub poll_status: SockPollStatus, + pub delete_pending: bool, + + pub error: Option, + + _pinned: PhantomPinned, +} + +impl SockState { + pub fn new(raw_socket: RawSocket, afd: Arc) -> std::io::Result { + Ok(SockState { + iosb: unsafe { std::mem::zeroed() }, + poll_info: unsafe { std::mem::zeroed() }, + afd, + base_socket: get_base_socket(raw_socket)?, + user_evts: 0, + pending_evts: 0, + user_data: 0, + poll_status: SockPollStatus::Idle, + delete_pending: false, + error: None, + _pinned: PhantomPinned, + }) + } + + pub fn update(&mut self, self_arc: &Pin>>) -> std::io::Result<()> { + assert!(!self.delete_pending); + + // make sure to reset previous error before a new update + self.error = None; + + if let SockPollStatus::Pending = self.poll_status { + if (self.user_evts & afd::KNOWN_EVENTS & !self.pending_evts) == 0 { + // All the events the user is interested in are already being monitored by + // the pending poll operation. It might spuriously complete because of an + // event that we're no longer interested in; when that happens we'll submit + // a new poll operation with the updated event mask. + } else { + // A poll operation is already pending, but it's not monitoring for all the + // events that the user is interested in. Therefore, cancel the pending + // poll operation; when we receive it's completion package, a new poll + // operation will be submitted with the correct event mask. + if let Err(e) = self.cancel() { + self.error = e.raw_os_error(); + return Err(e); + } + return Ok(()); + } + } else if let SockPollStatus::Cancelled = self.poll_status { + // The poll operation has already been cancelled, we're still waiting for + // it to return. For now, there's nothing that needs to be done. + } else if let SockPollStatus::Idle = self.poll_status { + // No poll operation is pending; start one. + self.poll_info.exclusive = 0; + self.poll_info.number_of_handles = 1; + self.poll_info.timeout = i64::MAX; + self.poll_info.handles[0].handle = self.base_socket as HANDLE; + self.poll_info.handles[0].status = 0; + self.poll_info.handles[0].events = self.user_evts | afd::POLL_LOCAL_CLOSE; + + // Increase the ref count as the memory will be used by the kernel. + let overlapped_ptr = into_overlapped(self_arc.clone()); + + let result = unsafe { + self.afd + .poll(&mut self.poll_info, &mut self.iosb, overlapped_ptr) + }; + if let Err(e) = result { + let code = e.raw_os_error().unwrap(); + if code == ERROR_IO_PENDING as i32 { + // Overlapped poll operation in progress; this is expected. + } else { + // Since the operation failed it means the kernel won't be + // using the memory any more. + drop(from_overlapped(overlapped_ptr as *mut _)); + if code == ERROR_INVALID_HANDLE as i32 { + // Socket closed; it'll be dropped. + self.mark_delete(); + return Ok(()); + } else { + self.error = e.raw_os_error(); + return Err(e); + } + } + } + + self.poll_status = SockPollStatus::Pending; + self.pending_evts = self.user_evts; + } else { + unreachable!("Invalid poll status during update") + } + + Ok(()) + } + + pub fn feed_event(&mut self) -> Option { + self.poll_status = SockPollStatus::Idle; + self.pending_evts = 0; + + let mut afd_events = 0; + // We use the status info in IO_STATUS_BLOCK to determine the socket poll status. It is + // unsafe to use a pointer of IO_STATUS_BLOCK. + unsafe { + if self.delete_pending { + return None; + } else if self.iosb.Anonymous.Status == STATUS_CANCELLED { + // The poll request was cancelled by CancelIoEx. + } else if self.iosb.Anonymous.Status < 0 { + // The overlapped request itself failed in an unexpected way. + afd_events = afd::POLL_CONNECT_FAIL; + } else if self.poll_info.number_of_handles < 1 { + // This poll operation succeeded but didn't report any socket events. + } else if self.poll_info.handles[0].events & afd::POLL_LOCAL_CLOSE != 0 { + // The poll operation reported that the socket was closed. + self.mark_delete(); + return None; + } else { + afd_events = self.poll_info.handles[0].events; + } + } + + afd_events &= self.user_evts; + + if afd_events == 0 { + return None; + } + + self.user_evts &= !afd_events; + + Some(Event { + data: self.user_data, + flags: afd_events, + }) + } + + pub fn mark_delete(&mut self) { + if !self.delete_pending { + if let SockPollStatus::Pending = self.poll_status { + drop(self.cancel()); + } + + self.delete_pending = true; + } + } + + pub fn set_event(&mut self, ev: Event) -> bool { + // afd::POLL_CONNECT_FAIL and afd::POLL_ABORT are always reported, even when not requested + // by the caller. + let events = ev.flags | afd::POLL_CONNECT_FAIL | afd::POLL_ABORT; + + self.user_evts = events; + self.user_data = ev.data; + + (events & !self.pending_evts) != 0 + } + + pub fn cancel(&mut self) -> std::io::Result<()> { + match self.poll_status { + SockPollStatus::Pending => {} + _ => unreachable!("Invalid poll status during cancel"), + }; + unsafe { + self.afd.cancel(&mut self.iosb)?; + } + self.poll_status = SockPollStatus::Cancelled; + self.pending_evts = 0; + Ok(()) + } +} + +impl Drop for SockState { + fn drop(&mut self) { + self.mark_delete(); + } +} + +fn get_base_socket(raw_socket: RawSocket) -> std::io::Result { + let res = try_get_base_socket(raw_socket, SIO_BASE_HANDLE); + if let Ok(base_socket) = res { + return Ok(base_socket); + } + + // The `SIO_BASE_HANDLE` should not be intercepted by LSPs, therefore + // it should not fail as long as `raw_socket` is a valid socket. See + // https://docs.microsoft.com/en-us/windows/win32/winsock/winsock-ioctls. + // However, at least one known LSP deliberately breaks it, so we try + // some alternative IOCTLs, starting with the most appropriate one. + for &ioctl in &[SIO_BSP_HANDLE_SELECT, SIO_BSP_HANDLE_POLL, SIO_BSP_HANDLE] { + if let Ok(base_socket) = try_get_base_socket(raw_socket, ioctl) { + // Since we know now that we're dealing with an LSP (otherwise + // SIO_BASE_HANDLE would't have failed), only return any result + // when it is different from the original `raw_socket`. + if base_socket != raw_socket { + return Ok(base_socket); + } + } + } + + // If the alternative IOCTLs also failed, return the original error. + let os_error = res.unwrap_err(); + let err = std::io::Error::from_raw_os_error(os_error); + Err(err) +} + +fn try_get_base_socket(raw_socket: RawSocket, ioctl: u32) -> Result { + let mut base_socket: RawSocket = 0; + let mut bytes: u32 = 0; + let result = unsafe { + WSAIoctl( + raw_socket as usize, + ioctl, + std::ptr::null_mut(), + 0, + &mut base_socket as *mut _ as *mut std::ffi::c_void, + std::mem::size_of::() as u32, + &mut bytes, + std::ptr::null_mut(), + None, + ) + }; + + if result != SOCKET_ERROR { + Ok(base_socket) + } else { + Err(unsafe { WSAGetLastError() }) + } +} diff --git a/monoio/src/driver/legacy/iocp/waker.rs b/monoio/src/driver/legacy/iocp/waker.rs new file mode 100644 index 00000000..b5d9b6b5 --- /dev/null +++ b/monoio/src/driver/legacy/iocp/waker.rs @@ -0,0 +1,25 @@ +use std::{io, sync::Arc}; + +use super::{CompletionPort, Event, Poller}; + +#[derive(Debug)] +pub struct Waker { + token: mio::Token, + port: Arc, +} + +impl Waker { + pub fn new(poller: &Poller, token: mio::Token) -> io::Result { + Ok(Waker { + token, + port: poller.cp.clone(), + }) + } + + pub fn wake(&self) -> io::Result<()> { + let mut ev = Event::new(self.token); + ev.set_readable(); + + self.port.post(ev.to_completion_status()) + } +} diff --git a/monoio/src/driver/legacy/mod.rs b/monoio/src/driver/legacy/mod.rs index ffcf0f50..a92a7024 100644 --- a/monoio/src/driver/legacy/mod.rs +++ b/monoio/src/driver/legacy/mod.rs @@ -15,6 +15,8 @@ use super::{ }; use crate::utils::slab::Slab; +#[cfg(windows)] +pub(super) mod iocp; pub(crate) mod ready; mod scheduled_io; @@ -25,8 +27,14 @@ pub(crate) use waker::UnparkHandle; pub(crate) struct LegacyInner { pub(crate) io_dispatch: Slab, + #[cfg(unix)] events: Option, + #[cfg(unix)] poll: mio::Poll, + #[cfg(windows)] + events: Option, + #[cfg(windows)] + poll: iocp::Poller, #[cfg(feature = "sync")] shared_waker: std::sync::Arc, @@ -56,13 +64,21 @@ impl LegacyDriver { } pub(crate) fn new_with_entries(entries: u32) -> io::Result { + #[cfg(unix)] let poll = mio::Poll::new()?; + #[cfg(windows)] + let poll = iocp::Poller::new()?; - #[cfg(feature = "sync")] + #[cfg(all(unix, feature = "sync"))] let shared_waker = std::sync::Arc::new(waker::EventWaker::new(mio::Waker::new( poll.registry(), TOKEN_WAKEUP, )?)); + #[cfg(all(windows, feature = "sync"))] + let shared_waker = std::sync::Arc::new(waker::EventWaker::new(iocp::Waker::new( + &poll, + TOKEN_WAKEUP, + )?)); #[cfg(feature = "sync")] let (waker_sender, waker_receiver) = flume::unbounded::(); #[cfg(feature = "sync")] @@ -70,7 +86,13 @@ impl LegacyDriver { let inner = LegacyInner { io_dispatch: Slab::new(), + #[cfg(unix)] events: Some(mio::Events::with_capacity(entries as usize)), + #[cfg(unix)] + poll, + #[cfg(windows)] + events: Some(iocp::Events::with_capacity(entries as usize)), + #[cfg(windows)] poll, #[cfg(feature = "sync")] shared_waker, @@ -147,6 +169,44 @@ impl LegacyDriver { Ok(()) } + #[cfg(windows)] + pub(crate) fn register( + this: &Rc>, + state: &mut iocp::SockState, + interest: mio::Interest, + ) -> io::Result { + let inner = unsafe { &mut *this.get() }; + let io = ScheduledIo::default(); + let token = inner.io_dispatch.insert(io); + + match inner.poll.register(state, mio::Token(token), interest) { + Ok(_) => Ok(token), + Err(e) => { + inner.io_dispatch.remove(token); + Err(e) + } + } + } + + #[cfg(windows)] + pub(crate) fn deregister( + this: &Rc>, + token: usize, + state: &mut iocp::SockState, + ) -> io::Result<()> { + let inner = unsafe { &mut *this.get() }; + + // try to deregister fd first, on success we will remove it from slab. + match inner.poll.deregister(state) { + Ok(_) => { + inner.io_dispatch.remove(token); + Ok(()) + } + Err(e) => Err(e), + } + } + + #[cfg(unix)] pub(crate) fn register( this: &Rc>, source: &mut impl mio::event::Source, @@ -166,6 +226,7 @@ impl LegacyDriver { } } + #[cfg(unix)] pub(crate) fn deregister( this: &Rc>, token: usize, diff --git a/monoio/src/driver/legacy/ready.rs b/monoio/src/driver/legacy/ready.rs index 11d57d73..374ddcd5 100644 --- a/monoio/src/driver/legacy/ready.rs +++ b/monoio/src/driver/legacy/ready.rs @@ -45,6 +45,30 @@ impl Ready { pub(crate) const READ_ALL: Ready = Ready(READABLE | READ_CLOSED | READ_CANCELED); pub(crate) const WRITE_ALL: Ready = Ready(WRITABLE | WRITE_CLOSED | WRITE_CANCELED); + #[cfg(windows)] + pub(crate) fn from_mio(event: &super::iocp::Event) -> Ready { + let mut ready = Ready::EMPTY; + + if event.is_readable() { + ready |= Ready::READABLE; + } + + if event.is_writable() { + ready |= Ready::WRITABLE; + } + + if event.is_read_closed() { + ready |= Ready::READ_CLOSED; + } + + if event.is_write_closed() { + ready |= Ready::WRITE_CLOSED; + } + + ready + } + + #[cfg(unix)] // Must remain crate-private to avoid adding a public dependency on Mio. pub(crate) fn from_mio(event: &mio::event::Event) -> Ready { let mut ready = Ready::EMPTY; diff --git a/monoio/src/driver/legacy/waker.rs b/monoio/src/driver/legacy/waker.rs index 7d8e7b7e..40290e96 100644 --- a/monoio/src/driver/legacy/waker.rs +++ b/monoio/src/driver/legacy/waker.rs @@ -2,12 +2,16 @@ use crate::driver::unpark::Unpark; pub(crate) struct EventWaker { // raw waker + #[cfg(windows)] + waker: super::iocp::Waker, + #[cfg(unix)] waker: mio::Waker, // Atomic awake status pub(crate) awake: std::sync::atomic::AtomicBool, } impl EventWaker { + #[cfg(unix)] pub(crate) fn new(waker: mio::Waker) -> Self { Self { waker, @@ -15,6 +19,14 @@ impl EventWaker { } } + #[cfg(windows)] + pub(crate) fn new(waker: super::iocp::Waker) -> Self { + Self { + waker, + awake: std::sync::atomic::AtomicBool::new(true), + } + } + pub(crate) fn wake(&self) -> std::io::Result<()> { // Skip wake if already awake if self.awake.load(std::sync::atomic::Ordering::Acquire) { diff --git a/monoio/src/driver/mod.rs b/monoio/src/driver/mod.rs index d5094506..cf443303 100644 --- a/monoio/src/driver/mod.rs +++ b/monoio/src/driver/mod.rs @@ -5,7 +5,7 @@ pub(crate) mod shared_fd; #[cfg(feature = "sync")] pub(crate) mod thread; -#[cfg(all(unix, feature = "legacy"))] +#[cfg(feature = "legacy")] mod legacy; #[cfg(all(target_os = "linux", feature = "iouring"))] mod uring; @@ -18,7 +18,7 @@ use std::{ time::Duration, }; -#[cfg(all(unix, feature = "legacy"))] +#[cfg(feature = "legacy")] pub use self::legacy::LegacyDriver; // #[cfg(windows)] // pub mod op { @@ -28,7 +28,7 @@ pub use self::legacy::LegacyDriver; // } // pub trait OpAble {} // } -#[cfg(all(unix, feature = "legacy"))] +#[cfg(feature = "legacy")] use self::legacy::LegacyInner; use self::op::{CompletionMeta, Op, OpAble}; #[cfg(all(target_os = "linux", feature = "iouring"))] @@ -91,7 +91,7 @@ scoped_thread_local!(pub(crate) static CURRENT: Inner); pub(crate) enum Inner { #[cfg(all(target_os = "linux", feature = "iouring"))] Uring(std::rc::Rc>), - #[cfg(all(unix, feature = "legacy"))] + #[cfg(feature = "legacy")] Legacy(std::rc::Rc>), } @@ -102,7 +102,7 @@ impl Inner { _ => unimplemented!(), #[cfg(all(target_os = "linux", feature = "iouring"))] Inner::Uring(this) => UringInner::submit_with_data(this, data), - #[cfg(all(unix, feature = "legacy"))] + #[cfg(feature = "legacy")] Inner::Legacy(this) => LegacyInner::submit_with_data(this, data), #[cfg(all( not(feature = "legacy"), @@ -126,7 +126,7 @@ impl Inner { _ => unimplemented!(), #[cfg(all(target_os = "linux", feature = "iouring"))] Inner::Uring(this) => UringInner::poll_op(this, index, cx), - #[cfg(all(unix, feature = "legacy"))] + #[cfg(feature = "legacy")] Inner::Legacy(this) => LegacyInner::poll_op::(this, data, cx), #[cfg(all( not(feature = "legacy"), @@ -145,7 +145,7 @@ impl Inner { _ => unimplemented!(), #[cfg(all(target_os = "linux", feature = "iouring"))] Inner::Uring(this) => UringInner::drop_op(this, index, data), - #[cfg(all(unix, feature = "legacy"))] + #[cfg(feature = "legacy")] Inner::Legacy(_) => {} #[cfg(all( not(feature = "legacy"), @@ -164,7 +164,7 @@ impl Inner { _ => unimplemented!(), #[cfg(all(target_os = "linux", feature = "iouring"))] Inner::Uring(this) => UringInner::cancel_op(this, op_canceller.index), - #[cfg(all(unix, feature = "legacy"))] + #[cfg(feature = "legacy")] Inner::Legacy(this) => { if let Some(direction) = op_canceller.direction { LegacyInner::cancel_op(this, op_canceller.index, direction) @@ -203,7 +203,7 @@ impl Inner { pub(crate) enum UnparkHandle { #[cfg(all(target_os = "linux", feature = "iouring"))] Uring(self::uring::UnparkHandle), - #[cfg(all(unix, feature = "legacy"))] + #[cfg(feature = "legacy")] Legacy(self::legacy::UnparkHandle), } @@ -213,7 +213,7 @@ impl unpark::Unpark for UnparkHandle { match self { #[cfg(all(target_os = "linux", feature = "iouring"))] UnparkHandle::Uring(inner) => inner.unpark(), - #[cfg(all(unix, feature = "legacy"))] + #[cfg(feature = "legacy")] UnparkHandle::Legacy(inner) => inner.unpark(), #[cfg(all( not(feature = "legacy"), @@ -247,7 +247,7 @@ impl UnparkHandle { CURRENT.with(|inner| match inner { #[cfg(all(target_os = "linux", feature = "iouring"))] Inner::Uring(this) => UringInner::unpark(this).into(), - #[cfg(all(unix, feature = "legacy"))] + #[cfg(feature = "legacy")] Inner::Legacy(this) => LegacyInner::unpark(this).into(), }) } diff --git a/monoio/src/driver/shared_fd.rs b/monoio/src/driver/shared_fd.rs index aee48b8d..a44247ef 100644 --- a/monoio/src/driver/shared_fd.rs +++ b/monoio/src/driver/shared_fd.rs @@ -1,9 +1,11 @@ #[cfg(unix)] use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; #[cfg(windows)] -use std::os::windows::io::{AsRawHandle, FromRawHandle, RawHandle}; +use std::os::windows::io::{AsRawSocket, FromRawSocket, OwnedSocket, RawSocket}; use std::{cell::UnsafeCell, io, rc::Rc}; +#[cfg(windows)] +use super::legacy::iocp::SocketState; use super::CURRENT; // Tracks in-flight operations on a file descriptor. Ensures all in-flight @@ -19,7 +21,7 @@ struct Inner { fd: RawFd, #[cfg(windows)] - fd: RawHandle, + fd: SocketState, // Waker to notify when the close operation completes. state: UnsafeCell, @@ -61,9 +63,9 @@ impl AsRawFd for SharedFd { } #[cfg(windows)] -impl AsRawHandle for SharedFd { - fn as_raw_handle(&self) -> RawHandle { - self.raw_handle() +impl AsRawSocket for SharedFd { + fn as_raw_socket(&self) -> RawSocket { + self.raw_socket() } } @@ -126,8 +128,28 @@ impl SharedFd { } #[cfg(windows)] - pub(crate) fn new(fd: RawHandle) -> io::Result { - unimplemented!() + pub(crate) fn new(fd: RawSocket) -> io::Result { + const RW_INTERESTS: mio::Interest = mio::Interest::READABLE.add(mio::Interest::WRITABLE); + + let mut fd = SocketState::new(fd); + + let state = { + let reg = CURRENT.with(|inner| match inner { + super::Inner::Legacy(inner) => { + super::legacy::LegacyDriver::register(inner, &mut fd, RW_INTERESTS) + } + }); + + State::Legacy(Some(reg?)) + }; + + #[allow(unreachable_code)] + Ok(SharedFd { + inner: Rc::new(Inner { + fd, + state: UnsafeCell::new(state), + }), + }) } #[cfg(unix)] @@ -157,8 +179,17 @@ impl SharedFd { #[cfg(windows)] #[allow(unreachable_code, unused)] - pub(crate) fn new_without_register(fd: RawHandle) -> io::Result { - unimplemented!() + pub(crate) fn new_without_register(fd: RawSocket) -> SharedFd { + let state = CURRENT.with(|inner| match inner { + super::Inner::Legacy(_) => State::Legacy(None), + }); + + SharedFd { + inner: Rc::new(Inner { + fd: SocketState::new(fd), + state: UnsafeCell::new(state), + }), + } } #[cfg(unix)] @@ -168,9 +199,9 @@ impl SharedFd { } #[cfg(windows)] - /// Returns the RawHandle - pub(crate) fn raw_handle(&self) -> RawHandle { - self.inner.fd + /// Returns the RawSocket + pub(crate) fn raw_socket(&self) -> RawSocket { + self.inner.fd.socket } #[cfg(unix)] @@ -217,19 +248,42 @@ impl SharedFd { #[cfg(windows)] /// Try unwrap Rc, then deregister if registered and return rawfd. /// Note: this action will consume self and return rawfd without closing it. - pub(crate) fn try_unwrap(self) -> Result { - unimplemented!() + pub(crate) fn try_unwrap(self) -> Result { + let fd = self.inner.fd; + match Rc::try_unwrap(self.inner) { + Ok(_inner) => { + let state = unsafe { &*_inner.state.get() }; + + #[allow(irrefutable_let_patterns)] + if let State::Legacy(idx) = state { + if CURRENT.is_set() { + CURRENT.with(|inner| { + match inner { + super::Inner::Legacy(inner) => { + // deregister it from driver(Poll and slab) and close fd + if let Some(idx) = idx { + let _ = super::legacy::LegacyDriver::deregister( + inner, *idx, &mut fd, + ); + } + } + } + }) + } + } + Ok(fd.socket) + } + Err(inner) => Err(Self { inner }), + } } #[allow(unused)] pub(crate) fn registered_index(&self) -> Option { let state = unsafe { &*self.inner.state.get() }; match state { - #[cfg(windows)] - _ => unimplemented!(), #[cfg(all(target_os = "linux", feature = "iouring"))] State::Uring(_) => None, - #[cfg(all(unix, feature = "legacy"))] + #[cfg(feature = "legacy")] State::Legacy(s) => *s, #[cfg(all( not(feature = "legacy"), @@ -325,7 +379,7 @@ impl Drop for Inner { let _ = unsafe { std::fs::File::from_raw_fd(fd) }; }; } - #[cfg(all(unix, feature = "legacy"))] + #[cfg(feature = "legacy")] State::Legacy(idx) => { if CURRENT.is_set() { CURRENT.with(|inner| { @@ -346,12 +400,23 @@ impl Drop for Inner { ); } } + #[cfg(windows)] + super::Inner::Legacy(inner) => { + // deregister it from driver(Poll and slab) and close fd + if let Some(idx) = idx { + let _ = super::legacy::LegacyDriver::deregister( + inner, *idx, &mut fd, + ); + } + } } }) } + #[cfg(all(unix, feature = "legacy"))] let _ = unsafe { std::fs::File::from_raw_fd(fd) }; + #[cfg(windows)] + let _ = unsafe { OwnedSocket::from_raw_socket(fd.socket) }; } - // TODO: windows _ => {} } }