Skip to content

Commit

Permalink
Changes to fix leaked sock states, issue tokio-rs#1146, and couple of…
Browse files Browse the repository at this point in the history
… improvements. All registered sock_states are saved in a slab data structure. Sock_states are removed from slab during deregister. A reregister expects to find the sock_state saved in slab already. During Selector drop all sock_states from the slab are cancelled and dropped. The slab saves sock_states as pairs of (key, sock_state). The key is also saved in the sock_state and used as overlapped data. Later, when receiving the completion event, the key is used to retrieve the sock_state. The slab will never get full of unusable sock_states which have been marked for deletion because, once a sock_state has been marked for deletion it will also be removed from the slab. Specific cases which should trigger the removal are: 1. a sock_state for which completion returned error, i.e ERROR_INVALID_HANDLE; 2. a sock_state for which a an LOCAL_CLOSE event was retrieved;

Signed-off-by: Daniel Tacalau <dst4096@gmail.com>
  • Loading branch information
dtacalau committed Nov 21, 2019
1 parent 952932a commit 09cff22
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 72 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ publish = false

[dependencies]
log = "0.4.8"
slab = "0.4.2"

[target.'cfg(unix)'.dependencies]
libc = "0.2.62"
Expand Down
202 changes: 131 additions & 71 deletions src/sys/windows/selector.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use slab::Slab;

use super::afd::{self, Afd, AfdPollInfo};
use super::io_status_block::IoStatusBlock;
use super::Event;
Expand All @@ -17,6 +19,7 @@ use std::sync::atomic::AtomicUsize;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use std::usize;
use std::{io, ptr};
use winapi::shared::ntdef::NT_SUCCESS;
use winapi::shared::ntdef::{HANDLE, PVOID};
Expand Down Expand Up @@ -75,38 +78,6 @@ impl AfdGroup {
}
}

/// This is the deallocation wrapper for overlapped pointer.
/// In case of error or status changing before the overlapped pointer is actually used(or not even being used),
/// this wrapper will decrease the reference count of Arc if being dropped.
/// Remember call `forget` if you have used the Arc, or you could decrease the reference count by two causing double free.
#[derive(Debug)]
struct OverlappedArcWrapper<T>(*const T);

unsafe impl<T> Send for OverlappedArcWrapper<T> {}

impl<T> OverlappedArcWrapper<T> {
fn new(arc: &Arc<T>) -> OverlappedArcWrapper<T> {
OverlappedArcWrapper(Arc::into_raw(arc.clone()))
}

fn forget(&mut self) {
self.0 = 0 as *const T;
}

fn get_ptr(&self) -> *const T {
self.0
}
}

impl<T> Drop for OverlappedArcWrapper<T> {
fn drop(&mut self) {
if self.0 as usize == 0 {
return;
}
drop(unsafe { Arc::from_raw(self.0) });
}
}

#[derive(Debug)]
enum SockPollStatus {
Idle,
Expand All @@ -123,13 +94,13 @@ pub struct SockState {
raw_socket: RawSocket,
base_socket: RawSocket,

id: usize,
user_evts: u32,
pending_evts: u32,

user_data: u64,

poll_status: SockPollStatus,
self_wrapped: Option<OverlappedArcWrapper<Mutex<SockState>>>,

delete_pending: bool,
}
Expand All @@ -142,15 +113,29 @@ impl SockState {
afd,
raw_socket,
base_socket: get_base_socket(raw_socket)?,
/// MAX is not a valid id, need to call set_id to have a valid id before using this field
id: usize::MAX,
user_evts: 0,
pending_evts: 0,
user_data: 0,
poll_status: SockPollStatus::Idle,
self_wrapped: None,
delete_pending: false,
})
}

/// Return true if id was set successfully, false otherwise
///
/// Note: It is an error to set the id multiple times
fn set_id(&mut self, id: usize) -> bool {
let mut result = false;

if self.id == usize::MAX {
self.id = id;
result = true;
}
result
}

