diff --git a/src/sys/windows/afd.rs b/src/sys/windows/afd.rs index 88f1c23c7e..b4aa91b210 100644 --- a/src/sys/windows/afd.rs +++ b/src/sys/windows/afd.rs @@ -144,18 +144,25 @@ impl Afd { /// /// # Unsafety /// - /// This function is unsafe due to memory of `IO_STATUS_BLOCK` still being used by `Afd` instance while `Ok(false)` (`STATUS_PENDING`). - /// `iosb` needs to be untouched after the call while operation is in effective at ALL TIME except for `cancel` method. - /// So be careful not to `poll` twice while polling. - /// User should deallocate there overlapped value when error to prevent memory leak. + /// This function is unsafe because the memory of `IO_STATUS_BLOCK` and + /// `AfdPollInfo` may not be freed after poll() returns `Ok(_)`. + /// + /// If this function returns `Ok(false)` the operation is pending. The + /// `IO_STATUS_BLOCK` and `AfdPollInfo` structures will be updated by the + /// windows kernel at a later time, and after that the `overlapped` pointer + /// will be reported by the I/O completion port. + /// + /// If this function returns `Ok(true)`, the operation has already been + /// completed, but the `overlapped` pointer will still be received by the + /// I/O completion port. pub unsafe fn poll( &self, info: &mut AfdPollInfo, - iosb: *mut IO_STATUS_BLOCK, + iosb: &mut IO_STATUS_BLOCK, overlapped: PVOID, ) -> io::Result { let info_ptr: PVOID = info as *mut _ as PVOID; - (*iosb).u.Status = STATUS_PENDING; + iosb.u.Status = STATUS_PENDING; let status = NtDeviceIoControlFile( self.fd.as_raw_handle(), null_mut(), @@ -186,8 +193,8 @@ impl Afd { /// This function is unsafe due to memory of `IO_STATUS_BLOCK` still being used by `Afd` instance while `Ok(false)` (`STATUS_PENDING`). /// Use it only with request is still being polled so that you have valid `IO_STATUS_BLOCK` to use. /// User should NOT deallocate there overlapped value after the `cancel` to prevent double free. - pub unsafe fn cancel(&self, iosb: *mut IO_STATUS_BLOCK) -> io::Result<()> { - if (*iosb).u.Status != STATUS_PENDING { + pub unsafe fn cancel(&self, iosb: &mut IO_STATUS_BLOCK) -> io::Result<()> { + if iosb.u.Status != STATUS_PENDING { return Ok(()); } diff --git a/src/sys/windows/io_status_block.rs b/src/sys/windows/io_status_block.rs index 9035b7b19c..f410813ced 100644 --- a/src/sys/windows/io_status_block.rs +++ b/src/sys/windows/io_status_block.rs @@ -1,32 +1,33 @@ -use ntapi::ntioapi::{IO_STATUS_BLOCK_u, IO_STATUS_BLOCK}; -use std::cell::UnsafeCell; -use std::fmt; +use ntapi::ntioapi::IO_STATUS_BLOCK; +use std::fmt::{self, Debug, Formatter}; +use std::mem::MaybeUninit; +use std::ops::{Deref, DerefMut}; -pub struct IoStatusBlock(UnsafeCell); - -// There is a pointer field in `IO_STATUS_BLOCK_u`, which we don't use that. Thus it is safe to implement Send here. -unsafe impl Send for IoStatusBlock {} +pub struct IoStatusBlock(IO_STATUS_BLOCK); impl IoStatusBlock { - pub fn zeroed() -> IoStatusBlock { - let iosb = IO_STATUS_BLOCK { - u: IO_STATUS_BLOCK_u { Status: 0 }, - Information: 0, - }; - IoStatusBlock(UnsafeCell::new(iosb)) + pub fn zeroed() -> Self { + Self(unsafe { MaybeUninit::::zeroed().assume_init() }) } +} + +unsafe impl Send for IoStatusBlock {} - pub fn as_ptr(&self) -> *const IO_STATUS_BLOCK { - self.0.get() +impl Debug for IoStatusBlock { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("IoStatusBlock").finish() } +} - pub fn as_mut_ptr(&self) -> *mut IO_STATUS_BLOCK { - self.0.get() +impl Deref for IoStatusBlock { + type Target = IO_STATUS_BLOCK; + fn deref(&self) -> &Self::Target { + &self.0 } } -impl fmt::Debug for IoStatusBlock { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("IoStatusBlock").finish() +impl DerefMut for IoStatusBlock { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } } diff --git a/src/sys/windows/mod.rs b/src/sys/windows/mod.rs index 0a6eff7560..33e58d8413 100644 --- a/src/sys/windows/mod.rs +++ b/src/sys/windows/mod.rs @@ -1,6 +1,7 @@ use std::io; use std::mem::size_of_val; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::pin::Pin; use std::sync::{Arc, Mutex, Once}; use winapi::ctypes::c_int; use winapi::shared::ws2def::SOCKADDR; @@ -45,14 +46,15 @@ mod udp; mod waker; pub use event::{Event, Events}; +pub use io_status_block::IoStatusBlock; pub use selector::{Selector, SelectorInner, SockState}; pub use tcp::{TcpListener, TcpStream}; pub use udp::UdpSocket; pub use waker::Waker; pub trait SocketState { - fn get_sock_state(&self) -> Option>>; - fn set_sock_state(&self, sock_state: Option>>); + fn get_sock_state(&self) -> Option>>>; + fn set_sock_state(&self, sock_state: Option>>>); } use crate::{Interests, Token}; @@ -61,7 +63,7 @@ struct InternalState { selector: Arc, token: Token, interests: Interests, - sock_state: Option>>, + sock_state: Option>>>, } impl InternalState { diff --git a/src/sys/windows/selector.rs b/src/sys/windows/selector.rs index cbb84aac20..5aa4a5b9a6 100644 --- a/src/sys/windows/selector.rs +++ b/src/sys/windows/selector.rs @@ -1,14 +1,16 @@ use super::afd::{self, Afd, AfdPollInfo}; -use super::io_status_block::IoStatusBlock; use super::Event; +use super::IoStatusBlock; use super::SocketState; + use crate::sys::Events; use crate::{Interests, Token}; use miow::iocp::{CompletionPort, CompletionStatus}; use miow::Overlapped; use std::collections::VecDeque; -use std::mem::size_of; +use std::marker::PhantomPinned; +use std::mem::{forget, size_of, transmute_copy}; use std::os::windows::io::{AsRawSocket, RawSocket}; use std::pin::Pin; use std::ptr::null_mut; @@ -18,8 +20,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use std::{io, ptr}; -use winapi::shared::ntdef::NT_SUCCESS; -use winapi::shared::ntdef::{HANDLE, PVOID}; +use winapi::shared::ntdef::{HANDLE, NT_SUCCESS, PVOID}; use winapi::shared::ntstatus::STATUS_CANCELLED; use winapi::shared::winerror::{ERROR_INVALID_HANDLE, ERROR_IO_PENDING, WAIT_TIMEOUT}; use winapi::um::mswsock::SIO_BASE_HANDLE; @@ -75,38 +76,6 @@ impl AfdGroup { } } -/// This is the deallocation wrapper for overlapped pointer. -/// In case of error or status changing before the overlapped pointer is actually used(or not even being used), -/// this wrapper will decrease the reference count of Arc if being dropped. -/// Remember call `forget` if you have used the Arc, or you could decrease the reference count by two causing double free. -#[derive(Debug)] -struct OverlappedArcWrapper(*const T); - -unsafe impl Send for OverlappedArcWrapper {} - -impl OverlappedArcWrapper { - fn new(arc: &Arc) -> OverlappedArcWrapper { - OverlappedArcWrapper(Arc::into_raw(arc.clone())) - } - - fn forget(&mut self) { - self.0 = 0 as *const T; - } - - fn get_ptr(&self) -> *const T { - self.0 - } -} - -impl Drop for OverlappedArcWrapper { - fn drop(&mut self) { - if self.0 as usize == 0 { - return; - } - drop(unsafe { Arc::from_raw(self.0) }); - } -} - #[derive(Debug)] enum SockPollStatus { Idle, @@ -116,7 +85,7 @@ enum SockPollStatus { #[derive(Debug)] pub struct SockState { - iosb: Pin>, + iosb: IoStatusBlock, poll_info: AfdPollInfo, afd: Arc, @@ -129,15 +98,15 @@ pub struct SockState { user_data: u64, poll_status: SockPollStatus, - self_wrapped: Option>>, - delete_pending: bool, + + pinned: PhantomPinned, } impl SockState { fn new(raw_socket: RawSocket, afd: Arc) -> io::Result { Ok(SockState { - iosb: Pin::new(Box::new(IoStatusBlock::zeroed())), + iosb: IoStatusBlock::zeroed(), poll_info: AfdPollInfo::zeroed(), afd, raw_socket, @@ -146,8 +115,8 @@ impl SockState { pending_evts: 0, user_data: 0, poll_status: SockPollStatus::Idle, - self_wrapped: None, delete_pending: false, + pinned: PhantomPinned, }) } @@ -162,7 +131,7 @@ impl SockState { (events & !self.pending_evts) != 0 } - fn update(&mut self, self_arc: &Arc>) -> io::Result<()> { + fn update(&mut self, self_arc: &Pin>>) -> io::Result<()> { assert!(!self.delete_pending); if let SockPollStatus::Pending = self.poll_status { @@ -192,11 +161,12 @@ impl SockState { self.poll_info.handles[0].status = 0; self.poll_info.handles[0].events = self.user_evts | afd::POLL_LOCAL_CLOSE; - let wrapped_overlapped = OverlappedArcWrapper::new(self_arc); - let overlapped = wrapped_overlapped.get_ptr() as *const _ as PVOID; let result = unsafe { - self.afd - .poll(&mut self.poll_info, (*self.iosb).as_mut_ptr(), overlapped) + self.afd.poll( + &mut self.poll_info, + &mut self.iosb, + transmute_copy(self_arc), + ) }; if let Err(e) = result { let code = e.raw_os_error().unwrap(); @@ -211,13 +181,9 @@ impl SockState { } } - if self.self_wrapped.is_some() { - // This shouldn't be happening. We cannot deallocate already pending overlapped before feed_event so we need to stand out here to declare unreachable. - unreachable!(); - } self.poll_status = SockPollStatus::Pending; - self.self_wrapped = Some(wrapped_overlapped); self.pending_evts = self.user_evts; + forget(self_arc.clone()); } else { unreachable!(); } @@ -230,7 +196,7 @@ impl SockState { _ => unreachable!(), }; unsafe { - self.afd.cancel((*self.iosb).as_mut_ptr())?; + self.afd.cancel(&mut self.iosb)?; } self.poll_status = SockPollStatus::Cancelled; self.pending_evts = 0; @@ -239,24 +205,17 @@ impl SockState { // This is the function called from the overlapped using as Arc>. Watch out for reference counting. fn feed_event(&mut self) -> Option { - if self.self_wrapped.is_some() { - // Forget our arced-self first. We will decrease the reference count by two if we don't do this on overlapped. - self.self_wrapped.as_mut().unwrap().forget(); - self.self_wrapped = None; - } - self.poll_status = SockPollStatus::Idle; self.pending_evts = 0; let mut afd_events = 0; // We use the status info in IO_STATUS_BLOCK to determine the socket poll status. It is unsafe to use a pointer of IO_STATUS_BLOCK. unsafe { - let iosb = &*(*self.iosb).as_ptr(); if self.delete_pending { return None; - } else if iosb.u.Status == STATUS_CANCELLED { + } else if self.iosb.u.Status == STATUS_CANCELLED { /* The poll request was cancelled by CancelIoEx. */ - } else if !NT_SUCCESS(iosb.u.Status) { + } else if !NT_SUCCESS(self.iosb.u.Status) { /* The overlapped request itself failed in an unexpected way. */ afd_events = afd::POLL_CONNECT_FAIL; } else if self.poll_info.number_of_handles < 1 { @@ -406,7 +365,7 @@ impl Selector { #[derive(Debug)] pub struct SelectorInner { cp: Arc, - update_queue: Mutex>>>, + update_queue: Mutex>>>>, afd_group: AfdGroup, is_polling: AtomicBool, } @@ -626,7 +585,7 @@ impl SelectorInner { n += 1; continue; } - let sock_arc = Arc::from_raw(iocp_event.overlapped() as *const Mutex); + let sock_arc: Pin>> = transmute_copy(&iocp_event.overlapped()); let mut sock_guard = sock_arc.lock().unwrap(); match sock_guard.feed_event() { Some(e) => { @@ -646,9 +605,9 @@ impl SelectorInner { fn _alloc_sock_for_rawsocket( &self, raw_socket: RawSocket, - ) -> io::Result>> { + ) -> io::Result>>> { let afd = self.afd_group.acquire()?; - Ok(Arc::new(Mutex::new(SockState::new(raw_socket, afd)?))) + Ok(Arc::pin(Mutex::new(SockState::new(raw_socket, afd)?))) } } diff --git a/src/sys/windows/tcp.rs b/src/sys/windows/tcp.rs index 8dbc7d8e0c..368de55649 100644 --- a/src/sys/windows/tcp.rs +++ b/src/sys/windows/tcp.rs @@ -8,6 +8,7 @@ use std::io::{self, IoSlice, IoSliceMut, Read, Write}; use std::net::{self, SocketAddr}; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::os::windows::raw::SOCKET as StdSocket; // winapi uses usize, stdlib uses u32/u64. +use std::pin::Pin; use std::sync::{Arc, Mutex}; use winapi::um::winsock2::{bind, closesocket, connect, listen, SOCKET_ERROR, SOCK_STREAM}; @@ -127,7 +128,7 @@ impl TcpStream { } impl super::SocketState for TcpStream { - fn get_sock_state(&self) -> Option>> { + fn get_sock_state(&self) -> Option>>> { let internal = self.internal.lock().unwrap(); match &*internal { Some(internal) => match &internal.sock_state { @@ -137,7 +138,7 @@ impl super::SocketState for TcpStream { None => None, } } - fn set_sock_state(&self, sock_state: Option>>) { + fn set_sock_state(&self, sock_state: Option>>>) { let mut internal = self.internal.lock().unwrap(); match &mut *internal { Some(internal) => { @@ -160,7 +161,7 @@ impl super::SocketState for TcpStream { } impl<'a> super::SocketState for &'a TcpStream { - fn get_sock_state(&self) -> Option>> { + fn get_sock_state(&self) -> Option>>> { let internal = self.internal.lock().unwrap(); match &*internal { Some(internal) => match &internal.sock_state { @@ -170,7 +171,7 @@ impl<'a> super::SocketState for &'a TcpStream { None => None, } } - fn set_sock_state(&self, sock_state: Option>>) { + fn set_sock_state(&self, sock_state: Option>>>) { let mut internal = self.internal.lock().unwrap(); match &mut *internal { Some(internal) => { @@ -405,7 +406,7 @@ impl TcpListener { } impl super::SocketState for TcpListener { - fn get_sock_state(&self) -> Option>> { + fn get_sock_state(&self) -> Option>>> { let internal = self.internal.lock().unwrap(); match &*internal { Some(internal) => match &internal.sock_state { @@ -415,7 +416,7 @@ impl super::SocketState for TcpListener { None => None, } } - fn set_sock_state(&self, sock_state: Option>>) { + fn set_sock_state(&self, sock_state: Option>>>) { let mut internal = self.internal.lock().unwrap(); match &mut *internal { Some(internal) => { diff --git a/src/sys/windows/udp.rs b/src/sys/windows/udp.rs index ece7f8c745..cf4db6fc29 100644 --- a/src/sys/windows/udp.rs +++ b/src/sys/windows/udp.rs @@ -6,6 +6,7 @@ use crate::{event, poll, Interests, Registry, Token}; use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::os::windows::raw::SOCKET as StdSocket; // winapi uses usize, stdlib uses u32/u64. +use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::{fmt, io}; use winapi::um::winsock2::{bind, closesocket, SOCKET_ERROR, SOCK_DGRAM}; @@ -160,7 +161,7 @@ impl UdpSocket { } impl super::SocketState for UdpSocket { - fn get_sock_state(&self) -> Option>> { + fn get_sock_state(&self) -> Option>>> { let internal = self.internal.lock().unwrap(); match &*internal { Some(internal) => match &internal.sock_state { @@ -170,7 +171,7 @@ impl super::SocketState for UdpSocket { None => None, } } - fn set_sock_state(&self, sock_state: Option>>) { + fn set_sock_state(&self, sock_state: Option>>>) { let mut internal = self.internal.lock().unwrap(); match &mut *internal { Some(internal) => {