diff --git a/src/sys/windows/io_status_block.rs b/src/sys/windows/io_status_block.rs index 9035b7b19..0309cb733 100644 --- a/src/sys/windows/io_status_block.rs +++ b/src/sys/windows/io_status_block.rs @@ -1,27 +1,30 @@ use ntapi::ntioapi::{IO_STATUS_BLOCK_u, IO_STATUS_BLOCK}; -use std::cell::UnsafeCell; use std::fmt; +use std::ops::{Deref, DerefMut}; -pub struct IoStatusBlock(UnsafeCell); - -// There is a pointer field in `IO_STATUS_BLOCK_u`, which we don't use that. Thus it is safe to implement Send here. -unsafe impl Send for IoStatusBlock {} +pub struct IoStatusBlock(IO_STATUS_BLOCK); impl IoStatusBlock { - pub fn zeroed() -> IoStatusBlock { - let iosb = IO_STATUS_BLOCK { + pub fn zeroed() -> Self { + Self(IO_STATUS_BLOCK { u: IO_STATUS_BLOCK_u { Status: 0 }, Information: 0, - }; - IoStatusBlock(UnsafeCell::new(iosb)) + }) } +} - pub fn as_ptr(&self) -> *const IO_STATUS_BLOCK { - self.0.get() +unsafe impl Send for IoStatusBlock {} + +impl Deref for IoStatusBlock { + type Target = IO_STATUS_BLOCK; + fn deref(&self) -> &Self::Target { + &self.0 } +} - pub fn as_mut_ptr(&self) -> *mut IO_STATUS_BLOCK { - self.0.get() +impl DerefMut for IoStatusBlock { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } } diff --git a/src/sys/windows/mod.rs b/src/sys/windows/mod.rs index 513d14671..88406e604 100644 --- a/src/sys/windows/mod.rs +++ b/src/sys/windows/mod.rs @@ -1,6 +1,7 @@ use std::io; use std::mem::size_of_val; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::pin::Pin; use std::sync::{Arc, Mutex, Once}; use winapi::ctypes::c_int; use winapi::shared::ws2def::SOCKADDR; @@ -64,8 +65,12 @@ pub use udp::UdpSocket; pub use waker::Waker; pub trait SocketState { - fn get_sock_state(&self) -> Option>>; - fn set_sock_state(&self, sock_state: Option>>); + // The `SockState` struct needs to be pinned in memory because it contains + // `OVERLAPPED` and `AFD_POLL_INFO` fields which are modified in the + // background by the windows kernel, therefore we need to ensure they are + // never moved to a different memory address. + fn get_sock_state(&self) -> Option>>>; + fn set_sock_state(&self, sock_state: Option>>>); } use crate::{Interests, Token}; @@ -74,7 +79,7 @@ struct InternalState { selector: Arc, token: Token, interests: Interests, - sock_state: Option>>, + sock_state: Option>>>, } impl InternalState { diff --git a/src/sys/windows/selector.rs b/src/sys/windows/selector.rs index 434f43700..551576bc9 100644 --- a/src/sys/windows/selector.rs +++ b/src/sys/windows/selector.rs @@ -8,6 +8,7 @@ use crate::{Interests, Token}; use miow::iocp::{CompletionPort, CompletionStatus}; use miow::Overlapped; use std::collections::VecDeque; +use std::marker::PhantomPinned; use std::mem::size_of; use std::os::windows::io::{AsRawSocket, RawSocket}; use std::pin::Pin; @@ -22,6 +23,7 @@ use winapi::shared::ntdef::NT_SUCCESS; use winapi::shared::ntdef::{HANDLE, PVOID}; use winapi::shared::ntstatus::STATUS_CANCELLED; use winapi::shared::winerror::{ERROR_INVALID_HANDLE, ERROR_IO_PENDING, WAIT_TIMEOUT}; +use winapi::um::minwinbase::OVERLAPPED; use winapi::um::mswsock::SIO_BASE_HANDLE; use winapi::um::winsock2::{WSAIoctl, INVALID_SOCKET, SOCKET_ERROR}; @@ -75,38 +77,6 @@ impl AfdGroup { } } -/// This is the deallocation wrapper for overlapped pointer. -/// In case of error or status changing before the overlapped pointer is actually used(or not even being used), -/// this wrapper will decrease the reference count of Arc if being dropped. -/// Remember call `forget` if you have used the Arc, or you could decrease the reference count by two causing double free. -#[derive(Debug)] -struct OverlappedArcWrapper(*const T); - -unsafe impl Send for OverlappedArcWrapper {} - -impl OverlappedArcWrapper { - fn new(arc: &Arc) -> OverlappedArcWrapper { - OverlappedArcWrapper(Arc::into_raw(arc.clone())) - } - - fn forget(&mut self) { - self.0 = 0 as *const T; - } - - fn get_ptr(&self) -> *const T { - self.0 - } -} - -impl Drop for OverlappedArcWrapper { - fn drop(&mut self) { - if self.0 as usize == 0 { - return; - } - drop(unsafe { Arc::from_raw(self.0) }); - } -} - #[derive(Debug)] enum SockPollStatus { Idle, @@ -116,7 +86,7 @@ enum SockPollStatus { #[derive(Debug)] pub struct SockState { - iosb: Pin>, + iosb: IoStatusBlock, poll_info: AfdPollInfo, afd: Arc, @@ -129,15 +99,15 @@ pub struct SockState { user_data: u64, poll_status: SockPollStatus, - self_wrapped: Option>>, - delete_pending: bool, + + pinned: PhantomPinned, } impl SockState { fn new(raw_socket: RawSocket, afd: Arc) -> io::Result { Ok(SockState { - iosb: Pin::new(Box::new(IoStatusBlock::zeroed())), + iosb: IoStatusBlock::zeroed(), poll_info: AfdPollInfo::zeroed(), afd, raw_socket, @@ -146,8 +116,8 @@ impl SockState { pending_evts: 0, user_data: 0, poll_status: SockPollStatus::Idle, - self_wrapped: None, delete_pending: false, + pinned: PhantomPinned, }) } @@ -162,7 +132,7 @@ impl SockState { (events & !self.pending_evts) != 0 } - fn update(&mut self, self_arc: &Arc>) -> io::Result<()> { + fn update(&mut self, self_arc: &Pin>>) -> io::Result<()> { assert!(!self.delete_pending); if let SockPollStatus::Pending = self.poll_status { @@ -185,38 +155,37 @@ impl SockState { /* No poll operation is pending; start one. */ self.poll_info.exclusive = 0; self.poll_info.number_of_handles = 1; - unsafe { - *self.poll_info.timeout.QuadPart_mut() = std::i64::MAX; - } + *unsafe { self.poll_info.timeout.QuadPart_mut() } = std::i64::MAX; self.poll_info.handles[0].handle = self.base_socket as HANDLE; self.poll_info.handles[0].status = 0; self.poll_info.handles[0].events = self.user_evts | afd::POLL_LOCAL_CLOSE; - let wrapped_overlapped = OverlappedArcWrapper::new(self_arc); - let overlapped = wrapped_overlapped.get_ptr() as *const _ as PVOID; + // Increase the ref count as the memory will be used by the kernel. + let overlapped_ptr = into_overlapped(self_arc.clone()); + let result = unsafe { self.afd - .poll(&mut self.poll_info, (*self.iosb).as_mut_ptr(), overlapped) + .poll(&mut self.poll_info, &mut *self.iosb, overlapped_ptr) }; if let Err(e) = result { let code = e.raw_os_error().unwrap(); if code == ERROR_IO_PENDING as i32 { /* Overlapped poll operation in progress; this is expected. */ - } else if code == ERROR_INVALID_HANDLE as i32 { - /* Socket closed; it'll be dropped. */ - self.mark_delete(); - return Ok(()); } else { - return Err(e); + // Since the operation failed it means the kernel won't be + // using the memory any more. + drop(from_overlapped(overlapped_ptr as *mut _)); + if code == ERROR_INVALID_HANDLE as i32 { + /* Socket closed; it'll be dropped. */ + self.mark_delete(); + return Ok(()); + } else { + return Err(e); + } } } - if self.self_wrapped.is_some() { - // This shouldn't be happening. We cannot deallocate already pending overlapped before feed_event so we need to stand out here to declare unreachable. - unreachable!(); - } self.poll_status = SockPollStatus::Pending; - self.self_wrapped = Some(wrapped_overlapped); self.pending_evts = self.user_evts; } else { unreachable!(); @@ -230,7 +199,7 @@ impl SockState { _ => unreachable!(), }; unsafe { - self.afd.cancel((*self.iosb).as_mut_ptr())?; + self.afd.cancel(&mut *self.iosb)?; } self.poll_status = SockPollStatus::Cancelled; self.pending_evts = 0; @@ -239,24 +208,17 @@ impl SockState { // This is the function called from the overlapped using as Arc>. Watch out for reference counting. fn feed_event(&mut self) -> Option { - if self.self_wrapped.is_some() { - // Forget our arced-self first. We will decrease the reference count by two if we don't do this on overlapped. - self.self_wrapped.as_mut().unwrap().forget(); - self.self_wrapped = None; - } - self.poll_status = SockPollStatus::Idle; self.pending_evts = 0; let mut afd_events = 0; // We use the status info in IO_STATUS_BLOCK to determine the socket poll status. It is unsafe to use a pointer of IO_STATUS_BLOCK. unsafe { - let iosb = &*(*self.iosb).as_ptr(); if self.delete_pending { return None; - } else if iosb.u.Status == STATUS_CANCELLED { + } else if self.iosb.u.Status == STATUS_CANCELLED { /* The poll request was cancelled by CancelIoEx. */ - } else if !NT_SUCCESS(iosb.u.Status) { + } else if !NT_SUCCESS(self.iosb.u.Status) { /* The overlapped request itself failed in an unexpected way. */ afd_events = afd::POLL_CONNECT_FAIL; } else if self.poll_info.number_of_handles < 1 { @@ -310,6 +272,21 @@ impl SockState { } } +/// Converts the pointer to a `SockState` into a raw pointer. +/// To revert see `from_overlapped`. +fn into_overlapped(sock_state: Pin>>) -> PVOID { + let overlapped_ptr: *const Mutex = + unsafe { Arc::into_raw(Pin::into_inner_unchecked(sock_state)) }; + overlapped_ptr as *mut _ +} + +/// Convert a raw overlapped pointer into a reference to `SockState`. +/// Reverts `into_overlapped`. +fn from_overlapped(ptr: *mut OVERLAPPED) -> Pin>> { + let sock_ptr: *const Mutex = ptr as *const _; + unsafe { Pin::new_unchecked(Arc::from_raw(sock_ptr)) } +} + impl Drop for SockState { fn drop(&mut self) { self.mark_delete(); @@ -406,7 +383,7 @@ impl Selector { #[derive(Debug)] pub struct SelectorInner { cp: Arc, - update_queue: Mutex>>>, + update_queue: Mutex>>>>, afd_group: AfdGroup, is_polling: AtomicBool, } @@ -414,6 +391,41 @@ pub struct SelectorInner { // We have ensured thread safety by introducing lock manually. unsafe impl Sync for SelectorInner {} +impl Drop for SelectorInner { + fn drop(&mut self) { + loop { + let events_num: usize; + let mut statuses: [CompletionStatus; 1024] = [CompletionStatus::zero(); 1024]; + + let result = self + .cp + .get_many(&mut statuses, Some(std::time::Duration::from_millis(0))); + match result { + Ok(iocp_events) => { + events_num = iocp_events.iter().len(); + for iocp_event in iocp_events.iter() { + if !iocp_event.overlapped().is_null() { + // drain sock state to release memory of Arc reference + let _sock_state = from_overlapped(iocp_event.overlapped()); + } + } + } + + Err(_) => { + break; + } + } + + if events_num < 1024 { + // continue looping until all completion statuses have been drained + break; + } + } + + self.afd_group.release_unused_afd(); + } +} + impl SelectorInner { pub fn new() -> io::Result { CompletionPort::new(0).map(|cp| { @@ -602,8 +614,9 @@ impl SelectorInner { n += 1; continue; } - let sock_arc = Arc::from_raw(iocp_event.overlapped() as *const Mutex); - let mut sock_guard = sock_arc.lock().unwrap(); + + let sock_state = from_overlapped(iocp_event.overlapped()); + let mut sock_guard = sock_state.lock().unwrap(); match sock_guard.feed_event() { Some(e) => { events.push(e); @@ -612,7 +625,7 @@ impl SelectorInner { } n += 1; if !sock_guard.is_pending_deletion() { - update_queue.push_back(sock_arc.clone()); + update_queue.push_back(sock_state.clone()); } } self.afd_group.release_unused_afd(); @@ -622,9 +635,9 @@ impl SelectorInner { fn _alloc_sock_for_rawsocket( &self, raw_socket: RawSocket, - ) -> io::Result>> { + ) -> io::Result>>> { let afd = self.afd_group.acquire()?; - Ok(Arc::new(Mutex::new(SockState::new(raw_socket, afd)?))) + Ok(Arc::pin(Mutex::new(SockState::new(raw_socket, afd)?))) } } diff --git a/src/sys/windows/tcp.rs b/src/sys/windows/tcp.rs index 57d5da569..ad0fd4419 100644 --- a/src/sys/windows/tcp.rs +++ b/src/sys/windows/tcp.rs @@ -8,6 +8,7 @@ use std::io::{self, IoSlice, IoSliceMut, Read, Write}; use std::net::{self, SocketAddr}; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::os::windows::raw::SOCKET as StdSocket; // winapi uses usize, stdlib uses u32/u64. +use std::pin::Pin; use std::sync::{Arc, Mutex}; use winapi::um::winsock2::{bind, closesocket, connect, listen, SOCKET_ERROR, SOCK_STREAM}; @@ -127,7 +128,7 @@ impl TcpStream { } impl super::SocketState for TcpStream { - fn get_sock_state(&self) -> Option>> { + fn get_sock_state(&self) -> Option>>> { let internal = self.internal.lock().unwrap(); match &*internal { Some(internal) => match &internal.sock_state { @@ -137,7 +138,7 @@ impl super::SocketState for TcpStream { None => None, } } - fn set_sock_state(&self, sock_state: Option>>) { + fn set_sock_state(&self, sock_state: Option>>>) { let mut internal = self.internal.lock().unwrap(); match &mut *internal { Some(internal) => { @@ -160,7 +161,7 @@ impl super::SocketState for TcpStream { } impl<'a> super::SocketState for &'a TcpStream { - fn get_sock_state(&self) -> Option>> { + fn get_sock_state(&self) -> Option>>> { let internal = self.internal.lock().unwrap(); match &*internal { Some(internal) => match &internal.sock_state { @@ -170,7 +171,7 @@ impl<'a> super::SocketState for &'a TcpStream { None => None, } } - fn set_sock_state(&self, sock_state: Option>>) { + fn set_sock_state(&self, sock_state: Option>>>) { let mut internal = self.internal.lock().unwrap(); match &mut *internal { Some(internal) => { @@ -389,7 +390,7 @@ impl TcpListener { } impl super::SocketState for TcpListener { - fn get_sock_state(&self) -> Option>> { + fn get_sock_state(&self) -> Option>>> { let internal = self.internal.lock().unwrap(); match &*internal { Some(internal) => match &internal.sock_state { @@ -399,7 +400,7 @@ impl super::SocketState for TcpListener { None => None, } } - fn set_sock_state(&self, sock_state: Option>>) { + fn set_sock_state(&self, sock_state: Option>>>) { let mut internal = self.internal.lock().unwrap(); match &mut *internal { Some(internal) => { diff --git a/src/sys/windows/udp.rs b/src/sys/windows/udp.rs index ece7f8c74..cf4db6fc2 100644 --- a/src/sys/windows/udp.rs +++ b/src/sys/windows/udp.rs @@ -6,6 +6,7 @@ use crate::{event, poll, Interests, Registry, Token}; use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::os::windows::raw::SOCKET as StdSocket; // winapi uses usize, stdlib uses u32/u64. +use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::{fmt, io}; use winapi::um::winsock2::{bind, closesocket, SOCKET_ERROR, SOCK_DGRAM}; @@ -160,7 +161,7 @@ impl UdpSocket { } impl super::SocketState for UdpSocket { - fn get_sock_state(&self) -> Option>> { + fn get_sock_state(&self) -> Option>>> { let internal = self.internal.lock().unwrap(); match &*internal { Some(internal) => match &internal.sock_state { @@ -170,7 +171,7 @@ impl super::SocketState for UdpSocket { None => None, } } - fn set_sock_state(&self, sock_state: Option>>) { + fn set_sock_state(&self, sock_state: Option>>>) { let mut internal = self.internal.lock().unwrap(); match &mut *internal { Some(internal) => {