From 60036559342045405b79bf3fc2e00367bef52d73 Mon Sep 17 00:00:00 2001 From: Bert Belder Date: Thu, 14 Nov 2019 16:57:10 -0800 Subject: [PATCH] WIP --- src/sys/windows/mod.rs | 9 +- src/sys/windows/selector.rs | 186 ++++++++++++++++++++++++------------ src/sys/windows/tcp.rs | 15 ++- src/sys/windows/udp.rs | 7 +- 4 files changed, 137 insertions(+), 80 deletions(-) diff --git a/src/sys/windows/mod.rs b/src/sys/windows/mod.rs index 48ca1430aa..b55fdb149d 100644 --- a/src/sys/windows/mod.rs +++ b/src/sys/windows/mod.rs @@ -1,8 +1,7 @@ use std::io; use std::mem::size_of_val; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; -use std::pin::Pin; -use std::sync::{Arc, Mutex, Once}; +use std::sync::{Arc, Once}; use winapi::ctypes::c_int; use winapi::shared::ws2def::SOCKADDR; use winapi::um::winsock2::{ @@ -52,8 +51,8 @@ pub use udp::UdpSocket; pub use waker::Waker; pub trait SocketState { - fn get_sock_state(&self) -> Option>>>; - fn set_sock_state(&self, sock_state: Option>>>); + fn get_sock_state(&self) -> Option; + fn set_sock_state(&self, sock_state: Option); } use crate::{Interests, Token}; @@ -62,7 +61,7 @@ struct InternalState { selector: Arc, token: Token, interests: Interests, - sock_state: Option>>>, + sock_state: Option, } impl InternalState { diff --git a/src/sys/windows/selector.rs b/src/sys/windows/selector.rs index 261995c6c8..bd86afae77 100644 --- a/src/sys/windows/selector.rs +++ b/src/sys/windows/selector.rs @@ -7,9 +7,12 @@ use crate::{Interests, Token}; use miow::iocp::{CompletionPort, CompletionStatus}; use miow::Overlapped; -use std::collections::{HashMap, VecDeque}; +use std::cmp::Eq; +use std::collections::{HashSet, VecDeque}; +use std::hash::{Hash, Hasher}; use std::marker::PhantomPinned; use std::mem::{forget, size_of, transmute_copy}; +use std::ops::{Deref, DerefMut}; use std::os::windows::io::{AsRawSocket, RawSocket}; use std::pin::Pin; use std::ptr::null_mut; @@ -82,8 +85,49 @@ enum SockPollStatus { Cancelled, } +#[derive(Debug, Clone)] +pub struct SockState(Pin>>); + +impl SockState { + fn new(raw_socket: RawSocket, afd: Arc) -> io::Result { + Ok(Self(Arc::pin(Mutex::new(SockStateInner::new( + raw_socket, afd, + )?)))) + } +} + +impl Eq for SockState {} + +impl PartialEq for SockState { + fn eq(&self, other: &Self) -> bool { + let ptr1: *const Mutex<_> = &*self.0; + let ptr2: *const Mutex<_> = &*other.0; + ptr1 == ptr2 + } +} + +impl Hash for SockState { + fn hash(&self, hasher: &mut H) { + let ptr: *const Mutex<_> = &*self.0; + ptr.hash(hasher); + } +} + +impl Deref for SockState { + type Target = Pin>>; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for SockState { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + #[derive(Debug)] -pub struct SockState { +pub struct SockStateInner { iosb: IoStatusBlock, poll_info: AfdPollInfo, afd: Arc, @@ -102,9 +146,21 @@ pub struct SockState { pinned: PhantomPinned, } -impl SockState { - fn new(raw_socket: RawSocket, afd: Arc) -> io::Result { - Ok(SockState { +#[cfg(debug_assertions)] +static SOCK_STATE_COUNT: AtomicUsize = AtomicUsize::new(0); + +impl Drop for SockStateInner { + fn drop(&mut self) { + if cfg!(debug_assertions) { + let count = SOCK_STATE_COUNT.fetch_sub(1, Ordering::SeqCst) - 1; + println!("- {}", count); + } + } +} + +impl SockStateInner { + fn new(raw_socket: RawSocket, afd: Arc) -> io::Result { + let sock_state = Self { iosb: IoStatusBlock::zeroed(), poll_info: AfdPollInfo::zeroed(), afd, @@ -116,7 +172,12 @@ impl SockState { poll_status: SockPollStatus::Idle, delete_pending: false, pinned: PhantomPinned, - }) + }; + if cfg!(debug_assertions) { + let count = SOCK_STATE_COUNT.fetch_add(1, Ordering::SeqCst) + 1; + println!("+ {}", count); + } + Ok(sock_state) } /// True if need to be added on update queue, false otherwise. @@ -130,10 +191,14 @@ impl SockState { (events & !self.pending_evts) != 0 } - fn update(&mut self, self_arc: &Pin>>) -> io::Result<()> { - assert!(!self.delete_pending); - + fn update(&mut self, self_arc: &SockState) -> io::Result<()> { use SockPollStatus::*; + + if self.delete_pending { + assert_ne!(self.poll_status, Pending); + return Ok(()); + } + match self.poll_status { Pending if (self.user_evts & afd::KNOWN_EVENTS & !self.pending_evts) == 0 => { // All the events the user is interested in are already being monitored by @@ -256,6 +321,15 @@ impl SockState { self.delete_pending } + pub fn is_submitted(&self) -> bool { + use SockPollStatus::*; + match self.poll_status { + Idle => false, + Pending => true, + Cancelled => true, + } + } + pub fn mark_delete(&mut self) { if !self.delete_pending { if self.poll_status == SockPollStatus::Pending { @@ -266,12 +340,6 @@ impl SockState { } } -impl Drop for SockState { - fn drop(&mut self) { - self.mark_delete(); - } -} - /// Each Selector has a globally unique(ish) ID associated with it. This ID /// gets tracked by `TcpStream`, `TcpListener`, etc... when they are first /// registered with the `Selector`. If a type that is previously associated with @@ -362,8 +430,8 @@ impl Selector { #[derive(Debug)] pub struct SelectorInner { cp: Arc, - update_queue: Mutex>>>>, - delete_queue: Mutex>>>>, + update_queue: Mutex>, + pending_queue: Mutex>, afd_group: AfdGroup, is_polling: AtomicBool, } @@ -373,17 +441,22 @@ unsafe impl Sync for SelectorInner {} impl Drop for SelectorInner { fn drop(&mut self) { - let mut delete_queue = self.delete_queue.lock().unwrap(); - for (_, sock) in delete_queue.drain() { - let mut sock_internal = sock.lock().unwrap(); - sock_internal.mark_delete(); + for s in { + self.pending_queue + .lock() + .unwrap() + .iter() + .chain(self.update_queue.lock().unwrap().iter()) + } { + s.lock().unwrap().mark_delete() } - let mut events = Events::with_capacity(16); - let result = self.select(&mut events, Some(std::time::Duration::from_millis(0))); - match result { - Ok(_) => {} - Err(_) => {} + while !(self.pending_queue.lock().unwrap().is_empty() + && self.update_queue.lock().unwrap().is_empty()) + { + let mut events = Events::with_capacity(16); + self.select(&mut events, Some(std::time::Duration::from_millis(0))) + .unwrap(); } self.afd_group.release_unused_afd(); @@ -399,7 +472,7 @@ impl SelectorInner { SelectorInner { cp, update_queue: Mutex::new(VecDeque::new()), - delete_queue: Mutex::new(HashMap::new()), + pending_queue: Mutex::new(HashSet::new()), afd_group: AfdGroup::new(cp_afd), is_polling: AtomicBool::new(false), } @@ -450,7 +523,7 @@ impl SelectorInner { } } - pub fn select2( + fn select2( &self, statuses: &mut [CompletionStatus], events: &mut Vec, @@ -495,7 +568,6 @@ impl SelectorInner { socket.set_sock_state(Some(sock)); unsafe { self.add_socket_to_update_queue(socket); - self.add_socket_to_delete_queue(socket); self.update_sockets_events_if_polling()?; } @@ -531,27 +603,32 @@ impl SelectorInner { } pub fn deregister(&self, socket: &S) -> io::Result<()> { - if socket.get_sock_state().is_none() { - return Err(io::Error::from(io::ErrorKind::NotFound)); - } - unsafe { - self.remove_socket_from_delete_queue(socket); + match socket.get_sock_state() { + None => Err(io::Error::from(io::ErrorKind::NotFound)), + Some(sock_state) => { + sock_state.lock().unwrap().mark_delete(); + Ok(()) + } } - socket.set_sock_state(None); - self.afd_group.release_unused_afd(); - Ok(()) } unsafe fn update_sockets_events(&self) -> io::Result<()> { let mut update_queue = self.update_queue.lock().unwrap(); + let mut pending_queue = self.pending_queue.lock().unwrap(); loop { let sock = match update_queue.pop_front() { Some(sock) => sock, None => break, }; - let mut sock_internal = sock.lock().unwrap(); - if !sock_internal.is_pending_deletion() { - sock_internal.update(&sock).unwrap(); + // FIXME: this logic is terrible. + if { + let mut sock_internal = sock.lock().unwrap(); + let submitted_before = sock_internal.is_submitted(); + sock_internal.update(&sock)?; + let submitted_after = sock_internal.is_submitted(); + !submitted_before && submitted_after + } { + assert_eq!(pending_queue.insert(sock), true); } } self.afd_group.release_unused_afd(); @@ -589,20 +666,6 @@ impl SelectorInner { update_queue.push_back(sock_state); } - unsafe fn add_socket_to_delete_queue(&self, socket: &S) { - let sock_state = socket.get_sock_state().unwrap(); - let user_data = sock_state.lock().unwrap().user_data; - let mut delete_queue = self.delete_queue.lock().unwrap(); - delete_queue.insert(user_data, sock_state); - } - - unsafe fn remove_socket_from_delete_queue(&self, socket: &S) { - let sock_state = socket.get_sock_state().unwrap(); - let sock_internal = sock_state.lock().unwrap(); - let mut delete_queue = self.delete_queue.lock().unwrap(); - delete_queue.remove(&sock_internal.user_data); - } - // It returns processed count of iocp_events rather than the events itself. unsafe fn feed_events( &self, @@ -611,6 +674,7 @@ impl SelectorInner { ) -> usize { let mut n = 0; let mut update_queue = self.update_queue.lock().unwrap(); + let mut pending_queue = self.pending_queue.lock().unwrap(); for iocp_event in iocp_events.iter() { if iocp_event.overlapped().is_null() { // `Waker` event, we'll add a readable event to match the other platforms. @@ -621,12 +685,11 @@ impl SelectorInner { n += 1; continue; } - let sock_arc: Pin>> = transmute_copy(&iocp_event.overlapped()); + let sock_arc: SockState = transmute_copy(&iocp_event.overlapped()); + assert_eq!(pending_queue.remove(&sock_arc), true); let mut sock_guard = sock_arc.lock().unwrap(); match sock_guard.feed_event() { - Some(e) => { - events.push(e); - } + Some(e) => events.push(e), None => {} } n += 1; @@ -638,12 +701,9 @@ impl SelectorInner { n } - fn _alloc_sock_for_rawsocket( - &self, - raw_socket: RawSocket, - ) -> io::Result>>> { + fn _alloc_sock_for_rawsocket(&self, raw_socket: RawSocket) -> io::Result { let afd = self.afd_group.acquire()?; - Ok(Arc::pin(Mutex::new(SockState::new(raw_socket, afd)?))) + SockState::new(raw_socket, afd) } } diff --git a/src/sys/windows/tcp.rs b/src/sys/windows/tcp.rs index 368de55649..f6e4fdd563 100644 --- a/src/sys/windows/tcp.rs +++ b/src/sys/windows/tcp.rs @@ -8,8 +8,7 @@ use std::io::{self, IoSlice, IoSliceMut, Read, Write}; use std::net::{self, SocketAddr}; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::os::windows::raw::SOCKET as StdSocket; // winapi uses usize, stdlib uses u32/u64. -use std::pin::Pin; -use std::sync::{Arc, Mutex}; +use std::sync::Mutex; use winapi::um::winsock2::{bind, closesocket, connect, listen, SOCKET_ERROR, SOCK_STREAM}; pub struct TcpStream { @@ -128,7 +127,7 @@ impl TcpStream { } impl super::SocketState for TcpStream { - fn get_sock_state(&self) -> Option>>> { + fn get_sock_state(&self) -> Option { let internal = self.internal.lock().unwrap(); match &*internal { Some(internal) => match &internal.sock_state { @@ -138,7 +137,7 @@ impl super::SocketState for TcpStream { None => None, } } - fn set_sock_state(&self, sock_state: Option>>>) { + fn set_sock_state(&self, sock_state: Option) { let mut internal = self.internal.lock().unwrap(); match &mut *internal { Some(internal) => { @@ -161,7 +160,7 @@ impl super::SocketState for TcpStream { } impl<'a> super::SocketState for &'a TcpStream { - fn get_sock_state(&self) -> Option>>> { + fn get_sock_state(&self) -> Option { let internal = self.internal.lock().unwrap(); match &*internal { Some(internal) => match &internal.sock_state { @@ -171,7 +170,7 @@ impl<'a> super::SocketState for &'a TcpStream { None => None, } } - fn set_sock_state(&self, sock_state: Option>>>) { + fn set_sock_state(&self, sock_state: Option) { let mut internal = self.internal.lock().unwrap(); match &mut *internal { Some(internal) => { @@ -406,7 +405,7 @@ impl TcpListener { } impl super::SocketState for TcpListener { - fn get_sock_state(&self) -> Option>>> { + fn get_sock_state(&self) -> Option { let internal = self.internal.lock().unwrap(); match &*internal { Some(internal) => match &internal.sock_state { @@ -416,7 +415,7 @@ impl super::SocketState for TcpListener { None => None, } } - fn set_sock_state(&self, sock_state: Option>>>) { + fn set_sock_state(&self, sock_state: Option) { let mut internal = self.internal.lock().unwrap(); match &mut *internal { Some(internal) => { diff --git a/src/sys/windows/udp.rs b/src/sys/windows/udp.rs index cf4db6fc29..862bf5799b 100644 --- a/src/sys/windows/udp.rs +++ b/src/sys/windows/udp.rs @@ -6,8 +6,7 @@ use crate::{event, poll, Interests, Registry, Token}; use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::os::windows::raw::SOCKET as StdSocket; // winapi uses usize, stdlib uses u32/u64. -use std::pin::Pin; -use std::sync::{Arc, Mutex}; +use std::sync::Mutex; use std::{fmt, io}; use winapi::um::winsock2::{bind, closesocket, SOCKET_ERROR, SOCK_DGRAM}; @@ -161,7 +160,7 @@ impl UdpSocket { } impl super::SocketState for UdpSocket { - fn get_sock_state(&self) -> Option>>> { + fn get_sock_state(&self) -> Option { let internal = self.internal.lock().unwrap(); match &*internal { Some(internal) => match &internal.sock_state { @@ -171,7 +170,7 @@ impl super::SocketState for UdpSocket { None => None, } } - fn set_sock_state(&self, sock_state: Option>>>) { + fn set_sock_state(&self, sock_state: Option) { let mut internal = self.internal.lock().unwrap(); match &mut *internal { Some(internal) => {