Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify windows SockState reference counting #1154

Merged
merged 4 commits into from
Nov 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions src/sys/windows/io_status_block.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
use ntapi::ntioapi::{IO_STATUS_BLOCK_u, IO_STATUS_BLOCK};
use std::cell::UnsafeCell;
use std::fmt;
use std::ops::{Deref, DerefMut};

pub struct IoStatusBlock(UnsafeCell<IO_STATUS_BLOCK>);

// There is a pointer field in `IO_STATUS_BLOCK_u`, which we don't use that. Thus it is safe to implement Send here.
unsafe impl Send for IoStatusBlock {}
pub struct IoStatusBlock(IO_STATUS_BLOCK);

impl IoStatusBlock {
pub fn zeroed() -> IoStatusBlock {
let iosb = IO_STATUS_BLOCK {
pub fn zeroed() -> Self {
Self(IO_STATUS_BLOCK {
u: IO_STATUS_BLOCK_u { Status: 0 },
Information: 0,
};
IoStatusBlock(UnsafeCell::new(iosb))
})
}
}

pub fn as_ptr(&self) -> *const IO_STATUS_BLOCK {
self.0.get()
unsafe impl Send for IoStatusBlock {}

impl Deref for IoStatusBlock {
type Target = IO_STATUS_BLOCK;
fn deref(&self) -> &Self::Target {
&self.0
}
}

pub fn as_mut_ptr(&self) -> *mut IO_STATUS_BLOCK {
self.0.get()
impl DerefMut for IoStatusBlock {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

Expand Down
11 changes: 8 additions & 3 deletions src/sys/windows/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::io;
use std::mem::size_of_val;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::pin::Pin;
use std::sync::{Arc, Mutex, Once};
use winapi::ctypes::c_int;
use winapi::shared::ws2def::SOCKADDR;
Expand Down Expand Up @@ -64,8 +65,12 @@ pub use udp::UdpSocket;
pub use waker::Waker;

pub trait SocketState {
fn get_sock_state(&self) -> Option<Arc<Mutex<SockState>>>;
fn set_sock_state(&self, sock_state: Option<Arc<Mutex<SockState>>>);
// The `SockState` struct needs to be pinned in memory because it contains
// `OVERLAPPED` and `AFD_POLL_INFO` fields which are modified in the
// background by the windows kernel, therefore we need to ensure they are
// never moved to a different memory address.
fn get_sock_state(&self) -> Option<Pin<Arc<Mutex<SockState>>>>;
piscisaureus marked this conversation as resolved.
Show resolved Hide resolved
fn set_sock_state(&self, sock_state: Option<Pin<Arc<Mutex<SockState>>>>);
}

use crate::{Interests, Token};
Expand All @@ -74,7 +79,7 @@ struct InternalState {
selector: Arc<SelectorInner>,
token: Token,
interests: Interests,
sock_state: Option<Arc<Mutex<SockState>>>,
sock_state: Option<Pin<Arc<Mutex<SockState>>>>,
}

impl InternalState {
Expand Down
153 changes: 83 additions & 70 deletions src/sys/windows/selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::{Interests, Token};
use miow::iocp::{CompletionPort, CompletionStatus};
use miow::Overlapped;
use std::collections::VecDeque;
use std::marker::PhantomPinned;
use std::mem::size_of;
use std::os::windows::io::{AsRawSocket, RawSocket};
use std::pin::Pin;
Expand All @@ -22,6 +23,7 @@ use winapi::shared::ntdef::NT_SUCCESS;
use winapi::shared::ntdef::{HANDLE, PVOID};
use winapi::shared::ntstatus::STATUS_CANCELLED;
use winapi::shared::winerror::{ERROR_INVALID_HANDLE, ERROR_IO_PENDING, WAIT_TIMEOUT};
use winapi::um::minwinbase::OVERLAPPED;
use winapi::um::mswsock::SIO_BASE_HANDLE;
use winapi::um::winsock2::{WSAIoctl, INVALID_SOCKET, SOCKET_ERROR};

Expand Down Expand Up @@ -75,38 +77,6 @@ impl AfdGroup {
}
}

/// This is the deallocation wrapper for overlapped pointer.
/// In case of error or status changing before the overlapped pointer is actually used(or not even being used),
/// this wrapper will decrease the reference count of Arc if being dropped.
/// Remember call `forget` if you have used the Arc, or you could decrease the reference count by two causing double free.
#[derive(Debug)]
struct OverlappedArcWrapper<T>(*const T);

unsafe impl<T> Send for OverlappedArcWrapper<T> {}

impl<T> OverlappedArcWrapper<T> {
fn new(arc: &Arc<T>) -> OverlappedArcWrapper<T> {
OverlappedArcWrapper(Arc::into_raw(arc.clone()))
}

fn forget(&mut self) {
self.0 = 0 as *const T;
}

fn get_ptr(&self) -> *const T {
self.0
}
}

impl<T> Drop for OverlappedArcWrapper<T> {
fn drop(&mut self) {
if self.0 as usize == 0 {
return;
}
drop(unsafe { Arc::from_raw(self.0) });
}
}

#[derive(Debug)]
enum SockPollStatus {
Idle,
Expand All @@ -116,7 +86,7 @@ enum SockPollStatus {

#[derive(Debug)]
pub struct SockState {
iosb: Pin<Box<IoStatusBlock>>,
iosb: IoStatusBlock,
poll_info: AfdPollInfo,
afd: Arc<Afd>,

Expand All @@ -129,15 +99,15 @@ pub struct SockState {
user_data: u64,

poll_status: SockPollStatus,
self_wrapped: Option<OverlappedArcWrapper<Mutex<SockState>>>,

delete_pending: bool,

pinned: PhantomPinned,
}

impl SockState {
fn new(raw_socket: RawSocket, afd: Arc<Afd>) -> io::Result<SockState> {
Ok(SockState {
iosb: Pin::new(Box::new(IoStatusBlock::zeroed())),
iosb: IoStatusBlock::zeroed(),
poll_info: AfdPollInfo::zeroed(),
afd,
raw_socket,
Expand All @@ -146,8 +116,8 @@ impl SockState {
pending_evts: 0,
user_data: 0,
poll_status: SockPollStatus::Idle,
self_wrapped: None,
delete_pending: false,
pinned: PhantomPinned,
})
}

Expand All @@ -162,7 +132,7 @@ impl SockState {
(events & !self.pending_evts) != 0
}

fn update(&mut self, self_arc: &Arc<Mutex<SockState>>) -> io::Result<()> {
fn update(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> io::Result<()> {
assert!(!self.delete_pending);

if let SockPollStatus::Pending = self.poll_status {
Expand All @@ -185,38 +155,37 @@ impl SockState {
/* No poll operation is pending; start one. */
self.poll_info.exclusive = 0;
self.poll_info.number_of_handles = 1;
unsafe {
*self.poll_info.timeout.QuadPart_mut() = std::i64::MAX;
}
*unsafe { self.poll_info.timeout.QuadPart_mut() } = std::i64::MAX;
self.poll_info.handles[0].handle = self.base_socket as HANDLE;
self.poll_info.handles[0].status = 0;
self.poll_info.handles[0].events = self.user_evts | afd::POLL_LOCAL_CLOSE;

let wrapped_overlapped = OverlappedArcWrapper::new(self_arc);
let overlapped = wrapped_overlapped.get_ptr() as *const _ as PVOID;
// Increase the ref count as the memory will be used by the kernel.
let overlapped_ptr = into_overlapped(self_arc.clone());

let result = unsafe {
self.afd
.poll(&mut self.poll_info, (*self.iosb).as_mut_ptr(), overlapped)
.poll(&mut self.poll_info, &mut *self.iosb, overlapped_ptr)
};
if let Err(e) = result {
let code = e.raw_os_error().unwrap();
if code == ERROR_IO_PENDING as i32 {
/* Overlapped poll operation in progress; this is expected. */
} else if code == ERROR_INVALID_HANDLE as i32 {
/* Socket closed; it'll be dropped. */
self.mark_delete();
return Ok(());
} else {
return Err(e);
// Since the operation failed it means the kernel won't be
// using the memory any more.
drop(from_overlapped(overlapped_ptr as *mut _));
if code == ERROR_INVALID_HANDLE as i32 {
/* Socket closed; it'll be dropped. */
self.mark_delete();
return Ok(());
} else {
return Err(e);
}
}
}

if self.self_wrapped.is_some() {
// This shouldn't be happening. We cannot deallocate already pending overlapped before feed_event so we need to stand out here to declare unreachable.
unreachable!();
}
self.poll_status = SockPollStatus::Pending;
self.self_wrapped = Some(wrapped_overlapped);
self.pending_evts = self.user_evts;
} else {
unreachable!();
Expand All @@ -230,7 +199,7 @@ impl SockState {
_ => unreachable!(),
};
unsafe {
self.afd.cancel((*self.iosb).as_mut_ptr())?;
self.afd.cancel(&mut *self.iosb)?;
}
self.poll_status = SockPollStatus::Cancelled;
self.pending_evts = 0;
Expand All @@ -239,24 +208,17 @@ impl SockState {

// This is the function called from the overlapped using as Arc<Mutex<SockState>>. Watch out for reference counting.
fn feed_event(&mut self) -> Option<Event> {
if self.self_wrapped.is_some() {
// Forget our arced-self first. We will decrease the reference count by two if we don't do this on overlapped.
self.self_wrapped.as_mut().unwrap().forget();
self.self_wrapped = None;
}

self.poll_status = SockPollStatus::Idle;
self.pending_evts = 0;

let mut afd_events = 0;
// We use the status info in IO_STATUS_BLOCK to determine the socket poll status. It is unsafe to use a pointer of IO_STATUS_BLOCK.
unsafe {
let iosb = &*(*self.iosb).as_ptr();
if self.delete_pending {
return None;
} else if iosb.u.Status == STATUS_CANCELLED {
} else if self.iosb.u.Status == STATUS_CANCELLED {
/* The poll request was cancelled by CancelIoEx. */
} else if !NT_SUCCESS(iosb.u.Status) {
} else if !NT_SUCCESS(self.iosb.u.Status) {
/* The overlapped request itself failed in an unexpected way. */
afd_events = afd::POLL_CONNECT_FAIL;
} else if self.poll_info.number_of_handles < 1 {
Expand Down Expand Up @@ -310,6 +272,21 @@ impl SockState {
}
}

/// Converts the pointer to a `SockState` into a raw pointer.
/// To revert see `from_overlapped`.
fn into_overlapped(sock_state: Pin<Arc<Mutex<SockState>>>) -> PVOID {
let overlapped_ptr: *const Mutex<SockState> =
unsafe { Arc::into_raw(Pin::into_inner_unchecked(sock_state)) };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks CI tests on Windows. If this gets merged as it is, it will block testing on Windows side because CI will always fail here.

error[E0658]: use of unstable library feature 'pin_into_inner'
--> src\sys\windows\selector.rs:279:32
|
279 | unsafe { Arc::into_raw(Pin::into_inner_unchecked(sock_state)) };
| ^^^^^^^^^^^^^^^^^^^^^^^^^
|
= note: for more information, see rust-lang/rust#60245

error: aborting due to previous error

For more information about this error, try rustc --explain E0658.
error: Could not compile mio.

To learn more, run the command again with --verbose.

overlapped_ptr as *mut _
}

/// Convert a raw overlapped pointer into a reference to `SockState`.
/// Reverts `into_overlapped`.
fn from_overlapped(ptr: *mut OVERLAPPED) -> Pin<Arc<Mutex<SockState>>> {
let sock_ptr: *const Mutex<SockState> = ptr as *const _;
unsafe { Pin::new_unchecked(Arc::from_raw(sock_ptr)) }
}

impl Drop for SockState {
fn drop(&mut self) {
self.mark_delete();
Expand Down Expand Up @@ -406,14 +383,49 @@ impl Selector {
#[derive(Debug)]
pub struct SelectorInner {
cp: Arc<CompletionPort>,
update_queue: Mutex<VecDeque<Arc<Mutex<SockState>>>>,
update_queue: Mutex<VecDeque<Pin<Arc<Mutex<SockState>>>>>,
afd_group: AfdGroup,
is_polling: AtomicBool,
}

// We have ensured thread safety by introducing lock manually.
unsafe impl Sync for SelectorInner {}

impl Drop for SelectorInner {
fn drop(&mut self) {
loop {
let events_num: usize;
let mut statuses: [CompletionStatus; 1024] = [CompletionStatus::zero(); 1024];

let result = self
.cp
.get_many(&mut statuses, Some(std::time::Duration::from_millis(0)));
match result {
Ok(iocp_events) => {
events_num = iocp_events.iter().len();
for iocp_event in iocp_events.iter() {
if !iocp_event.overlapped().is_null() {
// drain sock state to release memory of Arc reference
let _sock_state = from_overlapped(iocp_event.overlapped());
}
}
}

Err(_) => {
break;
}
}

if events_num < 1024 {
// continue looping until all completion statuses have been drained
break;
}
}

self.afd_group.release_unused_afd();
}
}

impl SelectorInner {
pub fn new() -> io::Result<SelectorInner> {
CompletionPort::new(0).map(|cp| {
Expand Down Expand Up @@ -602,8 +614,9 @@ impl SelectorInner {
n += 1;
continue;
}
let sock_arc = Arc::from_raw(iocp_event.overlapped() as *const Mutex<SockState>);
let mut sock_guard = sock_arc.lock().unwrap();

let sock_state = from_overlapped(iocp_event.overlapped());
let mut sock_guard = sock_state.lock().unwrap();
match sock_guard.feed_event() {
Some(e) => {
events.push(e);
Expand All @@ -612,7 +625,7 @@ impl SelectorInner {
}
n += 1;
if !sock_guard.is_pending_deletion() {
update_queue.push_back(sock_arc.clone());
update_queue.push_back(sock_state.clone());
}
}
self.afd_group.release_unused_afd();
Expand All @@ -622,9 +635,9 @@ impl SelectorInner {
fn _alloc_sock_for_rawsocket(
&self,
raw_socket: RawSocket,
) -> io::Result<Arc<Mutex<SockState>>> {
) -> io::Result<Pin<Arc<Mutex<SockState>>>> {
let afd = self.afd_group.acquire()?;
Ok(Arc::new(Mutex::new(SockState::new(raw_socket, afd)?)))
Ok(Arc::pin(Mutex::new(SockState::new(raw_socket, afd)?)))
}
}

Expand Down
Loading