Skip to content

Commit

Permalink
Fix IO_STATUS_BLOCK/AFD_POLL_INFO reference counting
Browse files Browse the repository at this point in the history
  • Loading branch information
piscisaureus committed Nov 13, 2019
1 parent cc1fd15 commit d6819d9
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 100 deletions.
32 changes: 0 additions & 32 deletions src/sys/windows/io_status_block.rs

This file was deleted.

8 changes: 4 additions & 4 deletions src/sys/windows/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::io;
use std::mem::size_of_val;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::pin::Pin;
use std::sync::{Arc, Mutex, Once};
use winapi::ctypes::c_int;
use winapi::shared::ws2def::SOCKADDR;
Expand Down Expand Up @@ -38,7 +39,6 @@ macro_rules! try_io {

mod afd;
pub mod event;
mod io_status_block;
mod selector;
mod tcp;
mod udp;
Expand All @@ -51,8 +51,8 @@ pub use udp::UdpSocket;
pub use waker::Waker;

pub trait SocketState {
fn get_sock_state(&self) -> Option<Arc<Mutex<SockState>>>;
fn set_sock_state(&self, sock_state: Option<Arc<Mutex<SockState>>>);
fn get_sock_state(&self) -> Option<Pin<Arc<Mutex<SockState>>>>;
fn set_sock_state(&self, sock_state: Option<Pin<Arc<Mutex<SockState>>>>);
}

use crate::{Interests, Token};
Expand All @@ -61,7 +61,7 @@ struct InternalState {
selector: Arc<SelectorInner>,
token: Token,
interests: Interests,
sock_state: Option<Arc<Mutex<SockState>>>,
sock_state: Option<Pin<Arc<Mutex<SockState>>>>,
}

impl InternalState {
Expand Down
100 changes: 44 additions & 56 deletions src/sys/windows/selector.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use super::afd::{self, Afd, AfdPollInfo};
use super::io_status_block::IoStatusBlock;
use super::Event;
use super::SocketState;
use crate::sys::Events;
use crate::{Interests, Token};

use miow::iocp::{CompletionPort, CompletionStatus};
use miow::Overlapped;
use ntapi::ntioapi::IO_STATUS_BLOCK;
use std::collections::VecDeque;
use std::mem::size_of;
use std::fmt::{self, Debug, Formatter};
use std::mem::{forget, size_of, transmute_copy, MaybeUninit};
use std::ops::{Deref, DerefMut};
use std::os::windows::io::{AsRawSocket, RawSocket};
use std::pin::Pin;
use std::ptr::null_mut;
Expand Down Expand Up @@ -75,48 +77,45 @@ impl AfdGroup {
}
}

/// This is the deallocation wrapper for overlapped pointer.
/// In case of error or status changing before the overlapped pointer is actually used(or not even being used),
/// this wrapper will decrease the reference count of Arc if being dropped.
/// Remember call `forget` if you have used the Arc, or you could decrease the reference count by two causing double free.
#[derive(Debug)]
struct OverlappedArcWrapper<T>(*const T);
enum SockPollStatus {
Idle,
Pending,
Cancelled,
}

unsafe impl<T> Send for OverlappedArcWrapper<T> {}
struct IoStatusBlock(IO_STATUS_BLOCK);

impl<T> OverlappedArcWrapper<T> {
fn new(arc: &Arc<T>) -> OverlappedArcWrapper<T> {
OverlappedArcWrapper(Arc::into_raw(arc.clone()))
impl IoStatusBlock {
fn zeroed() -> Self {
Self(unsafe { MaybeUninit::<IO_STATUS_BLOCK>::zeroed().assume_init() })
}
}

fn forget(&mut self) {
self.0 = 0 as *const T;
}
unsafe impl Send for IoStatusBlock {}

fn get_ptr(&self) -> *const T {
self.0
impl Debug for IoStatusBlock {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("IoStatusBlock").finish()
}
}

impl<T> Drop for OverlappedArcWrapper<T> {
fn drop(&mut self) {
if self.0 as usize == 0 {
return;
}
drop(unsafe { Arc::from_raw(self.0) });
impl Deref for IoStatusBlock {
type Target = IO_STATUS_BLOCK;
fn deref(&self) -> &Self::Target {
&self.0
}
}

#[derive(Debug)]
enum SockPollStatus {
Idle,
Pending,
Cancelled,
impl DerefMut for IoStatusBlock {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

#[derive(Debug)]
pub struct SockState {
iosb: Pin<Box<IoStatusBlock>>,
iosb: IoStatusBlock,
poll_info: AfdPollInfo,
afd: Arc<Afd>,

Expand All @@ -129,15 +128,13 @@ pub struct SockState {
user_data: u64,

poll_status: SockPollStatus,
self_wrapped: Option<OverlappedArcWrapper<Mutex<SockState>>>,

delete_pending: bool,
}

impl SockState {
fn new(raw_socket: RawSocket, afd: Arc<Afd>) -> io::Result<SockState> {
Ok(SockState {
iosb: Pin::new(Box::new(IoStatusBlock::zeroed())),
iosb: IoStatusBlock::zeroed(),
poll_info: AfdPollInfo::zeroed(),
afd,
raw_socket,
Expand All @@ -146,7 +143,6 @@ impl SockState {
pending_evts: 0,
user_data: 0,
poll_status: SockPollStatus::Idle,
self_wrapped: None,
delete_pending: false,
})
}
Expand All @@ -162,7 +158,7 @@ impl SockState {
(events & !self.pending_evts) != 0
}

fn update(&mut self, self_arc: &Arc<Mutex<SockState>>) -> io::Result<()> {
fn update(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> io::Result<()> {
assert!(!self.delete_pending);

if let SockPollStatus::Pending = self.poll_status {
Expand Down Expand Up @@ -192,11 +188,12 @@ impl SockState {
self.poll_info.handles[0].status = 0;
self.poll_info.handles[0].events = self.user_evts | afd::POLL_LOCAL_CLOSE;

let wrapped_overlapped = OverlappedArcWrapper::new(self_arc);
let overlapped = wrapped_overlapped.get_ptr() as *const _ as PVOID;
let result = unsafe {
self.afd
.poll(&mut self.poll_info, (*self.iosb).as_mut_ptr(), overlapped)
self.afd.poll(
&mut self.poll_info,
self.iosb.deref_mut(),
transmute_copy(self_arc),
)
};
if let Err(e) = result {
let code = e.raw_os_error().unwrap();
Expand All @@ -211,13 +208,9 @@ impl SockState {
}
}

if self.self_wrapped.is_some() {
// This shouldn't be happening. We cannot deallocate already pending overlapped before feed_event so we need to stand out here to declare unreachable.
unreachable!();
}
self.poll_status = SockPollStatus::Pending;
self.self_wrapped = Some(wrapped_overlapped);
self.pending_evts = self.user_evts;
forget(self_arc.clone());
} else {
unreachable!();
}
Expand All @@ -230,7 +223,7 @@ impl SockState {
_ => unreachable!(),
};
unsafe {
self.afd.cancel((*self.iosb).as_mut_ptr())?;
self.afd.cancel(self.iosb.deref_mut())?;
}
self.poll_status = SockPollStatus::Cancelled;
self.pending_evts = 0;
Expand All @@ -239,24 +232,17 @@ impl SockState {

// This is the function called from the overlapped using as Arc<Mutex<SockState>>. Watch out for reference counting.
fn feed_event(&mut self) -> Option<Event> {
if self.self_wrapped.is_some() {
// Forget our arced-self first. We will decrease the reference count by two if we don't do this on overlapped.
self.self_wrapped.as_mut().unwrap().forget();
self.self_wrapped = None;
}

self.poll_status = SockPollStatus::Idle;
self.pending_evts = 0;

let mut afd_events = 0;
// We use the status info in IO_STATUS_BLOCK to determine the socket poll status. It is unsafe to use a pointer of IO_STATUS_BLOCK.
unsafe {
let iosb = &*(*self.iosb).as_ptr();
if self.delete_pending {
return None;
} else if iosb.u.Status == STATUS_CANCELLED {
} else if self.iosb.u.Status == STATUS_CANCELLED {
/* The poll request was cancelled by CancelIoEx. */
} else if !NT_SUCCESS(iosb.u.Status) {
} else if !NT_SUCCESS(self.iosb.u.Status) {
/* The overlapped request itself failed in an unexpected way. */
afd_events = afd::POLL_CONNECT_FAIL;
} else if self.poll_info.number_of_handles < 1 {
Expand Down Expand Up @@ -406,7 +392,7 @@ impl Selector {
#[derive(Debug)]
pub struct SelectorInner {
cp: Arc<CompletionPort>,
update_queue: Mutex<VecDeque<Arc<Mutex<SockState>>>>,
update_queue: Mutex<VecDeque<Pin<Arc<Mutex<SockState>>>>>,
afd_group: AfdGroup,
is_polling: AtomicBool,
}
Expand Down Expand Up @@ -626,7 +612,7 @@ impl SelectorInner {
n += 1;
continue;
}
let sock_arc = Arc::from_raw(iocp_event.overlapped() as *const Mutex<SockState>);
let sock_arc: Pin<Arc<Mutex<SockState>>> = transmute_copy(&iocp_event.overlapped());
let mut sock_guard = sock_arc.lock().unwrap();
match sock_guard.feed_event() {
Some(e) => {
Expand All @@ -646,9 +632,11 @@ impl SelectorInner {
fn _alloc_sock_for_rawsocket(
&self,
raw_socket: RawSocket,
) -> io::Result<Arc<Mutex<SockState>>> {
) -> io::Result<Pin<Arc<Mutex<SockState>>>> {
let afd = self.afd_group.acquire()?;
Ok(Arc::new(Mutex::new(SockState::new(raw_socket, afd)?)))
Ok(Pin::new(Arc::new(Mutex::new(SockState::new(
raw_socket, afd,
)?))))
}
}

Expand Down
13 changes: 7 additions & 6 deletions src/sys/windows/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::io::{self, IoSlice, IoSliceMut, Read, Write};
use std::net::{self, SocketAddr};
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
use std::os::windows::raw::SOCKET as StdSocket; // winapi uses usize, stdlib uses u32/u64.
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use winapi::um::winsock2::{bind, closesocket, connect, listen, SOCKET_ERROR, SOCK_STREAM};

Expand Down Expand Up @@ -127,7 +128,7 @@ impl TcpStream {
}

impl super::SocketState for TcpStream {
fn get_sock_state(&self) -> Option<Arc<Mutex<SockState>>> {
fn get_sock_state(&self) -> Option<Pin<Arc<Mutex<SockState>>>> {
let internal = self.internal.lock().unwrap();
match &*internal {
Some(internal) => match &internal.sock_state {
Expand All @@ -137,7 +138,7 @@ impl super::SocketState for TcpStream {
None => None,
}
}
fn set_sock_state(&self, sock_state: Option<Arc<Mutex<SockState>>>) {
fn set_sock_state(&self, sock_state: Option<Pin<Arc<Mutex<SockState>>>>) {
let mut internal = self.internal.lock().unwrap();
match &mut *internal {
Some(internal) => {
Expand All @@ -160,7 +161,7 @@ impl super::SocketState for TcpStream {
}

impl<'a> super::SocketState for &'a TcpStream {
fn get_sock_state(&self) -> Option<Arc<Mutex<SockState>>> {
fn get_sock_state(&self) -> Option<Pin<Arc<Mutex<SockState>>>> {
let internal = self.internal.lock().unwrap();
match &*internal {
Some(internal) => match &internal.sock_state {
Expand All @@ -170,7 +171,7 @@ impl<'a> super::SocketState for &'a TcpStream {
None => None,
}
}
fn set_sock_state(&self, sock_state: Option<Arc<Mutex<SockState>>>) {
fn set_sock_state(&self, sock_state: Option<Pin<Arc<Mutex<SockState>>>>) {
let mut internal = self.internal.lock().unwrap();
match &mut *internal {
Some(internal) => {
Expand Down Expand Up @@ -405,7 +406,7 @@ impl TcpListener {
}

impl super::SocketState for TcpListener {
fn get_sock_state(&self) -> Option<Arc<Mutex<SockState>>> {
fn get_sock_state(&self) -> Option<Pin<Arc<Mutex<SockState>>>> {
let internal = self.internal.lock().unwrap();
match &*internal {
Some(internal) => match &internal.sock_state {
Expand All @@ -415,7 +416,7 @@ impl super::SocketState for TcpListener {
None => None,
}
}
fn set_sock_state(&self, sock_state: Option<Arc<Mutex<SockState>>>) {
fn set_sock_state(&self, sock_state: Option<Pin<Arc<Mutex<SockState>>>>) {
let mut internal = self.internal.lock().unwrap();
match &mut *internal {
Some(internal) => {
Expand Down
5 changes: 3 additions & 2 deletions src/sys/windows/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{event, poll, Interests, Registry, Token};
use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
use std::os::windows::raw::SOCKET as StdSocket; // winapi uses usize, stdlib uses u32/u64.
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::{fmt, io};
use winapi::um::winsock2::{bind, closesocket, SOCKET_ERROR, SOCK_DGRAM};
Expand Down Expand Up @@ -160,7 +161,7 @@ impl UdpSocket {
}

impl super::SocketState for UdpSocket {
fn get_sock_state(&self) -> Option<Arc<Mutex<SockState>>> {
fn get_sock_state(&self) -> Option<Pin<Arc<Mutex<SockState>>>> {
let internal = self.internal.lock().unwrap();
match &*internal {
Some(internal) => match &internal.sock_state {
Expand All @@ -170,7 +171,7 @@ impl super::SocketState for UdpSocket {
None => None,
}
}
fn set_sock_state(&self, sock_state: Option<Arc<Mutex<SockState>>>) {
fn set_sock_state(&self, sock_state: Option<Pin<Arc<Mutex<SockState>>>>) {
let mut internal = self.internal.lock().unwrap();
match &mut *internal {
Some(internal) => {
Expand Down

0 comments on commit d6819d9

Please sign in to comment.