/// True if need to be added on update queue, false otherwise.
fn set_event(&mut self, ev: Event) -> bool {
/* afd::POLL_CONNECT_FAIL and afd::POLL_ABORT are always reported, even when not requested by the caller. */
Expand All @@ -162,7 +147,7 @@ impl SockState {
(events & !self.pending_evts) != 0
}

fn update(&mut self, self_arc: &Arc<Mutex<SockState>>) -> io::Result<()> {
fn update(&mut self) -> io::Result<()> {
assert!(!self.delete_pending);

if let SockPollStatus::Pending = self.poll_status {
Expand Down Expand Up @@ -192,8 +177,10 @@ impl SockState {
self.poll_info.handles[0].status = 0;
self.poll_info.handles[0].events = self.user_evts | afd::POLL_LOCAL_CLOSE;

let wrapped_overlapped = OverlappedArcWrapper::new(self_arc);
let overlapped = wrapped_overlapped.get_ptr() as *const _ as PVOID;
// Use sock_state unique id as overlapped data. Id will be used to retrieve the sock_state.
// Notice: id is a slab key, which starts from 0. Overlapped data cannot be 0, that would mean NULL pointer,
// that is why id is increased by 1 when sending overlapped, the receiving side will decrease it by 1 before usage.
let overlapped = (self.id + 1) as PVOID;
let result = unsafe {
self.afd
.poll(&mut self.poll_info, (*self.iosb).as_mut_ptr(), overlapped)
Expand All @@ -211,12 +198,7 @@ impl SockState {
}
}

if self.self_wrapped.is_some() {
// This shouldn't be happening. We cannot deallocate already pending overlapped before feed_event so we need to stand out here to declare unreachable.
unreachable!();
}
self.poll_status = SockPollStatus::Pending;
self.self_wrapped = Some(wrapped_overlapped);
self.pending_evts = self.user_evts;
} else {
unreachable!();
Expand All @@ -239,12 +221,6 @@ impl SockState {

// This is the function called from the overlapped using as Arc<Mutex<SockState>>. Watch out for reference counting.
fn feed_event(&mut self) -> Option<Event> {
if self.self_wrapped.is_some() {
// Forget our arced-self first. We will decrease the reference count by two if we don't do this on overlapped.
self.self_wrapped.as_mut().unwrap().forget();
self.self_wrapped = None;
}

self.poll_status = SockPollStatus::Idle;
self.pending_evts = 0;

Expand Down Expand Up @@ -301,11 +277,11 @@ impl SockState {

pub fn mark_delete(&mut self) {
if !self.delete_pending {
self.delete_pending = true;

if let SockPollStatus::Pending = self.poll_status {
drop(self.cancel());
}

self.delete_pending = true;
}
}
}
Expand Down Expand Up @@ -403,26 +379,55 @@ impl Selector {
}
}

#[derive(Debug)]
pub struct SockStates {
/// contains sock_states which need to be updated by calling afd.poll
update_queue: VecDeque<Arc<Mutex<SockState>>>,
/// contains all sock_states which have been registered so far
all: Slab<Arc<Mutex<SockState>>>,
}

#[derive(Debug)]
pub struct SelectorInner {
cp: Arc<CompletionPort>,
update_queue: Mutex<VecDeque<Arc<Mutex<SockState>>>>,
sock_states: Mutex<SockStates>,
afd_group: AfdGroup,
is_polling: AtomicBool,
}

// We have ensured thread safety by introducing lock manually.
unsafe impl Sync for SelectorInner {}

impl Drop for SelectorInner {
fn drop(&mut self) {
let all_sock_states = &mut self.sock_states.lock().unwrap().all;
for sock_state in all_sock_states.drain() {
let sock_state_internal = &mut sock_state.lock().unwrap();
sock_state_internal.mark_delete();
}

self.afd_group.release_unused_afd();
}
}

enum SocketOps {
SocketRegister,
SocketReregister,
SocketDeregister,
}

impl SelectorInner {
pub fn new() -> io::Result<SelectorInner> {
CompletionPort::new(0).map(|cp| {
let cp = Arc::new(cp);
let cp_afd = Arc::clone(&cp);

let sock_states = SockStates {
update_queue: VecDeque::new(),
all: Slab::with_capacity(1024),
};
SelectorInner {
cp,
update_queue: Mutex::new(VecDeque::new()),
sock_states: Mutex::new(sock_states),
afd_group: AfdGroup::new(cp_afd),
is_polling: AtomicBool::new(false),
}
Expand Down Expand Up @@ -517,7 +522,7 @@ impl SelectorInner {
}
socket.set_sock_state(Some(sock));
unsafe {
self.add_socket_to_update_queue(socket);
self.update_sock_states(socket, SocketOps::SocketRegister);
self.update_sockets_events_if_polling()?;
}

Expand Down Expand Up @@ -545,7 +550,7 @@ impl SelectorInner {
sock.lock().unwrap().set_event(event);
}
unsafe {
self.add_socket_to_update_queue(socket);
self.update_sock_states(socket, SocketOps::SocketReregister);
self.update_sockets_events_if_polling()?;
}

Expand All @@ -556,23 +561,38 @@ impl SelectorInner {
if socket.get_sock_state().is_none() {
return Err(io::Error::from(io::ErrorKind::NotFound));
}
unsafe {
self.update_sock_states(socket, SocketOps::SocketDeregister);
}
socket.set_sock_state(None);
self.afd_group.release_unused_afd();
Ok(())
}

unsafe fn update_sockets_events(&self) -> io::Result<()> {
let mut update_queue = self.update_queue.lock().unwrap();
let sock_states = &mut self.sock_states.lock().unwrap();

loop {
let sock = match update_queue.pop_front() {
let sock = match sock_states.update_queue.pop_front() {
Some(sock) => sock,
None => break,
};

let mut sock_internal = sock.lock().unwrap();
if !sock_internal.is_pending_deletion() {
sock_internal.update(&sock).unwrap();
sock_internal.update().unwrap();
}

// If during the sock_internal update, because of some error, this socket was marked for deletion,
// just remove it. Make sure to check the slab contains the socket, to avoid double removing a socket.
// This may happen for sockets which, during Selector drop, have been cancelled and removed already.
if sock_internal.is_pending_deletion() {
if sock_states.all.contains(sock_internal.id) {
sock_states.all.remove(sock_internal.id);
}
}
}

self.afd_group.release_unused_afd();
Ok(())
}
Expand Down Expand Up @@ -602,10 +622,34 @@ impl SelectorInner {
}
}

unsafe fn add_socket_to_update_queue<S: SocketState>(&self, socket: &S) {
unsafe fn update_sock_states<S: SocketState>(&self, socket: &S, sock_op: SocketOps) {
let sock_state = socket.get_sock_state().unwrap();
let mut update_queue = self.update_queue.lock().unwrap();
update_queue.push_back(sock_state);
let sock_states = &mut self.sock_states.lock().unwrap();

match sock_op {
SocketOps::SocketRegister => {
sock_states.update_queue.push_back(sock_state.clone());

let entry = sock_states.all.vacant_entry();
let key = entry.key();
{
let mut sock_state_internal = sock_state.lock().unwrap();
assert!(sock_state_internal.set_id(key)); //this should always succeed, only called once
};
entry.insert(sock_state);
}

SocketOps::SocketReregister => {
assert!(sock_states.all.contains(sock_state.lock().unwrap().id));
sock_states.update_queue.push_back(sock_state);
}

SocketOps::SocketDeregister => {
let sock_state_internal = sock_state.lock().unwrap();
assert!(sock_states.all.contains(sock_state_internal.id));
sock_states.all.remove(sock_state_internal.id);
}
}
}

// It returns processed count of iocp_events rather than the events itself.
Expand All @@ -614,33 +658,49 @@ impl SelectorInner {
events: &mut Vec<Event>,
iocp_events: &[CompletionStatus],
) -> usize {
let mut n = 0;
let mut update_queue = self.update_queue.lock().unwrap();
let mut events_num = 0;
let sock_states = &mut self.sock_states.lock().unwrap();
for iocp_event in iocp_events.iter() {
if iocp_event.overlapped().is_null() {
// `Waker` event, we'll add a readable event to match the other platforms.
events.push(Event {
flags: afd::POLL_RECEIVE,
data: iocp_event.token() as u64,
});
n += 1;
events_num += 1;
continue;
}

// Use sock_state unique id as overlapped data. Id will be used to retrieve the sock_state.
// Notice: id is a slab key, which starts from 0, sending side increased it by 1 before
// sending it as overlapped data, so it will be decreased it by 1 before usage.
let id = (iocp_event.overlapped() as usize) - 1;
if sock_states.all.contains(id) == false {
// Cannot find a sock_state for this id, probably this is an event for a cancelled
// sock_state which has already been removed, silently drop it.
continue;
}
let sock_arc = Arc::from_raw(iocp_event.overlapped() as *const Mutex<SockState>);
let mut sock_guard = sock_arc.lock().unwrap();
match sock_guard.feed_event() {

let sock_state = sock_states.all[id].clone();
let sock_state_internal = &mut sock_state.lock().unwrap();
match sock_state_internal.feed_event() {
Some(e) => {
events.push(e);
events_num += 1;
}
None => {}
}
n += 1;
if !sock_guard.is_pending_deletion() {
update_queue.push_back(sock_arc.clone());

if !sock_state_internal.is_pending_deletion() {
sock_states.update_queue.push_back(sock_state.clone());
} else {
// if sock_state got a close event, it was marked for deletion, so just remove it
assert!(sock_states.all.contains(sock_state_internal.id));
sock_states.all.remove(sock_state_internal.id);
}
}
self.afd_group.release_unused_afd();
n
events_num
}

fn _alloc_sock_for_rawsocket(
Expand Down
5 changes: 4 additions & 1 deletion tests/tcp_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ use mio::{Interests, Token};
#[macro_use]
mod util;

#[cfg(not(target_os = "windows"))]
use util::init;

use util::{
any_local_address, any_local_ipv6_address, assert_send, assert_sync, assert_would_block,
expect_events, expect_no_events, init, init_with_poll, ExpectEvent,
expect_events, expect_no_events, init_with_poll, ExpectEvent,
};

const DATA1: &[u8] = b"Hello world!";
Expand Down

0 comments on commit 09cff22

Please sign in to comment.