Skip to content

Commit

Permalink
Simplify windows SockState reference counting (#1154)
Browse files Browse the repository at this point in the history
  • Loading branch information
piscisaureus committed Nov 26, 2019
1 parent 039b09c commit 88407fb
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 88 deletions.
29 changes: 16 additions & 13 deletions src/sys/windows/io_status_block.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
use ntapi::ntioapi::{IO_STATUS_BLOCK_u, IO_STATUS_BLOCK};
use std::cell::UnsafeCell;
use std::fmt;
use std::ops::{Deref, DerefMut};

pub struct IoStatusBlock(UnsafeCell<IO_STATUS_BLOCK>);

// There is a pointer field in `IO_STATUS_BLOCK_u`, which we don't use that. Thus it is safe to implement Send here.
unsafe impl Send for IoStatusBlock {}
pub struct IoStatusBlock(IO_STATUS_BLOCK);

impl IoStatusBlock {
pub fn zeroed() -> IoStatusBlock {
let iosb = IO_STATUS_BLOCK {
pub fn zeroed() -> Self {
Self(IO_STATUS_BLOCK {
u: IO_STATUS_BLOCK_u { Status: 0 },
Information: 0,
};
IoStatusBlock(UnsafeCell::new(iosb))
})
}
}

pub fn as_ptr(&self) -> *const IO_STATUS_BLOCK {
self.0.get()
unsafe impl Send for IoStatusBlock {}

impl Deref for IoStatusBlock {
type Target = IO_STATUS_BLOCK;
fn deref(&self) -> &Self::Target {
&self.0
}
}

pub fn as_mut_ptr(&self) -> *mut IO_STATUS_BLOCK {
self.0.get()
impl DerefMut for IoStatusBlock {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

Expand Down
11 changes: 8 additions & 3 deletions src/sys/windows/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::io;
use std::mem::size_of_val;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::pin::Pin;
use std::sync::{Arc, Mutex, Once};
use winapi::ctypes::c_int;
use winapi::shared::ws2def::SOCKADDR;
Expand Down Expand Up @@ -64,8 +65,12 @@ pub use udp::UdpSocket;
pub use waker::Waker;

pub trait SocketState {
fn get_sock_state(&self) -> Option<Arc<Mutex<SockState>>>;
fn set_sock_state(&self, sock_state: Option<Arc<Mutex<SockState>>>);
// The `SockState` struct needs to be pinned in memory because it contains
// `OVERLAPPED` and `AFD_POLL_INFO` fields which are modified in the
// background by the windows kernel, therefore we need to ensure they are
// never moved to a different memory address.
fn get_sock_state(&self) -> Option<Pin<Arc<Mutex<SockState>>>>;
fn set_sock_state(&self, sock_state: Option<Pin<Arc<Mutex<SockState>>>>);
}

use crate::{Interests, Token};
Expand All @@ -74,7 +79,7 @@ struct InternalState {
selector: Arc<SelectorInner>,
token: Token,
interests: Interests,
sock_state: Option<Arc<Mutex<SockState>>>,
sock_state: Option<Pin<Arc<Mutex<SockState>>>>,
}

impl InternalState {
Expand Down
92 changes: 28 additions & 64 deletions src/sys/windows/selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use crate::{Interests, Token};
use miow::iocp::{CompletionPort, CompletionStatus};
use miow::Overlapped;
use std::collections::VecDeque;
use std::mem::size_of;
use std::marker::PhantomPinned;
use std::mem::{forget, size_of, transmute_copy};
use std::os::windows::io::{AsRawSocket, RawSocket};
use std::pin::Pin;
use std::ptr::null_mut;
Expand Down Expand Up @@ -75,38 +76,6 @@ impl AfdGroup {
}
}

/// This is the deallocation wrapper for overlapped pointer.
/// In case of error or status changing before the overlapped pointer is actually used(or not even being used),
/// this wrapper will decrease the reference count of Arc if being dropped.
/// Remember call `forget` if you have used the Arc, or you could decrease the reference count by two causing double free.
#[derive(Debug)]
struct OverlappedArcWrapper<T>(*const T);

unsafe impl<T> Send for OverlappedArcWrapper<T> {}

impl<T> OverlappedArcWrapper<T> {
fn new(arc: &Arc<T>) -> OverlappedArcWrapper<T> {
OverlappedArcWrapper(Arc::into_raw(arc.clone()))
}

fn forget(&mut self) {
self.0 = 0 as *const T;
}

fn get_ptr(&self) -> *const T {
self.0
}
}

impl<T> Drop for OverlappedArcWrapper<T> {
fn drop(&mut self) {
if self.0 as usize == 0 {
return;
}
drop(unsafe { Arc::from_raw(self.0) });
}
}

#[derive(Debug)]
enum SockPollStatus {
Idle,
Expand All @@ -116,7 +85,7 @@ enum SockPollStatus {

#[derive(Debug)]
pub struct SockState {
iosb: Pin<Box<IoStatusBlock>>,
iosb: IoStatusBlock,
poll_info: AfdPollInfo,
afd: Arc<Afd>,

Expand All @@ -129,15 +98,15 @@ pub struct SockState {
user_data: u64,

poll_status: SockPollStatus,
self_wrapped: Option<OverlappedArcWrapper<Mutex<SockState>>>,

delete_pending: bool,

pinned: PhantomPinned,
}

impl SockState {
fn new(raw_socket: RawSocket, afd: Arc<Afd>) -> io::Result<SockState> {
Ok(SockState {
iosb: Pin::new(Box::new(IoStatusBlock::zeroed())),
iosb: IoStatusBlock::zeroed(),
poll_info: AfdPollInfo::zeroed(),
afd,
raw_socket,
Expand All @@ -146,8 +115,8 @@ impl SockState {
pending_evts: 0,
user_data: 0,
poll_status: SockPollStatus::Idle,
self_wrapped: None,
delete_pending: false,
pinned: PhantomPinned,
})
}

Expand All @@ -162,7 +131,7 @@ impl SockState {
(events & !self.pending_evts) != 0
}

fn update(&mut self, self_arc: &Arc<Mutex<SockState>>) -> io::Result<()> {
fn update(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> io::Result<()> {
assert!(!self.delete_pending);

if let SockPollStatus::Pending = self.poll_status {
Expand All @@ -185,18 +154,16 @@ impl SockState {
/* No poll operation is pending; start one. */
self.poll_info.exclusive = 0;
self.poll_info.number_of_handles = 1;
unsafe {
*self.poll_info.timeout.QuadPart_mut() = std::i64::MAX;
}
*unsafe { self.poll_info.timeout.QuadPart_mut() } = std::i64::MAX;
self.poll_info.handles[0].handle = self.base_socket as HANDLE;
self.poll_info.handles[0].status = 0;
self.poll_info.handles[0].events = self.user_evts | afd::POLL_LOCAL_CLOSE;

let wrapped_overlapped = OverlappedArcWrapper::new(self_arc);
let overlapped = wrapped_overlapped.get_ptr() as *const _ as PVOID;
let overlapped_ptr: PVOID = unsafe { transmute_copy(self_arc) };

let result = unsafe {
self.afd
.poll(&mut self.poll_info, (*self.iosb).as_mut_ptr(), overlapped)
.poll(&mut self.poll_info, &mut *self.iosb, overlapped_ptr)
};
if let Err(e) = result {
let code = e.raw_os_error().unwrap();
Expand All @@ -211,12 +178,16 @@ impl SockState {
}
}

if self.self_wrapped.is_some() {
// This shouldn't be happening. We cannot deallocate already pending overlapped before feed_event so we need to stand out here to declare unreachable.
unreachable!();
}
// We've effectively created another reference to the SockState
// struct by transmuting `self_arc` to a raw pointer, but the Arc's
// reference count has not been increased so far. Now that the
// AFD_POLL operation has succesfully started, it is certain that
// once the operation completes, this raw pointer will be received
// back from the kernel and converted to an actual Arc (in
// `feed_events()`). Therefore increase the Arc reference count here
// by cloning it and immediately leaking the clone.
forget(self_arc.clone());
self.poll_status = SockPollStatus::Pending;
self.self_wrapped = Some(wrapped_overlapped);
self.pending_evts = self.user_evts;
} else {
unreachable!();
Expand All @@ -230,7 +201,7 @@ impl SockState {
_ => unreachable!(),
};
unsafe {
self.afd.cancel((*self.iosb).as_mut_ptr())?;
self.afd.cancel(&mut *self.iosb)?;
}
self.poll_status = SockPollStatus::Cancelled;
self.pending_evts = 0;
Expand All @@ -239,24 +210,17 @@ impl SockState {

// This is the function called from the overlapped using as Arc<Mutex<SockState>>. Watch out for reference counting.
fn feed_event(&mut self) -> Option<Event> {
if self.self_wrapped.is_some() {
// Forget our arced-self first. We will decrease the reference count by two if we don't do this on overlapped.
self.self_wrapped.as_mut().unwrap().forget();
self.self_wrapped = None;
}

self.poll_status = SockPollStatus::Idle;
self.pending_evts = 0;

let mut afd_events = 0;
// We use the status info in IO_STATUS_BLOCK to determine the socket poll status. It is unsafe to use a pointer of IO_STATUS_BLOCK.
unsafe {
let iosb = &*(*self.iosb).as_ptr();
if self.delete_pending {
return None;
} else if iosb.u.Status == STATUS_CANCELLED {
} else if self.iosb.u.Status == STATUS_CANCELLED {
/* The poll request was cancelled by CancelIoEx. */
} else if !NT_SUCCESS(iosb.u.Status) {
} else if !NT_SUCCESS(self.iosb.u.Status) {
/* The overlapped request itself failed in an unexpected way. */
afd_events = afd::POLL_CONNECT_FAIL;
} else if self.poll_info.number_of_handles < 1 {
Expand Down Expand Up @@ -406,7 +370,7 @@ impl Selector {
#[derive(Debug)]
pub struct SelectorInner {
cp: Arc<CompletionPort>,
update_queue: Mutex<VecDeque<Arc<Mutex<SockState>>>>,
update_queue: Mutex<VecDeque<Pin<Arc<Mutex<SockState>>>>>,
afd_group: AfdGroup,
is_polling: AtomicBool,
}
Expand Down Expand Up @@ -602,7 +566,7 @@ impl SelectorInner {
n += 1;
continue;
}
let sock_arc = Arc::from_raw(iocp_event.overlapped() as *const Mutex<SockState>);
let sock_arc: Pin<Arc<Mutex<SockState>>> = transmute_copy(&iocp_event.overlapped());
let mut sock_guard = sock_arc.lock().unwrap();
match sock_guard.feed_event() {
Some(e) => {
Expand All @@ -622,9 +586,9 @@ impl SelectorInner {
fn _alloc_sock_for_rawsocket(
&self,
raw_socket: RawSocket,
) -> io::Result<Arc<Mutex<SockState>>> {
) -> io::Result<Pin<Arc<Mutex<SockState>>>> {
let afd = self.afd_group.acquire()?;
Ok(Arc::new(Mutex::new(SockState::new(raw_socket, afd)?)))
Ok(Arc::pin(Mutex::new(SockState::new(raw_socket, afd)?)))
}
}

Expand Down
13 changes: 7 additions & 6 deletions src/sys/windows/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::io::{self, IoSlice, IoSliceMut, Read, Write};
use std::net::{self, SocketAddr};
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
use std::os::windows::raw::SOCKET as StdSocket; // winapi uses usize, stdlib uses u32/u64.
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use winapi::um::winsock2::{bind, closesocket, connect, listen, SOCKET_ERROR, SOCK_STREAM};

Expand Down Expand Up @@ -127,7 +128,7 @@ impl TcpStream {
}

impl super::SocketState for TcpStream {
fn get_sock_state(&self) -> Option<Arc<Mutex<SockState>>> {
fn get_sock_state(&self) -> Option<Pin<Arc<Mutex<SockState>>>> {
let internal = self.internal.lock().unwrap();
match &*internal {
Some(internal) => match &internal.sock_state {
Expand All @@ -137,7 +138,7 @@ impl super::SocketState for TcpStream {
None => None,
}
}
fn set_sock_state(&self, sock_state: Option<Arc<Mutex<SockState>>>) {
fn set_sock_state(&self, sock_state: Option<Pin<Arc<Mutex<SockState>>>>) {
let mut internal = self.internal.lock().unwrap();
match &mut *internal {
Some(internal) => {
Expand All @@ -160,7 +161,7 @@ impl super::SocketState for TcpStream {
}

impl<'a> super::SocketState for &'a TcpStream {
fn get_sock_state(&self) -> Option<Arc<Mutex<SockState>>> {
fn get_sock_state(&self) -> Option<Pin<Arc<Mutex<SockState>>>> {
let internal = self.internal.lock().unwrap();
match &*internal {
Some(internal) => match &internal.sock_state {
Expand All @@ -170,7 +171,7 @@ impl<'a> super::SocketState for &'a TcpStream {
None => None,
}
}
fn set_sock_state(&self, sock_state: Option<Arc<Mutex<SockState>>>) {
fn set_sock_state(&self, sock_state: Option<Pin<Arc<Mutex<SockState>>>>) {
let mut internal = self.internal.lock().unwrap();
match &mut *internal {
Some(internal) => {
Expand Down Expand Up @@ -389,7 +390,7 @@ impl TcpListener {
}

impl super::SocketState for TcpListener {
fn get_sock_state(&self) -> Option<Arc<Mutex<SockState>>> {
fn get_sock_state(&self) -> Option<Pin<Arc<Mutex<SockState>>>> {
let internal = self.internal.lock().unwrap();
match &*internal {
Some(internal) => match &internal.sock_state {
Expand All @@ -399,7 +400,7 @@ impl super::SocketState for TcpListener {
None => None,
}
}
fn set_sock_state(&self, sock_state: Option<Arc<Mutex<SockState>>>) {
fn set_sock_state(&self, sock_state: Option<Pin<Arc<Mutex<SockState>>>>) {
let mut internal = self.internal.lock().unwrap();
match &mut *internal {
Some(internal) => {
Expand Down
5 changes: 3 additions & 2 deletions src/sys/windows/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{event, poll, Interests, Registry, Token};
use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
use std::os::windows::raw::SOCKET as StdSocket; // winapi uses usize, stdlib uses u32/u64.
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::{fmt, io};
use winapi::um::winsock2::{bind, closesocket, SOCKET_ERROR, SOCK_DGRAM};
Expand Down Expand Up @@ -160,7 +161,7 @@ impl UdpSocket {
}

impl super::SocketState for UdpSocket {
fn get_sock_state(&self) -> Option<Arc<Mutex<SockState>>> {
fn get_sock_state(&self) -> Option<Pin<Arc<Mutex<SockState>>>> {
let internal = self.internal.lock().unwrap();
match &*internal {
Some(internal) => match &internal.sock_state {
Expand All @@ -170,7 +171,7 @@ impl super::SocketState for UdpSocket {
None => None,
}
}
fn set_sock_state(&self, sock_state: Option<Arc<Mutex<SockState>>>) {
fn set_sock_state(&self, sock_state: Option<Pin<Arc<Mutex<SockState>>>>) {
let mut internal = self.internal.lock().unwrap();
match &mut *internal {
Some(internal) => {
Expand Down

0 comments on commit 88407fb

Please sign in to comment.