Skip to content

Commit

Permalink
Fix IO_STATUS_BLOCK/AFD_POLL_INFO reference counting
Browse files Browse the repository at this point in the history
  • Loading branch information
piscisaureus committed Nov 14, 2019
1 parent cc1fd15 commit 1522d5f
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 104 deletions.
23 changes: 15 additions & 8 deletions src/sys/windows/afd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,18 +144,25 @@ impl Afd {
///
/// # Unsafety
///
/// This function is unsafe due to memory of `IO_STATUS_BLOCK` still being used by `Afd` instance while `Ok(false)` (`STATUS_PENDING`).
/// `iosb` needs to be untouched after the call while operation is in effective at ALL TIME except for `cancel` method.
/// So be careful not to `poll` twice while polling.
/// User should deallocate there overlapped value when error to prevent memory leak.
/// This function is unsafe because the memory of `IO_STATUS_BLOCK` and
/// `AfdPollInfo` may not be freed after poll() returns `Ok(_)`.
///
/// If this function returns `Ok(false)` the operation is pending. The
/// `IO_STATUS_BLOCK` and `AfdPollInfo` structures will be updated by the
/// windows kernel at a later time, and after that the `overlapped` pointer
/// will be reported by the I/O completion port.
///
/// If this function returns `Ok(true)`, the operation has already been
/// completed, but the `overlapped` pointer will still be received by the
/// I/O completion port.
pub unsafe fn poll(
&self,
info: &mut AfdPollInfo,
iosb: *mut IO_STATUS_BLOCK,
iosb: &mut IO_STATUS_BLOCK,
overlapped: PVOID,
) -> io::Result<bool> {
let info_ptr: PVOID = info as *mut _ as PVOID;
(*iosb).u.Status = STATUS_PENDING;
iosb.u.Status = STATUS_PENDING;
let status = NtDeviceIoControlFile(
self.fd.as_raw_handle(),
null_mut(),
Expand Down Expand Up @@ -186,8 +193,8 @@ impl Afd {
/// This function is unsafe due to memory of `IO_STATUS_BLOCK` still being used by `Afd` instance while `Ok(false)` (`STATUS_PENDING`).
/// Use it only with request is still being polled so that you have valid `IO_STATUS_BLOCK` to use.
/// User should NOT deallocate there overlapped value after the `cancel` to prevent double free.
pub unsafe fn cancel(&self, iosb: *mut IO_STATUS_BLOCK) -> io::Result<()> {
if (*iosb).u.Status != STATUS_PENDING {
pub unsafe fn cancel(&self, iosb: &mut IO_STATUS_BLOCK) -> io::Result<()> {
if iosb.u.Status != STATUS_PENDING {
return Ok(());
}

Expand Down
41 changes: 21 additions & 20 deletions src/sys/windows/io_status_block.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
use ntapi::ntioapi::{IO_STATUS_BLOCK_u, IO_STATUS_BLOCK};
use std::cell::UnsafeCell;
use std::fmt;
use ntapi::ntioapi::IO_STATUS_BLOCK;
use std::fmt::{self, Debug, Formatter};
use std::mem::MaybeUninit;
use std::ops::{Deref, DerefMut};

pub struct IoStatusBlock(UnsafeCell<IO_STATUS_BLOCK>);

// There is a pointer field in `IO_STATUS_BLOCK_u`, which we don't use that. Thus it is safe to implement Send here.
unsafe impl Send for IoStatusBlock {}
pub struct IoStatusBlock(IO_STATUS_BLOCK);

impl IoStatusBlock {
pub fn zeroed() -> IoStatusBlock {
let iosb = IO_STATUS_BLOCK {
u: IO_STATUS_BLOCK_u { Status: 0 },
Information: 0,
};
IoStatusBlock(UnsafeCell::new(iosb))
pub fn zeroed() -> Self {
Self(unsafe { MaybeUninit::<IO_STATUS_BLOCK>::zeroed().assume_init() })
}
}

unsafe impl Send for IoStatusBlock {}

pub fn as_ptr(&self) -> *const IO_STATUS_BLOCK {
self.0.get()
impl Debug for IoStatusBlock {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("IoStatusBlock").finish()
}
}

pub fn as_mut_ptr(&self) -> *mut IO_STATUS_BLOCK {
self.0.get()
impl Deref for IoStatusBlock {
type Target = IO_STATUS_BLOCK;
fn deref(&self) -> &Self::Target {
&self.0
}
}

impl fmt::Debug for IoStatusBlock {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IoStatusBlock").finish()
impl DerefMut for IoStatusBlock {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
8 changes: 5 additions & 3 deletions src/sys/windows/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::io;
use std::mem::size_of_val;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::pin::Pin;
use std::sync::{Arc, Mutex, Once};
use winapi::ctypes::c_int;
use winapi::shared::ws2def::SOCKADDR;
Expand Down Expand Up @@ -45,14 +46,15 @@ mod udp;
mod waker;

pub use event::{Event, Events};
pub use io_status_block::IoStatusBlock;
pub use selector::{Selector, SelectorInner, SockState};
pub use tcp::{TcpListener, TcpStream};
pub use udp::UdpSocket;
pub use waker::Waker;

pub trait SocketState {
fn get_sock_state(&self) -> Option<Arc<Mutex<SockState>>>;
fn set_sock_state(&self, sock_state: Option<Arc<Mutex<SockState>>>);
fn get_sock_state(&self) -> Option<Pin<Arc<Mutex<SockState>>>>;
fn set_sock_state(&self, sock_state: Option<Pin<Arc<Mutex<SockState>>>>);
}

use crate::{Interests, Token};
Expand All @@ -61,7 +63,7 @@ struct InternalState {
selector: Arc<SelectorInner>,
token: Token,
interests: Interests,
sock_state: Option<Arc<Mutex<SockState>>>,
sock_state: Option<Pin<Arc<Mutex<SockState>>>>,
}

impl InternalState {
Expand Down
89 changes: 24 additions & 65 deletions src/sys/windows/selector.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use super::afd::{self, Afd, AfdPollInfo};
use super::io_status_block::IoStatusBlock;
use super::Event;
use super::IoStatusBlock;
use super::SocketState;

use crate::sys::Events;
use crate::{Interests, Token};

use miow::iocp::{CompletionPort, CompletionStatus};
use miow::Overlapped;
use std::collections::VecDeque;
use std::mem::size_of;
use std::marker::PhantomPinned;
use std::mem::{forget, size_of, transmute_copy};
use std::os::windows::io::{AsRawSocket, RawSocket};
use std::pin::Pin;
use std::ptr::null_mut;
Expand All @@ -18,8 +20,7 @@ use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use std::{io, ptr};
use winapi::shared::ntdef::NT_SUCCESS;
use winapi::shared::ntdef::{HANDLE, PVOID};
use winapi::shared::ntdef::{HANDLE, NT_SUCCESS, PVOID};
use winapi::shared::ntstatus::STATUS_CANCELLED;
use winapi::shared::winerror::{ERROR_INVALID_HANDLE, ERROR_IO_PENDING, WAIT_TIMEOUT};
use winapi::um::mswsock::SIO_BASE_HANDLE;
Expand Down Expand Up @@ -75,38 +76,6 @@ impl AfdGroup {
}
}

/// This is the deallocation wrapper for overlapped pointer.
/// In case of error or status changing before the overlapped pointer is actually used(or not even being used),
/// this wrapper will decrease the reference count of Arc if being dropped.
/// Remember call `forget` if you have used the Arc, or you could decrease the reference count by two causing double free.
#[derive(Debug)]
struct OverlappedArcWrapper<T>(*const T);

unsafe impl<T> Send for OverlappedArcWrapper<T> {}

impl<T> OverlappedArcWrapper<T> {
fn new(arc: &Arc<T>) -> OverlappedArcWrapper<T> {
OverlappedArcWrapper(Arc::into_raw(arc.clone()))
}

fn forget(&mut self) {
self.0 = 0 as *const T;
}

fn get_ptr(&self) -> *const T {
self.0
}
}

impl<T> Drop for OverlappedArcWrapper<T> {
fn drop(&mut self) {
if self.0 as usize == 0 {
return;
}
drop(unsafe { Arc::from_raw(self.0) });
}
}

#[derive(Debug)]
enum SockPollStatus {
Idle,
Expand All @@ -116,7 +85,7 @@ enum SockPollStatus {

#[derive(Debug)]
pub struct SockState {
iosb: Pin<Box<IoStatusBlock>>,
iosb: IoStatusBlock,
poll_info: AfdPollInfo,
afd: Arc<Afd>,

Expand All @@ -129,15 +98,15 @@ pub struct SockState {
user_data: u64,

poll_status: SockPollStatus,
self_wrapped: Option<OverlappedArcWrapper<Mutex<SockState>>>,

delete_pending: bool,

pinned: PhantomPinned,
}

impl SockState {
fn new(raw_socket: RawSocket, afd: Arc<Afd>) -> io::Result<SockState> {
Ok(SockState {
iosb: Pin::new(Box::new(IoStatusBlock::zeroed())),
iosb: IoStatusBlock::zeroed(),
poll_info: AfdPollInfo::zeroed(),
afd,
raw_socket,
Expand All @@ -146,8 +115,8 @@ impl SockState {
pending_evts: 0,
user_data: 0,
poll_status: SockPollStatus::Idle,
self_wrapped: None,
delete_pending: false,
pinned: PhantomPinned,
})
}

Expand All @@ -162,7 +131,7 @@ impl SockState {
(events & !self.pending_evts) != 0
}

fn update(&mut self, self_arc: &Arc<Mutex<SockState>>) -> io::Result<()> {
fn update(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> io::Result<()> {
assert!(!self.delete_pending);

if let SockPollStatus::Pending = self.poll_status {
Expand Down Expand Up @@ -192,11 +161,12 @@ impl SockState {
self.poll_info.handles[0].status = 0;
self.poll_info.handles[0].events = self.user_evts | afd::POLL_LOCAL_CLOSE;

let wrapped_overlapped = OverlappedArcWrapper::new(self_arc);
let overlapped = wrapped_overlapped.get_ptr() as *const _ as PVOID;
let result = unsafe {
self.afd
.poll(&mut self.poll_info, (*self.iosb).as_mut_ptr(), overlapped)
self.afd.poll(
&mut self.poll_info,
&mut self.iosb,
transmute_copy(self_arc),
)
};
if let Err(e) = result {
let code = e.raw_os_error().unwrap();
Expand All @@ -211,13 +181,9 @@ impl SockState {
}
}

if self.self_wrapped.is_some() {
// This shouldn't be happening. We cannot deallocate already pending overlapped before feed_event so we need to stand out here to declare unreachable.
unreachable!();
}
self.poll_status = SockPollStatus::Pending;
self.self_wrapped = Some(wrapped_overlapped);
self.pending_evts = self.user_evts;
forget(self_arc.clone());
} else {
unreachable!();
}
Expand All @@ -230,7 +196,7 @@ impl SockState {
_ => unreachable!(),
};
unsafe {
self.afd.cancel((*self.iosb).as_mut_ptr())?;
self.afd.cancel(&mut self.iosb)?;
}
self.poll_status = SockPollStatus::Cancelled;
self.pending_evts = 0;
Expand All @@ -239,24 +205,17 @@ impl SockState {

// This is the function called from the overlapped using as Arc<Mutex<SockState>>. Watch out for reference counting.
fn feed_event(&mut self) -> Option<Event> {
if self.self_wrapped.is_some() {
// Forget our arced-self first. We will decrease the reference count by two if we don't do this on overlapped.
self.self_wrapped.as_mut().unwrap().forget();
self.self_wrapped = None;
}

self.poll_status = SockPollStatus::Idle;
self.pending_evts = 0;

let mut afd_events = 0;
// We use the status info in IO_STATUS_BLOCK to determine the socket poll status. It is unsafe to use a pointer of IO_STATUS_BLOCK.
unsafe {
let iosb = &*(*self.iosb).as_ptr();
if self.delete_pending {
return None;
} else if iosb.u.Status == STATUS_CANCELLED {
} else if self.iosb.u.Status == STATUS_CANCELLED {
/* The poll request was cancelled by CancelIoEx. */
} else if !NT_SUCCESS(iosb.u.Status) {
} else if !NT_SUCCESS(self.iosb.u.Status) {
/* The overlapped request itself failed in an unexpected way. */
afd_events = afd::POLL_CONNECT_FAIL;
} else if self.poll_info.number_of_handles < 1 {
Expand Down Expand Up @@ -406,7 +365,7 @@ impl Selector {
#[derive(Debug)]
pub struct SelectorInner {
cp: Arc<CompletionPort>,
update_queue: Mutex<VecDeque<Arc<Mutex<SockState>>>>,
update_queue: Mutex<VecDeque<Pin<Arc<Mutex<SockState>>>>>,
afd_group: AfdGroup,
is_polling: AtomicBool,
}
Expand Down Expand Up @@ -626,7 +585,7 @@ impl SelectorInner {
n += 1;
continue;
}
let sock_arc = Arc::from_raw(iocp_event.overlapped() as *const Mutex<SockState>);
let sock_arc: Pin<Arc<Mutex<SockState>>> = transmute_copy(&iocp_event.overlapped());
let mut sock_guard = sock_arc.lock().unwrap();
match sock_guard.feed_event() {
Some(e) => {
Expand All @@ -646,9 +605,9 @@ impl SelectorInner {
fn _alloc_sock_for_rawsocket(
&self,
raw_socket: RawSocket,
) -> io::Result<Arc<Mutex<SockState>>> {
) -> io::Result<Pin<Arc<Mutex<SockState>>>> {
let afd = self.afd_group.acquire()?;
Ok(Arc::new(Mutex::new(SockState::new(raw_socket, afd)?)))
Ok(Arc::pin(Mutex::new(SockState::new(raw_socket, afd)?)))
}
}

Expand Down
Loading

0 comments on commit 1522d5f

Please sign in to comment.