Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix windows SockState reference counting + misc clean-ups #1149

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions src/sys/windows/io_status_block.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
use ntapi::ntioapi::{IO_STATUS_BLOCK_u, IO_STATUS_BLOCK};
use std::cell::UnsafeCell;
use std::fmt;
use std::ops::{Deref, DerefMut};

pub struct IoStatusBlock(UnsafeCell<IO_STATUS_BLOCK>);

// There is a pointer field in `IO_STATUS_BLOCK_u`, which we don't use that. Thus it is safe to implement Send here.
unsafe impl Send for IoStatusBlock {}
pub struct IoStatusBlock(IO_STATUS_BLOCK);

impl IoStatusBlock {
pub fn zeroed() -> IoStatusBlock {
let iosb = IO_STATUS_BLOCK {
pub fn zeroed() -> Self {
Self(IO_STATUS_BLOCK {
u: IO_STATUS_BLOCK_u { Status: 0 },
Information: 0,
};
IoStatusBlock(UnsafeCell::new(iosb))
})
}
}

pub fn as_ptr(&self) -> *const IO_STATUS_BLOCK {
self.0.get()
unsafe impl Send for IoStatusBlock {}

impl Deref for IoStatusBlock {
type Target = IO_STATUS_BLOCK;
fn deref(&self) -> &Self::Target {
&self.0
}
}

pub fn as_mut_ptr(&self) -> *mut IO_STATUS_BLOCK {
self.0.get()
impl DerefMut for IoStatusBlock {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

Expand Down
7 changes: 4 additions & 3 deletions src/sys/windows/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::io;
use std::mem::size_of_val;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::pin::Pin;
use std::sync::{Arc, Mutex, Once};
use winapi::ctypes::c_int;
use winapi::shared::ws2def::SOCKADDR;
Expand Down Expand Up @@ -51,8 +52,8 @@ pub use udp::UdpSocket;
pub use waker::Waker;

pub trait SocketState {
fn get_sock_state(&self) -> Option<Arc<Mutex<SockState>>>;
fn set_sock_state(&self, sock_state: Option<Arc<Mutex<SockState>>>);
fn get_sock_state(&self) -> Option<Pin<Arc<Mutex<SockState>>>>;
fn set_sock_state(&self, sock_state: Option<Pin<Arc<Mutex<SockState>>>>);
}

use crate::{Interests, Token};
Expand All @@ -61,7 +62,7 @@ struct InternalState {
selector: Arc<SelectorInner>,
token: Token,
interests: Interests,
sock_state: Option<Arc<Mutex<SockState>>>,
sock_state: Option<Pin<Arc<Mutex<SockState>>>>,
}

impl InternalState {
Expand Down
83 changes: 21 additions & 62 deletions src/sys/windows/selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use crate::{Interests, Token};
use miow::iocp::{CompletionPort, CompletionStatus};
use miow::Overlapped;
use std::collections::VecDeque;
use std::mem::size_of;
use std::marker::PhantomPinned;
use std::mem::{forget, size_of, transmute_copy};
use std::os::windows::io::{AsRawSocket, RawSocket};
use std::pin::Pin;
use std::ptr::null_mut;
Expand Down Expand Up @@ -75,38 +76,6 @@ impl AfdGroup {
}
}

/// This is the deallocation wrapper for overlapped pointer.
/// In case of error or status changing before the overlapped pointer is actually used(or not even being used),
/// this wrapper will decrease the reference count of Arc if being dropped.
/// Remember call `forget` if you have used the Arc, or you could decrease the reference count by two causing double free.
#[derive(Debug)]
struct OverlappedArcWrapper<T>(*const T);

unsafe impl<T> Send for OverlappedArcWrapper<T> {}

impl<T> OverlappedArcWrapper<T> {
fn new(arc: &Arc<T>) -> OverlappedArcWrapper<T> {
OverlappedArcWrapper(Arc::into_raw(arc.clone()))
}

fn forget(&mut self) {
self.0 = 0 as *const T;
}

fn get_ptr(&self) -> *const T {
self.0
}
}

impl<T> Drop for OverlappedArcWrapper<T> {
fn drop(&mut self) {
if self.0 as usize == 0 {
return;
}
drop(unsafe { Arc::from_raw(self.0) });
}
}

#[derive(Debug)]
enum SockPollStatus {
Idle,
Expand All @@ -116,7 +85,7 @@ enum SockPollStatus {

#[derive(Debug)]
pub struct SockState {
iosb: Pin<Box<IoStatusBlock>>,
iosb: IoStatusBlock,
poll_info: AfdPollInfo,
afd: Arc<Afd>,

Expand All @@ -129,15 +98,15 @@ pub struct SockState {
user_data: u64,

poll_status: SockPollStatus,
self_wrapped: Option<OverlappedArcWrapper<Mutex<SockState>>>,

delete_pending: bool,

pinned: PhantomPinned,
}

impl SockState {
fn new(raw_socket: RawSocket, afd: Arc<Afd>) -> io::Result<SockState> {
Ok(SockState {
iosb: Pin::new(Box::new(IoStatusBlock::zeroed())),
iosb: IoStatusBlock::zeroed(),
poll_info: AfdPollInfo::zeroed(),
afd,
raw_socket,
Expand All @@ -146,8 +115,8 @@ impl SockState {
pending_evts: 0,
user_data: 0,
poll_status: SockPollStatus::Idle,
self_wrapped: None,
delete_pending: false,
pinned: PhantomPinned,
})
}

Expand All @@ -162,7 +131,7 @@ impl SockState {
(events & !self.pending_evts) != 0
}

fn update(&mut self, self_arc: &Arc<Mutex<SockState>>) -> io::Result<()> {
fn update(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> io::Result<()> {
assert!(!self.delete_pending);

if let SockPollStatus::Pending = self.poll_status {
Expand Down Expand Up @@ -192,11 +161,12 @@ impl SockState {
self.poll_info.handles[0].status = 0;
self.poll_info.handles[0].events = self.user_evts | afd::POLL_LOCAL_CLOSE;

let wrapped_overlapped = OverlappedArcWrapper::new(self_arc);
let overlapped = wrapped_overlapped.get_ptr() as *const _ as PVOID;
let result = unsafe {
self.afd
.poll(&mut self.poll_info, (*self.iosb).as_mut_ptr(), overlapped)
self.afd.poll(
&mut self.poll_info,
&mut *self.iosb,
transmute_copy(self_arc),
)
};
if let Err(e) = result {
let code = e.raw_os_error().unwrap();
Expand All @@ -211,13 +181,9 @@ impl SockState {
}
}

if self.self_wrapped.is_some() {
// This shouldn't be happening. We cannot deallocate already pending overlapped before feed_event so we need to stand out here to declare unreachable.
unreachable!();
}
self.poll_status = SockPollStatus::Pending;
self.self_wrapped = Some(wrapped_overlapped);
self.pending_evts = self.user_evts;
forget(self_arc.clone());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks strange, why call forget on the clone and not on the self_arc directly?

Copy link
Contributor Author

@piscisaureus piscisaureus Nov 14, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self_arc is a reference here. Dropping a reference does is no-op, therefore ’forget’ing a reference instead of dropping doesn't make any difference.

} else {
unreachable!();
}
Expand All @@ -230,7 +196,7 @@ impl SockState {
_ => unreachable!(),
};
unsafe {
self.afd.cancel((*self.iosb).as_mut_ptr())?;
self.afd.cancel(&mut *self.iosb)?;
}
self.poll_status = SockPollStatus::Cancelled;
self.pending_evts = 0;
Expand All @@ -239,24 +205,17 @@ impl SockState {

// This is the function called from the overlapped using as Arc<Mutex<SockState>>. Watch out for reference counting.
fn feed_event(&mut self) -> Option<Event> {
if self.self_wrapped.is_some() {
// Forget our arced-self first. We will decrease the reference count by two if we don't do this on overlapped.
self.self_wrapped.as_mut().unwrap().forget();
self.self_wrapped = None;
}

self.poll_status = SockPollStatus::Idle;
self.pending_evts = 0;

let mut afd_events = 0;
// We use the status info in IO_STATUS_BLOCK to determine the socket poll status. It is unsafe to use a pointer of IO_STATUS_BLOCK.
unsafe {
let iosb = &*(*self.iosb).as_ptr();
if self.delete_pending {
return None;
} else if iosb.u.Status == STATUS_CANCELLED {
} else if self.iosb.u.Status == STATUS_CANCELLED {
/* The poll request was cancelled by CancelIoEx. */
} else if !NT_SUCCESS(iosb.u.Status) {
} else if !NT_SUCCESS(self.iosb.u.Status) {
/* The overlapped request itself failed in an unexpected way. */
afd_events = afd::POLL_CONNECT_FAIL;
} else if self.poll_info.number_of_handles < 1 {
Expand Down Expand Up @@ -406,7 +365,7 @@ impl Selector {
#[derive(Debug)]
pub struct SelectorInner {
cp: Arc<CompletionPort>,
update_queue: Mutex<VecDeque<Arc<Mutex<SockState>>>>,
update_queue: Mutex<VecDeque<Pin<Arc<Mutex<SockState>>>>>,
afd_group: AfdGroup,
is_polling: AtomicBool,
}
Expand Down Expand Up @@ -626,7 +585,7 @@ impl SelectorInner {
n += 1;
continue;
}
let sock_arc = Arc::from_raw(iocp_event.overlapped() as *const Mutex<SockState>);
let sock_arc: Pin<Arc<Mutex<SockState>>> = transmute_copy(&iocp_event.overlapped());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a perfect copy without increasing the ref count?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. It's basically reinterpret_cast + copy.

let mut sock_guard = sock_arc.lock().unwrap();
match sock_guard.feed_event() {
Some(e) => {
Expand All @@ -646,9 +605,9 @@ impl SelectorInner {
fn _alloc_sock_for_rawsocket(
&self,
raw_socket: RawSocket,
) -> io::Result<Arc<Mutex<SockState>>> {
) -> io::Result<Pin<Arc<Mutex<SockState>>>> {
let afd = self.afd_group.acquire()?;
Ok(Arc::new(Mutex::new(SockState::new(raw_socket, afd)?)))
Ok(Arc::pin(Mutex::new(SockState::new(raw_socket, afd)?)))
}
}

Expand Down
13 changes: 7 additions & 6 deletions src/sys/windows/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::io::{self, IoSlice, IoSliceMut, Read, Write};
use std::net::{self, SocketAddr};
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
use std::os::windows::raw::SOCKET as StdSocket; // winapi uses usize, stdlib uses u32/u64.
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use winapi::um::winsock2::{bind, closesocket, connect, listen, SOCKET_ERROR, SOCK_STREAM};

Expand Down Expand Up @@ -127,7 +128,7 @@ impl TcpStream {
}

impl super::SocketState for TcpStream {
fn get_sock_state(&self) -> Option<Arc<Mutex<SockState>>> {
fn get_sock_state(&self) -> Option<Pin<Arc<Mutex<SockState>>>> {
let internal = self.internal.lock().unwrap();
match &*internal {
Some(internal) => match &internal.sock_state {
Expand All @@ -137,7 +138,7 @@ impl super::SocketState for TcpStream {
None => None,
}
}
fn set_sock_state(&self, sock_state: Option<Arc<Mutex<SockState>>>) {
fn set_sock_state(&self, sock_state: Option<Pin<Arc<Mutex<SockState>>>>) {
let mut internal = self.internal.lock().unwrap();
match &mut *internal {
Some(internal) => {
Expand All @@ -160,7 +161,7 @@ impl super::SocketState for TcpStream {
}

impl<'a> super::SocketState for &'a TcpStream {
fn get_sock_state(&self) -> Option<Arc<Mutex<SockState>>> {
fn get_sock_state(&self) -> Option<Pin<Arc<Mutex<SockState>>>> {
let internal = self.internal.lock().unwrap();
match &*internal {
Some(internal) => match &internal.sock_state {
Expand All @@ -170,7 +171,7 @@ impl<'a> super::SocketState for &'a TcpStream {
None => None,
}
}
fn set_sock_state(&self, sock_state: Option<Arc<Mutex<SockState>>>) {
fn set_sock_state(&self, sock_state: Option<Pin<Arc<Mutex<SockState>>>>) {
let mut internal = self.internal.lock().unwrap();
match &mut *internal {
Some(internal) => {
Expand Down Expand Up @@ -405,7 +406,7 @@ impl TcpListener {
}

impl super::SocketState for TcpListener {
fn get_sock_state(&self) -> Option<Arc<Mutex<SockState>>> {
fn get_sock_state(&self) -> Option<Pin<Arc<Mutex<SockState>>>> {
let internal = self.internal.lock().unwrap();
match &*internal {
Some(internal) => match &internal.sock_state {
Expand All @@ -415,7 +416,7 @@ impl super::SocketState for TcpListener {
None => None,
}
}
fn set_sock_state(&self, sock_state: Option<Arc<Mutex<SockState>>>) {
fn set_sock_state(&self, sock_state: Option<Pin<Arc<Mutex<SockState>>>>) {
let mut internal = self.internal.lock().unwrap();
match &mut *internal {
Some(internal) => {
Expand Down
5 changes: 3 additions & 2 deletions src/sys/windows/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{event, poll, Interests, Registry, Token};
use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
use std::os::windows::raw::SOCKET as StdSocket; // winapi uses usize, stdlib uses u32/u64.
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::{fmt, io};
use winapi::um::winsock2::{bind, closesocket, SOCKET_ERROR, SOCK_DGRAM};
Expand Down Expand Up @@ -160,7 +161,7 @@ impl UdpSocket {
}

impl super::SocketState for UdpSocket {
fn get_sock_state(&self) -> Option<Arc<Mutex<SockState>>> {
fn get_sock_state(&self) -> Option<Pin<Arc<Mutex<SockState>>>> {
let internal = self.internal.lock().unwrap();
match &*internal {
Some(internal) => match &internal.sock_state {
Expand All @@ -170,7 +171,7 @@ impl super::SocketState for UdpSocket {
None => None,
}
}
fn set_sock_state(&self, sock_state: Option<Arc<Mutex<SockState>>>) {
fn set_sock_state(&self, sock_state: Option<Pin<Arc<Mutex<SockState>>>>) {
let mut internal = self.internal.lock().unwrap();
match &mut *internal {
Some(internal) => {
Expand Down