From 1200f91c1c6e04010c973e982cda39490eb4dd3f Mon Sep 17 00:00:00 2001 From: Benjamin Saunders Date: Fri, 29 Oct 2021 22:10:10 -0700 Subject: [PATCH 1/5] Dispatch connection events synchronously --- quinn/src/connection.rs | 60 +++++++----------------------------- quinn/src/endpoint.rs | 68 ++++++++++++++++------------------------- quinn/src/lib.rs | 10 ------ 3 files changed, 38 insertions(+), 100 deletions(-) diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index 4fa5768c0..1c3a65740 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -23,7 +23,7 @@ use crate::{ mutex::Mutex, recv_stream::RecvStream, send_stream::{SendStream, WriteError}, - ConnectionEvent, EndpointEvent, VarInt, + EndpointEvent, VarInt, }; use proto::congestion::Controller; @@ -41,17 +41,15 @@ impl Connecting { handle: ConnectionHandle, conn: proto::Connection, endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, - conn_events: mpsc::UnboundedReceiver, udp_state: Arc, runtime: Arc, - ) -> Connecting { + ) -> (Connecting, ConnectionRef) { let (on_handshake_data_send, on_handshake_data_recv) = oneshot::channel(); let (on_connected_send, on_connected_recv) = oneshot::channel(); let conn = ConnectionRef::new( handle, conn, endpoint_events, - conn_events, on_handshake_data_send, on_connected_send, udp_state, @@ -59,12 +57,14 @@ impl Connecting { ); runtime.spawn(Box::pin(ConnectionDriver(conn.clone()))); - - Connecting { - conn: Some(conn), - connected: on_connected_recv, - handshake_data_ready: Some(on_handshake_data_recv), - } + ( + Connecting { + conn: Some(conn.clone()), + connected: on_connected_recv, + handshake_data_ready: Some(on_handshake_data_recv), + }, + conn, + ) } /// Convert into a 0-RTT or 0.5-RTT connection at the cost of weakened security @@ -226,10 +226,6 @@ impl Future for ConnectionDriver { let span = debug_span!("drive", id = conn.handle.0); let _guard = span.enter(); - if let Err(e) = conn.process_conn_events(&self.0.shared, cx) { - conn.terminate(e, &self.0.shared); - return Poll::Ready(()); - } let mut keep_going = conn.drive_transmit(); // If a timer expires, there might be more to transmit. When we transmit something, we // might need to reset a timer. Hence, we must loop until neither happens. @@ -746,7 +742,6 @@ impl ConnectionRef { handle: ConnectionHandle, conn: proto::Connection, endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, - conn_events: mpsc::UnboundedReceiver, on_handshake_data: oneshot::Sender<()>, on_connected: oneshot::Sender, udp_state: Arc, @@ -762,7 +757,6 @@ impl ConnectionRef { connected: false, timer: None, timer_deadline: None, - conn_events, endpoint_events, blocked_writers: FxHashMap::default(), blocked_readers: FxHashMap::default(), @@ -838,7 +832,6 @@ pub(crate) struct State { connected: bool, timer: Option>>, timer_deadline: Option, - conn_events: mpsc::UnboundedReceiver, endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, pub(crate) blocked_writers: FxHashMap, pub(crate) blocked_readers: FxHashMap, @@ -890,37 +883,6 @@ impl State { } } - /// If this returns `Err`, the endpoint is dead, so the driver should exit immediately. - fn process_conn_events( - &mut self, - shared: &Shared, - cx: &mut Context, - ) -> Result<(), ConnectionError> { - loop { - match self.conn_events.poll_recv(cx) { - Poll::Ready(Some(ConnectionEvent::Ping)) => { - self.inner.ping(); - } - Poll::Ready(Some(ConnectionEvent::Proto(event))) => { - self.inner.handle_event(event); - } - Poll::Ready(Some(ConnectionEvent::Close { reason, error_code })) => { - self.close(error_code, reason, shared); - } - Poll::Ready(None) => { - return Err(ConnectionError::TransportError(proto::TransportError { - code: proto::TransportErrorCode::INTERNAL_ERROR, - frame: None, - reason: "endpoint driver future was dropped".to_string(), - })); - } - Poll::Pending => { - return Ok(()); - } - } - } - } - fn forward_app_events(&mut self, shared: &Shared) { while let Some(event) = self.inner.poll() { use proto::Event::*; @@ -1073,7 +1035,7 @@ impl State { shared.closed.notify_waiters(); } - fn close(&mut self, error_code: VarInt, reason: Bytes, shared: &Shared) { + pub fn close(&mut self, error_code: VarInt, reason: Bytes, shared: &Shared) { self.inner.close(Instant::now(), error_code, reason); self.terminate(ConnectionError::LocallyClosed, shared); self.wake(); diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index c134231c7..94718f45a 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -23,8 +23,9 @@ use tokio::sync::{futures::Notified, mpsc, Notify}; use udp::{RecvMeta, UdpState, BATCH_SIZE}; use crate::{ - connection::Connecting, work_limiter::WorkLimiter, ConnectionEvent, EndpointConfig, - EndpointEvent, VarInt, IO_LOOP_BOUND, RECV_TIME_BOUND, SEND_TIME_BOUND, + connection::{Connecting, ConnectionRef}, + work_limiter::WorkLimiter, + EndpointConfig, EndpointEvent, VarInt, IO_LOOP_BOUND, RECV_TIME_BOUND, SEND_TIME_BOUND, }; /// A QUIC endpoint. @@ -210,9 +211,10 @@ impl Endpoint { inner.ipv6 = addr.is_ipv6(); // Generate some activity so peers notice the rebind - for sender in inner.connections.senders.values() { - // Ignoring errors from dropped connections - let _ = sender.send(ConnectionEvent::Ping); + for conn in inner.connections.refs.values() { + let mut state = conn.state.lock("ping"); + state.inner.ping(); + state.wake(); } Ok(()) @@ -244,12 +246,9 @@ impl Endpoint { let reason = Bytes::copy_from_slice(reason); let mut endpoint = self.inner.state.lock().unwrap(); endpoint.connections.close = Some((error_code, reason.clone())); - for sender in endpoint.connections.senders.values() { - // Ignoring errors from dropped connections - let _ = sender.send(ConnectionEvent::Close { - error_code, - reason: reason.clone(), - }); + for conn in endpoint.connections.refs.values() { + let mut state = conn.state.lock("close"); + state.close(error_code, reason.clone(), &conn.shared); } self.inner.shared.incoming.notify_waiters(); } @@ -333,9 +332,6 @@ impl Drop for EndpointDriver { let mut endpoint = self.0.state.lock().unwrap(); endpoint.driver_lost = true; self.0.shared.incoming.notify_waiters(); - // Drop all outgoing channels, signaling the termination of the endpoint to the associated - // connections. - endpoint.connections.senders.clear(); } } @@ -408,13 +404,10 @@ impl State { self.incoming.push_back(conn); } Some((handle, DatagramEvent::ConnectionEvent(event))) => { - // Ignoring errors from dropped connections that haven't yet been cleaned up - let _ = self - .connections - .senders - .get_mut(&handle) - .unwrap() - .send(ConnectionEvent::Proto(event)); + let conn = self.connections.refs.get(&handle).unwrap(); + let mut state = conn.state.lock("handle_event"); + state.inner.handle_event(event); + state.wake(); } None => {} } @@ -493,19 +486,16 @@ impl State { Poll::Ready(Some((ch, event))) => match event { Proto(e) => { if e.is_drained() { - self.connections.senders.remove(&ch); + self.connections.refs.remove(&ch); if self.connections.is_empty() { shared.idle.notify_waiters(); } } if let Some(event) = self.inner.handle_event(ch, e) { - // Ignoring errors from dropped connections that haven't yet been cleaned up - let _ = self - .connections - .senders - .get_mut(&ch) - .unwrap() - .send(ConnectionEvent::Proto(event)); + let conn = self.connections.refs.get(&ch).unwrap(); + let mut conn = conn.state.lock("handle_event"); + conn.inner.handle_event(event); + conn.wake(); } } Transmit(t) => self.outgoing.push_back(t), @@ -523,8 +513,7 @@ impl State { #[derive(Debug)] struct ConnectionSet { - /// Senders for communicating with the endpoint's connections - senders: FxHashMap>, + refs: FxHashMap, /// Stored to give out clones to new ConnectionInners sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, /// Set if the endpoint has been manually closed @@ -539,20 +528,17 @@ impl ConnectionSet { udp_state: Arc, runtime: Arc, ) -> Connecting { - let (send, recv) = mpsc::unbounded_channel(); + let (future, conn) = Connecting::new(handle, conn, self.sender.clone(), udp_state, runtime); if let Some((error_code, ref reason)) = self.close { - send.send(ConnectionEvent::Close { - error_code, - reason: reason.clone(), - }) - .unwrap(); + let mut state = conn.state.lock("close"); + state.close(error_code, reason.clone(), &conn.shared); } - self.senders.insert(handle, send); - Connecting::new(handle, conn, self.sender.clone(), recv, udp_state, runtime) + self.refs.insert(handle, conn); + future } fn is_empty(&self) -> bool { - self.senders.is_empty() + self.refs.is_empty() } } @@ -632,7 +618,7 @@ impl EndpointRef { incoming: VecDeque::new(), driver: None, connections: ConnectionSet { - senders: FxHashMap::default(), + refs: FxHashMap::default(), sender, close: None, }, diff --git a/quinn/src/lib.rs b/quinn/src/lib.rs index 4aa366474..e3d420384 100644 --- a/quinn/src/lib.rs +++ b/quinn/src/lib.rs @@ -81,16 +81,6 @@ pub use crate::send_stream::{SendStream, StoppedError, WriteError}; #[cfg(test)] mod tests; -#[derive(Debug)] -enum ConnectionEvent { - Close { - error_code: VarInt, - reason: bytes::Bytes, - }, - Proto(proto::ConnectionEvent), - Ping, -} - #[derive(Debug)] enum EndpointEvent { Proto(proto::EndpointEvent), From f8a823434e869bdc0527d6119a7a153de7d18003 Mon Sep 17 00:00:00 2001 From: Benjamin Saunders Date: Tue, 16 Nov 2021 18:13:00 -0800 Subject: [PATCH 2/5] Unify connection/endpoint drivers --- quinn/Cargo.toml | 1 + quinn/src/connection.rs | 178 ++++++++-------------------------------- quinn/src/endpoint.rs | 145 +++++++++++++++++++------------- quinn/src/lib.rs | 12 --- 4 files changed, 123 insertions(+), 213 deletions(-) diff --git a/quinn/Cargo.toml b/quinn/Cargo.toml index 877b72bfb..36fd00067 100644 --- a/quinn/Cargo.toml +++ b/quinn/Cargo.toml @@ -45,6 +45,7 @@ rustls = { version = "0.20.3", default-features = false, features = ["quic"], op thiserror = "1.0.21" tracing = "0.1.10" tokio = { version = "1.13.0", features = ["sync"] } +tokio-util = { version = "0.6.9", features = ["time"] } udp = { package = "quinn-udp", path = "../quinn-udp", version = "0.3", default-features = false } webpki = { version = "0.22", default-features = false, optional = true } diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index 1c3a65740..e36a983b7 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -1,5 +1,6 @@ use std::{ any::Any, + collections::VecDeque, fmt, future::Future, net::{IpAddr, SocketAddr}, @@ -9,13 +10,16 @@ use std::{ time::{Duration, Instant}, }; -use crate::runtime::{AsyncTimer, Runtime}; use bytes::Bytes; use pin_project_lite::pin_project; use proto::{ConnectionError, ConnectionHandle, ConnectionStats, Dir, StreamEvent, StreamId}; use rustc_hash::FxHashMap; use thiserror::Error; -use tokio::sync::{futures::Notified, mpsc, oneshot, Notify}; +use tokio::{ + sync::{futures::Notified, mpsc, oneshot, Notify}, + time::Instant as TokioInstant, +}; +use tokio_util::time::delay_queue; use tracing::debug_span; use udp::UdpState; @@ -23,7 +27,7 @@ use crate::{ mutex::Mutex, recv_stream::RecvStream, send_stream::{SendStream, WriteError}, - EndpointEvent, VarInt, + VarInt, }; use proto::congestion::Controller; @@ -38,25 +42,22 @@ pub struct Connecting { impl Connecting { pub(crate) fn new( + dirty: mpsc::UnboundedSender, handle: ConnectionHandle, conn: proto::Connection, - endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, udp_state: Arc, - runtime: Arc, ) -> (Connecting, ConnectionRef) { let (on_handshake_data_send, on_handshake_data_recv) = oneshot::channel(); let (on_connected_send, on_connected_recv) = oneshot::channel(); let conn = ConnectionRef::new( handle, conn, - endpoint_events, + dirty, on_handshake_data_send, on_connected_send, udp_state, - runtime.clone(), ); - runtime.spawn(Box::pin(ConnectionDriver(conn.clone()))); ( Connecting { conn: Some(conn.clone()), @@ -202,53 +203,6 @@ impl Future for ZeroRttAccepted { } } -/// A future that drives protocol logic for a connection -/// -/// This future handles the protocol logic for a single connection, routing events from the -/// `Connection` API object to the `Endpoint` task and the related stream-related interfaces. -/// It also keeps track of outstanding timeouts for the `Connection`. -/// -/// If the connection encounters an error condition, this future will yield an error. It will -/// terminate (yielding `Ok(())`) if the connection was closed without error. Unlike other -/// connection-related futures, this waits for the draining period to complete to ensure that -/// packets still in flight from the peer are handled gracefully. -#[must_use = "connection drivers must be spawned for their connections to function"] -#[derive(Debug)] -struct ConnectionDriver(ConnectionRef); - -impl Future for ConnectionDriver { - type Output = (); - - #[allow(unused_mut)] // MSRV - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let conn = &mut *self.0.state.lock("poll"); - - let span = debug_span!("drive", id = conn.handle.0); - let _guard = span.enter(); - - let mut keep_going = conn.drive_transmit(); - // If a timer expires, there might be more to transmit. When we transmit something, we - // might need to reset a timer. Hence, we must loop until neither happens. - keep_going |= conn.drive_timer(cx); - conn.forward_endpoint_events(); - conn.forward_app_events(&self.0.shared); - - if !conn.inner.is_drained() { - if keep_going { - // If the connection hasn't processed all tasks, schedule it again - cx.waker().wake_by_ref(); - } else { - conn.driver = Some(cx.waker().clone()); - } - return Poll::Pending; - } - if conn.error.is_none() { - unreachable!("drained connections always have an error"); - } - Poll::Ready(()) - } -} - /// A QUIC connection. /// /// If all references to a connection (including every clone of the `Connection` handle, streams of @@ -741,23 +695,24 @@ impl ConnectionRef { fn new( handle: ConnectionHandle, conn: proto::Connection, - endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, + dirty: mpsc::UnboundedSender, on_handshake_data: oneshot::Sender<()>, on_connected: oneshot::Sender, udp_state: Arc, - runtime: Arc, ) -> Self { + let _ = dirty.send(handle); Self(Arc::new(ConnectionInner { state: Mutex::new(State { inner: conn, - driver: None, handle, + span: debug_span!("connection", id = handle.0), + is_dirty: true, + dirty, on_handshake_data: Some(on_handshake_data), on_connected: Some(on_connected), connected: false, - timer: None, + timer_handle: None, timer_deadline: None, - endpoint_events, blocked_writers: FxHashMap::default(), blocked_readers: FxHashMap::default(), finishing: FxHashMap::default(), @@ -765,7 +720,6 @@ impl ConnectionRef { error: None, ref_count: 0, udp_state, - runtime, }), shared: Shared::default(), })) @@ -825,14 +779,18 @@ pub(crate) struct Shared { pub(crate) struct State { pub(crate) inner: proto::Connection, - driver: Option, handle: ConnectionHandle, + pub(crate) span: tracing::Span, + /// Whether `handle` has been sent to `dirty` since the last time this connection was driven by + /// the endpoint. Ensures `dirty`'s size stays bounded regardless of activity. + pub(crate) is_dirty: bool, + /// `handle` is sent here when `wake` is called, prompting the endpoint to drive the connection + dirty: mpsc::UnboundedSender, on_handshake_data: Option>, on_connected: Option>, connected: bool, - timer: Option>>, - timer_deadline: Option, - endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, + pub(crate) timer_handle: Option, + pub(crate) timer_deadline: Option, pub(crate) blocked_writers: FxHashMap, pub(crate) blocked_readers: FxHashMap, pub(crate) finishing: FxHashMap>>, @@ -842,11 +800,10 @@ pub(crate) struct State { /// Number of live handles that can be used to initiate or handle I/O; excludes the driver ref_count: usize, udp_state: Arc, - runtime: Arc, } impl State { - fn drive_transmit(&mut self) -> bool { + pub(crate) fn drive_transmit(&mut self, out: &mut VecDeque) -> bool { let now = Instant::now(); let mut transmits = 0; @@ -857,10 +814,7 @@ impl State { None => 1, Some(s) => (t.contents.len() + s - 1) / s, // round up }; - // If the endpoint driver is gone, noop. - let _ = self - .endpoint_events - .send((self.handle, EndpointEvent::Transmit(t))); + out.push_back(t); if transmits >= MAX_TRANSMIT_DATAGRAMS { // TODO: What isn't ideal here yet is that if we don't poll all @@ -874,16 +828,7 @@ impl State { false } - fn forward_endpoint_events(&mut self) { - while let Some(event) = self.inner.poll_endpoint_events() { - // If the endpoint driver is gone, noop. - let _ = self - .endpoint_events - .send((self.handle, EndpointEvent::Proto(event))); - } - } - - fn forward_app_events(&mut self, shared: &Shared) { + pub(crate) fn forward_app_events(&mut self, shared: &Shared) { while let Some(event) = self.inner.poll() { use proto::Event::*; match event { @@ -949,61 +894,14 @@ impl State { } } - fn drive_timer(&mut self, cx: &mut Context) -> bool { - // Check whether we need to (re)set the timer. If so, we must poll again to ensure the - // timer is registered with the runtime (and check whether it's already - // expired). - match self.inner.poll_timeout() { - Some(deadline) => { - if let Some(delay) = &mut self.timer { - // There is no need to reset the tokio timer if the deadline - // did not change - if self - .timer_deadline - .map(|current_deadline| current_deadline != deadline) - .unwrap_or(true) - { - delay.as_mut().reset(deadline); - } - } else { - self.timer = Some(self.runtime.new_timer(deadline)); - } - // Store the actual expiration time of the timer - self.timer_deadline = Some(deadline); - } - None => { - self.timer_deadline = None; - return false; - } - } - - if self.timer_deadline.is_none() { - return false; - } - - let delay = self - .timer - .as_mut() - .expect("timer must exist in this state") - .as_mut(); - if delay.poll(cx).is_pending() { - // Since there wasn't a timeout event, there is nothing new - // for the connection to do - return false; - } - - // A timer expired, so the caller needs to check for - // new transmits, which might cause new timers to be set. - self.inner.handle_timeout(Instant::now()); - self.timer_deadline = None; - true - } - - /// Wake up a blocked `Driver` task to process I/O + /// Wake up endpoint to process I/O by marking it as "dirty" for the endpoint pub(crate) fn wake(&mut self) { - if let Some(x) = self.driver.take() { - x.wake(); + if self.is_dirty { + return; } + self.is_dirty = true; + // Take no action if the endpoint is gone + let _ = self.dirty.send(self.handle); } /// Used to wake up all blocked futures when the connection becomes closed for any reason @@ -1058,18 +956,6 @@ impl State { } } -impl Drop for State { - fn drop(&mut self) { - if !self.inner.is_drained() { - // Ensure the endpoint can tidy up - let _ = self.endpoint_events.send(( - self.handle, - EndpointEvent::Proto(proto::EndpointEvent::drained()), - )); - } - } -} - impl fmt::Debug for State { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("State").field("inner", &self.inner).finish() diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index 94718f45a..1265452d0 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -20,12 +20,13 @@ use proto::{ }; use rustc_hash::FxHashMap; use tokio::sync::{futures::Notified, mpsc, Notify}; +use tokio_util::time::DelayQueue; use udp::{RecvMeta, UdpState, BATCH_SIZE}; use crate::{ connection::{Connecting, ConnectionRef}, work_limiter::WorkLimiter, - EndpointConfig, EndpointEvent, VarInt, IO_LOOP_BOUND, RECV_TIME_BOUND, SEND_TIME_BOUND, + EndpointConfig, VarInt, RECV_TIME_BOUND, SEND_TIME_BOUND, }; /// A QUIC endpoint. @@ -119,7 +120,6 @@ impl Endpoint { socket, proto::Endpoint::new(Arc::new(config), server_config.map(Arc::new)), addr.is_ipv6(), - runtime.clone(), ); let driver = EndpointDriver(rc.clone()); runtime.spawn(Box::pin(async { @@ -192,9 +192,8 @@ impl Endpoint { }; let (ch, conn) = endpoint.inner.connect(config, addr, server_name)?; let udp_state = endpoint.udp_state.clone(); - Ok(endpoint - .connections - .insert(ch, conn, udp_state, self.runtime.clone())) + let dirty = endpoint.dirty_send.clone(); + Ok(endpoint.connections.insert(dirty, ch, conn, udp_state)) } /// Switch to a new UDP socket @@ -297,15 +296,13 @@ impl Future for EndpointDriver { #[allow(unused_mut)] // MSRV fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let mut endpoint = self.0.state.lock().unwrap(); + let mut endpoint = &mut *self.0.state.lock().unwrap(); if endpoint.driver.is_none() { endpoint.driver = Some(cx.waker().clone()); } - let now = Instant::now(); - let mut keep_going = false; - keep_going |= endpoint.drive_recv(cx, now)?; - keep_going |= endpoint.handle_events(cx, &self.0.shared); + let mut keep_going = endpoint.drive_recv(cx, Instant::now())?; + keep_going |= endpoint.drive_connections(cx, &self.0.shared); keep_going |= endpoint.drive_send(cx)?; if !endpoint.incoming.is_empty() { @@ -315,10 +312,6 @@ impl Future for EndpointDriver { if endpoint.ref_count == 0 && endpoint.connections.is_empty() { Poll::Ready(Ok(())) } else { - drop(endpoint); - // If there is more work to do schedule the endpoint task again. - // `wake_by_ref()` is called outside the lock to minimize - // lock contention on a multithreaded runtime. if keep_going { cx.waker().wake_by_ref(); } @@ -351,14 +344,19 @@ pub(crate) struct State { driver: Option, ipv6: bool, connections: ConnectionSet, - events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>, /// Number of live handles that can be used to initiate or handle I/O; excludes the driver ref_count: usize, driver_lost: bool, recv_limiter: WorkLimiter, recv_buf: Box<[u8]>, send_limiter: WorkLimiter, - runtime: Arc, + /// Connections add themselves to this queue when they need to be driven + /// + /// Occurs e.g. due to application-layer activity + dirty_recv: mpsc::UnboundedReceiver, + /// Passed in to connections to enable the above + dirty_send: mpsc::UnboundedSender, + timers: DelayQueue, } #[derive(Debug)] @@ -396,10 +394,10 @@ impl State { { Some((handle, DatagramEvent::NewConnection(conn))) => { let conn = self.connections.insert( + self.dirty_send.clone(), handle, conn, self.udp_state.clone(), - self.runtime.clone(), ); self.incoming.push_back(conn); } @@ -478,44 +476,86 @@ impl State { result } - fn handle_events(&mut self, cx: &mut Context, shared: &Shared) -> bool { - use EndpointEvent::*; - - for _ in 0..IO_LOOP_BOUND { - match self.events.poll_recv(cx) { - Poll::Ready(Some((ch, event))) => match event { - Proto(e) => { - if e.is_drained() { - self.connections.refs.remove(&ch); - if self.connections.is_empty() { - shared.idle.notify_waiters(); - } - } - if let Some(event) = self.inner.handle_event(ch, e) { - let conn = self.connections.refs.get(&ch).unwrap(); - let mut conn = conn.state.lock("handle_event"); - conn.inner.handle_event(event); - conn.wake(); + /// Process connections on which there's been timeouts, packets received, or application + /// activity ("dirty" connections) + fn drive_connections(&mut self, cx: &mut Context, shared: &Shared) -> bool { + let mut keep_going = false; + + while let Poll::Ready(Some(result)) = self.timers.poll_expired(cx) { + let conn_handle = result.unwrap().into_inner(); + let conn = match self.connections.refs.get(&conn_handle) { + Some(c) => c, + None => continue, + }; + let mut state = &mut *conn.state.lock("poll timeouts"); + let _guard = state.span.clone().entered(); + state.inner.handle_timeout(Instant::now()); + state.timer_handle = None; + state.timer_deadline = None; + state.wake(); + } + + let mut dirty_buffer = Vec::new(); + + // Buffer the list of initially dirty connections, guaranteeing that the connection + // processing loop below takes a predictable amount of time. + while let Poll::Ready(Some(conn_handle)) = self.dirty_recv.poll_recv(cx) { + dirty_buffer.push(conn_handle); + } + + let mut drained = Vec::new(); + for conn_handle in dirty_buffer { + let conn = match self.connections.refs.get(&conn_handle) { + Some(c) => c, + None => continue, + }; + let mut state = conn.state.lock("poll dirty"); + state.is_dirty = false; + let _guard = state.span.clone().entered(); + let mut keep_conn_going = state.drive_transmit(&mut self.outgoing); + if let Some(deadline) = state.inner.poll_timeout() { + let deadline = tokio::time::Instant::from(deadline); + if Some(deadline) != state.timer_deadline { + match state.timer_handle { + Some(ref key) => self.timers.reset_at(key, deadline), + None => { + state.timer_handle = Some(self.timers.insert_at(conn_handle, deadline)); } } - Transmit(t) => self.outgoing.push_back(t), - }, - Poll::Ready(None) => unreachable!("EndpointInner owns one sender"), - Poll::Pending => { - return false; + // self.timers may need to be polled + keep_going = true; } } + while let Some(event) = state.inner.poll_endpoint_events() { + if event.is_drained() { + drained.push(conn_handle); + } + if let Some(event) = self.inner.handle_event(conn_handle, event) { + state.inner.handle_event(event); + keep_conn_going = true; + } + } + state.forward_app_events(&conn.shared); + if keep_conn_going { + state.wake(); + keep_going = true; + } } - true + for conn_handle in drained { + self.connections.refs.remove(&conn_handle); + } + if self.connections.is_empty() { + shared.idle.notify_waiters(); + } + + keep_going } } #[derive(Debug)] struct ConnectionSet { refs: FxHashMap, - /// Stored to give out clones to new ConnectionInners - sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, /// Set if the endpoint has been manually closed close: Option<(VarInt, Bytes)>, } @@ -523,12 +563,12 @@ struct ConnectionSet { impl ConnectionSet { fn insert( &mut self, + dirty: mpsc::UnboundedSender, handle: ConnectionHandle, conn: proto::Connection, udp_state: Arc, - runtime: Arc, ) -> Connecting { - let (future, conn) = Connecting::new(handle, conn, self.sender.clone(), udp_state, runtime); + let (future, conn) = Connecting::new(dirty, handle, conn, udp_state); if let Some((error_code, ref reason)) = self.close { let mut state = conn.state.lock("close"); state.close(error_code, reason.clone(), &conn.shared); @@ -589,12 +629,7 @@ impl<'a> Future for Accept<'a> { pub(crate) struct EndpointRef(Arc); impl EndpointRef { - pub(crate) fn new( - socket: Box, - inner: proto::Endpoint, - ipv6: bool, - runtime: Arc, - ) -> Self { + pub(crate) fn new(socket: Box, inner: proto::Endpoint, ipv6: bool) -> Self { let udp_state = Arc::new(UdpState::new()); let recv_buf = vec![ 0; @@ -602,7 +637,7 @@ impl EndpointRef { * udp_state.gro_segments() * BATCH_SIZE ]; - let (sender, events) = mpsc::unbounded_channel(); + let (dirty_send, dirty_recv) = mpsc::unbounded_channel(); Self(Arc::new(EndpointInner { shared: Shared { incoming: Notify::new(), @@ -613,13 +648,11 @@ impl EndpointRef { udp_state, inner, ipv6, - events, outgoing: VecDeque::new(), incoming: VecDeque::new(), driver: None, connections: ConnectionSet { refs: FxHashMap::default(), - sender, close: None, }, ref_count: 0, @@ -627,7 +660,9 @@ impl EndpointRef { recv_buf: recv_buf.into(), recv_limiter: WorkLimiter::new(RECV_TIME_BOUND), send_limiter: WorkLimiter::new(SEND_TIME_BOUND), - runtime, + dirty_recv, + dirty_send, + timers: DelayQueue::new(), }), })) } diff --git a/quinn/src/lib.rs b/quinn/src/lib.rs index e3d420384..70cff11e8 100644 --- a/quinn/src/lib.rs +++ b/quinn/src/lib.rs @@ -81,18 +81,6 @@ pub use crate::send_stream::{SendStream, StoppedError, WriteError}; #[cfg(test)] mod tests; -#[derive(Debug)] -enum EndpointEvent { - Proto(proto::EndpointEvent), - Transmit(proto::Transmit), -} - -/// Maximum number of datagrams processed in send/recv calls to make before moving on to other processing -/// -/// This helps ensure we don't starve anything when the CPU is slower than the link. -/// Value is selected by picking a low number which didn't degrade throughput in benchmarks. -const IO_LOOP_BOUND: usize = 160; - /// The maximum amount of time that should be spent in `recvmsg()` calls per endpoint iteration /// /// 50us are chosen so that an endpoint iteration with a 50us sendmsg limit blocks From d38ec466cf43e933a70129819080b80f08772650 Mon Sep 17 00:00:00 2001 From: Benjamin Saunders Date: Sat, 24 Sep 2022 09:56:42 -0700 Subject: [PATCH 3/5] Remove unnecessary sharing of UdpState --- quinn/src/connection.rs | 14 +++++--------- quinn/src/endpoint.rs | 14 ++++++-------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index e36a983b7..63d3660ee 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -21,7 +21,6 @@ use tokio::{ }; use tokio_util::time::delay_queue; use tracing::debug_span; -use udp::UdpState; use crate::{ mutex::Mutex, @@ -45,7 +44,6 @@ impl Connecting { dirty: mpsc::UnboundedSender, handle: ConnectionHandle, conn: proto::Connection, - udp_state: Arc, ) -> (Connecting, ConnectionRef) { let (on_handshake_data_send, on_handshake_data_recv) = oneshot::channel(); let (on_connected_send, on_connected_recv) = oneshot::channel(); @@ -55,7 +53,6 @@ impl Connecting { dirty, on_handshake_data_send, on_connected_send, - udp_state, ); ( @@ -698,7 +695,6 @@ impl ConnectionRef { dirty: mpsc::UnboundedSender, on_handshake_data: oneshot::Sender<()>, on_connected: oneshot::Sender, - udp_state: Arc, ) -> Self { let _ = dirty.send(handle); Self(Arc::new(ConnectionInner { @@ -719,7 +715,6 @@ impl ConnectionRef { stopped: FxHashMap::default(), error: None, ref_count: 0, - udp_state, }), shared: Shared::default(), })) @@ -799,16 +794,17 @@ pub(crate) struct State { pub(crate) error: Option, /// Number of live handles that can be used to initiate or handle I/O; excludes the driver ref_count: usize, - udp_state: Arc, } impl State { - pub(crate) fn drive_transmit(&mut self, out: &mut VecDeque) -> bool { + pub(crate) fn drive_transmit( + &mut self, + out: &mut VecDeque, + max_datagrams: usize, + ) -> bool { let now = Instant::now(); let mut transmits = 0; - let max_datagrams = self.udp_state.max_gso_segments(); - while let Some(t) = self.inner.poll_transmit(now, max_datagrams) { transmits += match t.segment_size { None => 1, diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index 1265452d0..254464dad 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -191,9 +191,8 @@ impl Endpoint { addr }; let (ch, conn) = endpoint.inner.connect(config, addr, server_name)?; - let udp_state = endpoint.udp_state.clone(); let dirty = endpoint.dirty_send.clone(); - Ok(endpoint.connections.insert(dirty, ch, conn, udp_state)) + Ok(endpoint.connections.insert(dirty, ch, conn)) } /// Switch to a new UDP socket @@ -337,7 +336,7 @@ pub(crate) struct EndpointInner { #[derive(Debug)] pub(crate) struct State { socket: Box, - udp_state: Arc, + udp_state: UdpState, inner: proto::Endpoint, outgoing: VecDeque, incoming: VecDeque, @@ -397,7 +396,6 @@ impl State { self.dirty_send.clone(), handle, conn, - self.udp_state.clone(), ); self.incoming.push_back(conn); } @@ -503,6 +501,7 @@ impl State { dirty_buffer.push(conn_handle); } + let max_datagrams = self.udp_state.max_gso_segments(); let mut drained = Vec::new(); for conn_handle in dirty_buffer { let conn = match self.connections.refs.get(&conn_handle) { @@ -512,7 +511,7 @@ impl State { let mut state = conn.state.lock("poll dirty"); state.is_dirty = false; let _guard = state.span.clone().entered(); - let mut keep_conn_going = state.drive_transmit(&mut self.outgoing); + let mut keep_conn_going = state.drive_transmit(&mut self.outgoing, max_datagrams); if let Some(deadline) = state.inner.poll_timeout() { let deadline = tokio::time::Instant::from(deadline); if Some(deadline) != state.timer_deadline { @@ -566,9 +565,8 @@ impl ConnectionSet { dirty: mpsc::UnboundedSender, handle: ConnectionHandle, conn: proto::Connection, - udp_state: Arc, ) -> Connecting { - let (future, conn) = Connecting::new(dirty, handle, conn, udp_state); + let (future, conn) = Connecting::new(dirty, handle, conn); if let Some((error_code, ref reason)) = self.close { let mut state = conn.state.lock("close"); state.close(error_code, reason.clone(), &conn.shared); @@ -630,7 +628,7 @@ pub(crate) struct EndpointRef(Arc); impl EndpointRef { pub(crate) fn new(socket: Box, inner: proto::Endpoint, ipv6: bool) -> Self { - let udp_state = Arc::new(UdpState::new()); + let udp_state = UdpState::new(); let recv_buf = vec![ 0; inner.config().get_max_udp_payload_size().min(64 * 1024) as usize From aa92ba963df1d949b2a62622a305e145751a421e Mon Sep 17 00:00:00 2001 From: Benjamin Saunders Date: Tue, 27 Sep 2022 12:34:05 -0700 Subject: [PATCH 4/5] Move timer handling into an isolated method --- quinn/src/endpoint.rs | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index 254464dad..bc2946c7c 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -432,6 +432,22 @@ impl State { Ok(false) } + fn drive_timers(&mut self, cx: &mut Context, now: Instant) { + while let Poll::Ready(Some(result)) = self.timers.poll_expired(cx) { + let conn_handle = result.unwrap().into_inner(); + let conn = match self.connections.refs.get(&conn_handle) { + Some(c) => c, + None => continue, + }; + let mut state = &mut *conn.state.lock("poll timeouts"); + let _guard = state.span.clone().entered(); + state.inner.handle_timeout(now); + state.timer_handle = None; + state.timer_deadline = None; + state.wake(); + } + } + fn drive_send(&mut self, cx: &mut Context) -> Result { self.send_limiter.start_cycle(); @@ -479,19 +495,7 @@ impl State { fn drive_connections(&mut self, cx: &mut Context, shared: &Shared) -> bool { let mut keep_going = false; - while let Poll::Ready(Some(result)) = self.timers.poll_expired(cx) { - let conn_handle = result.unwrap().into_inner(); - let conn = match self.connections.refs.get(&conn_handle) { - Some(c) => c, - None => continue, - }; - let mut state = &mut *conn.state.lock("poll timeouts"); - let _guard = state.span.clone().entered(); - state.inner.handle_timeout(Instant::now()); - state.timer_handle = None; - state.timer_deadline = None; - state.wake(); - } + self.drive_timers(cx, Instant::now()); let mut dirty_buffer = Vec::new(); From 32c80cca473be96577cab99db6404df7e08a31ed Mon Sep 17 00:00:00 2001 From: Benjamin Saunders Date: Sun, 28 Aug 2022 14:50:13 -0700 Subject: [PATCH 5/5] Portable delay queue --- quinn/Cargo.toml | 3 +- quinn/src/connection.rs | 11 +- quinn/src/delay_queue.rs | 588 +++++++++++++++++++++++++++++++++++++++ quinn/src/endpoint.rs | 58 +++- quinn/src/lib.rs | 1 + 5 files changed, 642 insertions(+), 19 deletions(-) create mode 100644 quinn/src/delay_queue.rs diff --git a/quinn/Cargo.toml b/quinn/Cargo.toml index 36fd00067..509224b5c 100644 --- a/quinn/Cargo.toml +++ b/quinn/Cargo.toml @@ -42,10 +42,10 @@ rustc-hash = "1.1" pin-project-lite = "0.2" proto = { package = "quinn-proto", path = "../quinn-proto", version = "0.9", default-features = false } rustls = { version = "0.20.3", default-features = false, features = ["quic"], optional = true } +slab = "0.4" thiserror = "1.0.21" tracing = "0.1.10" tokio = { version = "1.13.0", features = ["sync"] } -tokio-util = { version = "0.6.9", features = ["time"] } udp = { package = "quinn-udp", path = "../quinn-udp", version = "0.3", default-features = false } webpki = { version = "0.22", default-features = false, optional = true } @@ -54,6 +54,7 @@ anyhow = "1.0.22" crc = "3" bencher = "0.1.5" directories-next = "2" +proptest = "=1.0.0" # Pinned for MSRV rand = "0.8" rcgen = "0.10.0" rustls-pemfile = "1.0.0" diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index 63d3660ee..20cdc6671 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -15,14 +15,11 @@ use pin_project_lite::pin_project; use proto::{ConnectionError, ConnectionHandle, ConnectionStats, Dir, StreamEvent, StreamId}; use rustc_hash::FxHashMap; use thiserror::Error; -use tokio::{ - sync::{futures::Notified, mpsc, oneshot, Notify}, - time::Instant as TokioInstant, -}; -use tokio_util::time::delay_queue; +use tokio::sync::{futures::Notified, mpsc, oneshot, Notify}; use tracing::debug_span; use crate::{ + delay_queue::Timer, mutex::Mutex, recv_stream::RecvStream, send_stream::{SendStream, WriteError}, @@ -784,8 +781,8 @@ pub(crate) struct State { on_handshake_data: Option>, on_connected: Option>, connected: bool, - pub(crate) timer_handle: Option, - pub(crate) timer_deadline: Option, + pub(crate) timer_handle: Option, + pub(crate) timer_deadline: Option, pub(crate) blocked_writers: FxHashMap, pub(crate) blocked_readers: FxHashMap, pub(crate) finishing: FxHashMap>>, diff --git a/quinn/src/delay_queue.rs b/quinn/src/delay_queue.rs new file mode 100644 index 000000000..14266edbc --- /dev/null +++ b/quinn/src/delay_queue.rs @@ -0,0 +1,588 @@ +use std::{fmt, ops::RangeInclusive}; + +use slab::Slab; + +/// Stores values to be yielded at specific times in the future +/// +/// Time is expressed as a bare u64 representing an absolute point in time. The caller may use any +/// consistent unit, e.g. milliseconds, and any consistent definition of time zero. Larger units +/// limit resolution but make `poll`ing over the same real-time interval proportionately faster, +/// whereas smaller units improve resolution, limit total range, and reduce `poll` performance. +#[derive(Debug)] +pub struct DelayQueue { + /// Definitions of each active timer + /// + /// Timers are defined here, and referenced indirectly by index from `levels` and in the public + /// API. This allows for safe construction of intrusive linked lists between timers, and helps + /// reduce the amount of data that needs to be routinely shuffled around in `levels` as time + /// passes. + timers: Slab>, + + /// A hierarchical timer wheel + /// + /// This data structure breaks down points in time into digits. The base of those digits can be + /// chosen arbitrarily; this implementation uses base `2^LOG_2_SLOTS`. A power of two makes it + /// easy to manipulate individual digits using bit shifts and masking because each digit + /// corresponds directly to `LOG_2_SLOTS` bits in the binary representation. For familiarity, we + /// will illustrate a timer wheel built instead on base 10, but the behavior is identical. + /// + /// Consider this timer wheel where timers are set at times 32, 42, and 46, and `next_tick` is + /// between 30 and 32 inclusive. Note that the number of slots in each level is equal to the + /// base of the digits used, in this case 10. + /// + /// ```text + /// +--+--+--+--+-- + /// Level 0 |30|31|32|33| ... + /// +--+--+--+--+-- + /// \ | / + /// \ V / + /// \ +--+ / + /// \ |32| / + /// \+--+ / + /// \ / + /// +--+--+--+--+--+--+--+--+--+--+ + /// Level 1 |00|10|20|30|40|50|60|70|80|90| + /// +--+--+--+--+--+--+--+--+--+--+ + /// | + /// V + /// +--+ + /// |46| + /// +--+ + /// ^| + /// |V + /// +--+ + /// |42| + /// +--+ + /// ``` + /// + /// Timers are organized into buckets (or slots) at a resolution that decreases exponentially + /// with distance from `next_tick`, the present. Higher-numbered levels cover larger intervals, + /// until the highest-numbered level covers the complete representable of timers, from 0 to + /// `u64::MAX`. Every lower level covers the slot in the next highest level which `next_tick` + /// lies within. Level 0 represents the maximum resolution, where each slot covers exactly one + /// unit of time. + /// + /// The slot that a timer should be stored in is easily computed based on `next_tick` and the + /// desired expiry time. For a base 10 structure, find the most significant digit in the base 10 + /// representations of `next_tick` and the desired expiry time that differs between the two. The + /// position of that digit is the level, and the value of that digit is the position in the + /// level. For example, if `next_tick` is 7342, and a timer is scheduled for time 7361, the + /// timer would be stored at level 1, slot 6. Note that no subtraction is performed: the start + /// of each level is always the greatest integer multiple of the level's span which is less than + /// or equal to `next_tick`. + /// + /// Calls to `poll` move `next_tick` towards the passed-in time. When `next_tick` reaches a + /// timer in level 0, it stops there and the timer is removed and returned from `poll`. Reaching + /// the end of level 0 redefines level 0 to represent the next slot in level 1, at which point + /// all timers stored in that slot are unpacked into appropriate slots of level 0, and traversal + /// of level 0 begins again from the start. When level 1 is exhausted, the next slot in level 2 + /// is unpacked into levels 1 and 0, and so on for higher levels. Slots preceding `next_tick` + /// are therefore empty at any level, and for levels above 0, the slot containing `next_tick` is + /// also empty, having necessarily been unpacked into lower levels. + /// + /// Assuming the number of timers scheduled within a period of time is on average proportional + /// to the size of that period, advancing the queue by a constant amount of time has amortized + /// constant time complexity, because the frequency with which slots at a particular level are + /// unpacked is inversely proportional to the expected number of timers stored in that + /// slot. + /// + /// Inserting, removing, and updating timers are constant-time operations thanks to the above + /// and the use of unordered doubly linked lists to represent the contents of a slot. We can + /// also compute a lower bound for the next timeout in constant time by scanning for the + /// earliest nonempty slot. + levels: [Level; LEVELS], + + /// Earliest point at which a timer may be pending + /// + /// Each `LOG_2_SLOTS` bits of this are a cursor into the associated level, in order of + /// ascending significance. + next_tick: u64, +} + +impl DelayQueue { + /// Create an empty queue starting at time `0` + pub fn new() -> Self { + Self { + timers: Slab::new(), + levels: [Level::new(); LEVELS], + next_tick: 0, + } + } + + /// Returns a timer that has expired by `now`, if any + /// + /// `now` must be at least the largest previously passed value + pub fn poll(&mut self, now: u64) -> Option { + debug_assert!(now >= self.next_tick, "time advances monotonically"); + loop { + // Advance towards the next timeout + self.advance_towards(now); + // Check for timeouts in the immediate future + if let Some(value) = self.scan_bottom(now) { + return Some(value); + } + // If we can't advance any further, bail out + if self.next_tick >= now { + return None; + } + } + } + + /// Find a timer expired by `now` in level 0 + fn scan_bottom(&mut self, now: u64) -> Option { + if let Some((slot, timer)) = self.levels[0].slots[range_in_level(0, self.next_tick..=now)] + .iter_mut() + .find_map(|x| x.take().map(|timer| (x, timer))) + { + let state = self.timers.remove(timer.0); + debug_assert_eq!(state.prev, None, "head of list has no predecessor"); + debug_assert!(state.expiry <= now); + if let Some(next) = state.next { + debug_assert_eq!( + self.timers[next.0].prev, + Some(timer), + "successor links to head" + ); + self.timers[next.0].prev = None; + } + *slot = state.next; + self.next_tick = state.expiry; + self.maybe_shrink(); + return Some(state.value); + } + None + } + + /// Advance to the start of the first nonempty slot or `now`, whichever is sooner + fn advance_towards(&mut self, now: u64) { + for level in 0..LEVELS { + for slot in range_in_level(level, self.next_tick..=now) { + debug_assert!( + now >= slot_start(self.next_tick, level, slot), + "slot overlaps with the past" + ); + if self.levels[level].slots[slot].is_some() { + self.advance_to(level, slot); + return; + } + } + } + self.next_tick = now; + } + + /// Advance to a specific slot, which must be the first nonempty slot + fn advance_to(&mut self, level: usize, slot: usize) { + debug_assert!( + self.levels[..level] + .iter() + .all(|level| level.slots.iter().all(|x| x.is_none())), + "lower levels are empty" + ); + debug_assert!( + self.levels[level].slots[..slot].iter().all(Option::is_none), + "lower slots in this level are empty" + ); + + // Advance into the slot + self.next_tick = slot_start(self.next_tick, level, slot); + + if level == 0 { + // No lower levels exist to unpack timers into + return; + } + + // Unpack all timers in this slot into lower levels + while let Some(timer) = self.levels[level].slots[slot].take() { + let next = self.timers[timer.0].next; + self.levels[level].slots[slot] = next; + if let Some(next) = next { + self.timers[next.0].prev = None; + } + self.list_unlink(timer); + self.schedule(timer); + } + } + + /// Link `timer` from the slot associated with its expiry + fn schedule(&mut self, timer: Timer) { + debug_assert_eq!( + self.timers[timer.0].next, None, + "timer isn't already scheduled" + ); + debug_assert_eq!( + self.timers[timer.0].prev, None, + "timer isn't already scheduled" + ); + let (level, slot) = timer_index(self.next_tick, self.timers[timer.0].expiry); + // Insert `timer` at the head of the list in the target slot + let head = self.levels[level].slots[slot]; + self.timers[timer.0].next = head; + if let Some(head) = head { + self.timers[head.0].prev = Some(timer); + } + self.levels[level].slots[slot] = Some(timer); + } + + /// Lower bound on when the next timer will expire, if any + pub fn next_timeout(&self) -> Option { + for level in 0..LEVELS { + let start = ((self.next_tick >> (level * LOG_2_SLOTS)) & (SLOTS - 1) as u64) as usize; + for slot in start..SLOTS { + if self.levels[level].slots[slot].is_some() { + return Some(slot_start(self.next_tick, level, slot)); + } + } + } + None + } + + /// Register a timer that will yield `value` at `timeout` + pub fn insert(&mut self, timeout: u64, value: T) -> Timer { + let timer = Timer(self.timers.insert(TimerState { + expiry: timeout.max(self.next_tick), + prev: None, + next: None, + value, + })); + self.schedule(timer); + timer + } + + /// Adjust `timer` to expire at `timeout` + pub fn reset(&mut self, timer: Timer, timeout: u64) { + self.unlink(timer); + self.timers[timer.0].expiry = timeout.max(self.next_tick); + self.schedule(timer); + } + + /// Cancel `timer` + #[cfg(test)] + pub fn remove(&mut self, timer: Timer) -> T { + self.unlink(timer); + let state = self.timers.remove(timer.0); + self.maybe_shrink(); + state.value + } + + /// Release timer state memory if it's mostly unused + fn maybe_shrink(&mut self) { + if self.timers.capacity() / 16 > self.timers.len() { + self.timers.shrink_to_fit(); + } + } + + /// Remove all references to `timer` + fn unlink(&mut self, timer: Timer) { + let (level, slot) = timer_index(self.next_tick, self.timers[timer.0].expiry); + // If necessary, remove a reference to `timer` from its slot by replacing it with its + // successor + let slot_head = self.levels[level].slots[slot].unwrap(); + if slot_head == timer { + self.levels[level].slots[slot] = self.timers[slot_head.0].next; + debug_assert_eq!( + self.timers[timer.0].prev, None, + "head of list has no predecessor" + ); + } + // Remove references to `timer` from other timers + self.list_unlink(timer); + } + + /// Remove `timer` from its list + fn list_unlink(&mut self, timer: Timer) { + let prev = self.timers[timer.0].prev.take(); + let next = self.timers[timer.0].next.take(); + if let Some(prev) = prev { + // Remove reference from predecessor + self.timers[prev.0].next = next; + } + if let Some(next) = next { + // Remove reference from successor + self.timers[next.0].prev = prev; + } + } +} + +fn range_in_level(level: usize, raw: RangeInclusive) -> RangeInclusive { + let shift = level * LOG_2_SLOTS; + const MASK: u64 = SLOTS as u64 - 1; + let start = ((*raw.start() >> shift) & MASK) as usize; + let level_end = (*raw.start() >> shift) | MASK; + let end = ((*raw.end() >> shift).min(level_end) & MASK) as usize; + start..=end +} + +/// Compute the first tick that lies within a slot +fn slot_start(base: u64, level: usize, slot: usize) -> u64 { + let shift = (level * LOG_2_SLOTS) as u64; + // Shifting twice avoids an overflow when level = 10. + (base & ((!0 << shift) << LOG_2_SLOTS as u64)) | ((slot as u64) << shift) +} + +/// Compute the level and slot for a certain expiry +fn timer_index(base: u64, expiry: u64) -> (usize, usize) { + // The level is the position of the first bit set in `expiry` but not in `base`, divided by the + // number of bits spanned by each level. + let differing_bits = base ^ expiry; + let level = (63 - (differing_bits | 1).leading_zeros()) as usize / LOG_2_SLOTS; + debug_assert!(level < LEVELS, "every possible expiry is in range"); + + // The slot in that level is the difference between the expiry time and the time at which the + // level's span begins, after both times are shifted down to the level's granularity. Each + // level's spans starts at `base`, rounded down to a multiple of the size of its span. + let slot_base = (base >> (level * LOG_2_SLOTS)) & (!0 << LOG_2_SLOTS); + let slot = (expiry >> (level * LOG_2_SLOTS)) - slot_base; + debug_assert!(slot < SLOTS as u64); + + (level, slot as usize) +} + +impl Default for DelayQueue { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug)] +struct TimerState { + /// Lowest argument to `poll` for which this timer may be returned + expiry: u64, + /// Value returned to the caller on expiry + value: T, + /// Predecessor within a slot's list + prev: Option, + /// Successor within a slot's list + next: Option, +} + +/// A set of contiguous timer lists, ordered by expiry +/// +/// Level `n` spans `2^(LOG_2_SLOTS * (n+1))` ticks, and each of its slots corresponds to a span of +/// `2^(LOG_2_SLOTS * n)`. +#[derive(Copy, Clone)] +struct Level { + slots: [Option; SLOTS], +} + +impl Level { + fn new() -> Self { + Self { + slots: [None; SLOTS], + } + } +} + +impl fmt::Debug for Level { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut m = f.debug_map(); + let numbered_nonempty_slots = self + .slots + .iter() + .enumerate() + .filter_map(|(i, x)| x.map(|t| (i, t))); + for (i, Timer(t)) in numbered_nonempty_slots { + m.entry(&i, &t); + } + m.finish() + } +} + +const LOG_2_SLOTS: usize = 6; +const LEVELS: usize = 1 + 64 / LOG_2_SLOTS; +const SLOTS: usize = 1 << LOG_2_SLOTS; + +// Index in `DelayQueue::timers`. Future work: add a niche here. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct Timer(usize); + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use super::*; + use proptest::prelude::*; + + #[test] + fn max_timeout() { + let mut queue = DelayQueue::new(); + queue.insert(u64::MAX, ()); + assert!(queue.poll(u64::MAX - 1).is_none()); + assert!(queue.poll(u64::MAX).is_some()); + } + + #[test] + fn level_ranges() { + assert_eq!(range_in_level(0, 0..=1), 0..=1); + assert_eq!(range_in_level(0, 0..=SLOTS as u64), 0..=SLOTS - 1); + assert_eq!(range_in_level(1, 0..=SLOTS as u64), 0..=1); + assert_eq!(range_in_level(1, 0..=(SLOTS as u64).pow(2)), 0..=SLOTS - 1); + assert_eq!(range_in_level(2, 0..=(SLOTS as u64).pow(2)), 0..=1); + } + + #[test] + fn slot_starts() { + for i in 0..SLOTS { + assert_eq!(slot_start(0, 0, i), i as u64); + assert_eq!(slot_start(SLOTS as u64, 0, i), SLOTS as u64 + i as u64); + assert_eq!(slot_start(SLOTS as u64 + 1, 0, i), SLOTS as u64 + i as u64); + for j in 1..LEVELS { + assert_eq!( + slot_start(0, j, i), + (SLOTS as u64).pow(j as u32).wrapping_mul(i as u64) + ); + } + } + } + + #[test] + fn indexes() { + assert_eq!(timer_index(0, 0), (0, 0)); + assert_eq!(timer_index(0, SLOTS as u64 - 1), (0, SLOTS - 1)); + assert_eq!( + timer_index(SLOTS as u64 - 1, SLOTS as u64 - 1), + (0, SLOTS - 1) + ); + assert_eq!(timer_index(0, SLOTS as u64), (1, 1)); + for i in 0..LEVELS { + assert_eq!(timer_index(0, (SLOTS as u64).pow(i as u32)), (i, 1)); + if i < LEVELS - 1 { + assert_eq!( + timer_index(0, (SLOTS as u64).pow(i as u32 + 1) - 1), + (i, SLOTS - 1) + ); + assert_eq!( + timer_index(SLOTS as u64 - 1, (SLOTS as u64).pow(i as u32 + 1) - 1), + (i, SLOTS - 1) + ); + } + } + } + + #[test] + fn next_timeout() { + let mut queue = DelayQueue::new(); + assert_eq!(queue.next_timeout(), None); + let k = queue.insert(0, ()); + assert_eq!(queue.next_timeout(), Some(0)); + queue.remove(k); + assert_eq!(queue.next_timeout(), None); + queue.insert(1234, ()); + assert!(queue.next_timeout().unwrap() > 12); + queue.insert(12, ()); + assert_eq!(queue.next_timeout(), Some(12)); + } + + #[test] + fn poll_boundary() { + let mut queue = DelayQueue::new(); + queue.insert(SLOTS as u64 - 1, 'a'); + queue.insert(SLOTS as u64, 'b'); + assert_eq!(queue.poll(SLOTS as u64 - 2), None); + assert_eq!(queue.poll(SLOTS as u64 - 1), Some('a')); + assert_eq!(queue.poll(SLOTS as u64 - 1), None); + assert_eq!(queue.poll(SLOTS as u64), Some('b')); + } + + #[test] + /// Validate that `reset` properly updates intrusive list links + fn reset_list_middle() { + let mut queue = DelayQueue::new(); + let slot = SLOTS as u64 / 2; + let a = queue.insert(slot, ()); + let b = queue.insert(slot, ()); + let c = queue.insert(slot, ()); + + queue.reset(b, slot + 1); + + assert_eq!(queue.levels[0].slots[slot as usize + 1], Some(b)); + assert_eq!(queue.timers[b.0].prev, None); + assert_eq!(queue.timers[b.0].next, None); + + assert_eq!(queue.levels[0].slots[slot as usize], Some(c)); + assert_eq!(queue.timers[c.0].prev, None); + assert_eq!(queue.timers[c.0].next, Some(a)); + assert_eq!(queue.timers[a.0].prev, Some(c)); + assert_eq!(queue.timers[a.0].next, None); + } + + proptest! { + #[test] + fn poll(ts in times()) { + let mut queue = DelayQueue::new(); + let mut time_values = HashMap::>::new(); + for (i, t) in ts.into_iter().enumerate() { + queue.insert(t, i); + time_values.entry(t).or_default().push(i); + } + let mut time_values = time_values.into_iter().collect::)>>(); + time_values.sort_unstable_by_key(|&(t, _)| t); + for &(t, ref is) in &time_values { + assert!(queue.next_timeout().unwrap() <= t); + if t > 0 { + assert_eq!(queue.poll(t-1), None); + } + let mut values = Vec::new(); + while let Some(i) = queue.poll(t) { + values.push(i); + } + assert_eq!(values.len(), is.len()); + for i in is { + assert!(values.contains(i)); + } + } + } + + #[test] + fn reset(ts_a in times(), ts_b in times()) { + let mut queue = DelayQueue::new(); + let timers = ts_a.map(|t| queue.insert(t, ())); + for (timer, t) in timers.into_iter().zip(ts_b) { + queue.reset(timer, t); + } + let mut n = 0; + while let Some(()) = queue.poll(u64::MAX) { + n += 1; + } + assert_eq!(n, timers.len()); + } + + #[test] + fn index_start_consistency(a in time(), b in time()) { + let base = a.min(b); + let t = a.max(b); + let (level, slot) = timer_index(base, t); + let start = slot_start(base, level, slot); + assert!(start <= t); + if let Some(end) = start.checked_add((SLOTS as u64).pow(level as u32)) { + assert!(end > t); + } else { + // Slot contains u64::MAX + assert!(start >= slot_start(0, LEVELS - 1, 15)); + if level == LEVELS - 1 { + assert_eq!(slot, 15); + } else { + assert_eq!(slot, SLOTS - 1); + } + } + } + } + + /// Generates a time whose level/slot is more or less uniformly distributed + fn time() -> impl Strategy { + ((0..LEVELS as u32), (0..SLOTS as u64)).prop_perturb(|(level, mut slot), mut rng| { + if level == LEVELS as u32 - 1 { + slot %= 16; + } + let slot_size = (SLOTS as u64).pow(level); + let slot_start = slot * slot_size; + let slot_end = (slot + 1).saturating_mul(slot_size); + rng.gen_range(slot_start..slot_end) + }) + } + + #[rustfmt::skip] + fn times() -> impl Strategy { + [time(), time(), time(), time(), time(), time(), time(), time(), + time(), time(), time(), time(), time(), time(), time(), time()] + } +} diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index bc2946c7c..746987030 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -12,7 +12,7 @@ use std::{ time::Instant, }; -use crate::runtime::{default_runtime, AsyncUdpSocket, Runtime}; +use crate::runtime::{default_runtime, AsyncTimer, AsyncUdpSocket, Runtime}; use bytes::{Bytes, BytesMut}; use pin_project_lite::pin_project; use proto::{ @@ -20,11 +20,11 @@ use proto::{ }; use rustc_hash::FxHashMap; use tokio::sync::{futures::Notified, mpsc, Notify}; -use tokio_util::time::DelayQueue; use udp::{RecvMeta, UdpState, BATCH_SIZE}; use crate::{ connection::{Connecting, ConnectionRef}, + delay_queue::DelayQueue, work_limiter::WorkLimiter, EndpointConfig, VarInt, RECV_TIME_BOUND, SEND_TIME_BOUND, }; @@ -120,6 +120,7 @@ impl Endpoint { socket, proto::Endpoint::new(Arc::new(config), server_config.map(Arc::new)), addr.is_ipv6(), + runtime.clone(), ); let driver = EndpointDriver(rc.clone()); runtime.spawn(Box::pin(async { @@ -335,6 +336,7 @@ pub(crate) struct EndpointInner { #[derive(Debug)] pub(crate) struct State { + runtime: Arc, socket: Box, udp_state: UdpState, inner: proto::Endpoint, @@ -356,6 +358,8 @@ pub(crate) struct State { /// Passed in to connections to enable the above dirty_send: mpsc::UnboundedSender, timers: DelayQueue, + timer_epoch: Instant, + base_timer: Option>>, } #[derive(Debug)] @@ -432,9 +436,17 @@ impl State { Ok(false) } - fn drive_timers(&mut self, cx: &mut Context, now: Instant) { - while let Poll::Ready(Some(result)) = self.timers.poll_expired(cx) { - let conn_handle = result.unwrap().into_inner(); + fn drive_timers(&mut self, cx: &mut Context, now: Instant) -> bool { + let mut keep_going = false; + // `DelayQueue::poll` currently yields timers expiring in the same millisecond in LIFO + // order. This doesn't matter so long as we're processing all expiries, but if the below + // loop is ever updated to bail out early to improve fairness under heavy load, then we + // should carefully consider whether serving newer events (more likely to still be relevant) + // or older ones (more likely to allow us to free resources) should take priority. + while let Some(conn_handle) = self + .timers + .poll((now - self.timer_epoch).as_millis() as u64) + { let conn = match self.connections.refs.get(&conn_handle) { Some(c) => c, None => continue, @@ -446,6 +458,20 @@ impl State { state.timer_deadline = None; state.wake(); } + if let Some(deadline) = self.timers.next_timeout() { + let deadline = self.timer_epoch + std::time::Duration::from_millis(deadline); + let timer = match self.base_timer { + Some(ref mut x) => { + x.as_mut().reset(deadline); + x + } + None => self.base_timer.insert(self.runtime.new_timer(deadline)), + }; + if let Poll::Ready(()) = timer.as_mut().poll(cx) { + keep_going = true; + } + } + keep_going } fn drive_send(&mut self, cx: &mut Context) -> Result { @@ -495,7 +521,7 @@ impl State { fn drive_connections(&mut self, cx: &mut Context, shared: &Shared) -> bool { let mut keep_going = false; - self.drive_timers(cx, Instant::now()); + keep_going |= self.drive_timers(cx, Instant::now()); let mut dirty_buffer = Vec::new(); @@ -517,15 +543,17 @@ impl State { let _guard = state.span.clone().entered(); let mut keep_conn_going = state.drive_transmit(&mut self.outgoing, max_datagrams); if let Some(deadline) = state.inner.poll_timeout() { - let deadline = tokio::time::Instant::from(deadline); if Some(deadline) != state.timer_deadline { + let deadline = (deadline - self.timer_epoch).as_millis() as u64; match state.timer_handle { - Some(ref key) => self.timers.reset_at(key, deadline), + Some(key) => { + self.timers.reset(key, deadline); + } None => { - state.timer_handle = Some(self.timers.insert_at(conn_handle, deadline)); + state.timer_handle = Some(self.timers.insert(deadline, conn_handle)); } } - // self.timers may need to be polled + // base timer may need to be updated keep_going = true; } } @@ -631,7 +659,12 @@ impl<'a> Future for Accept<'a> { pub(crate) struct EndpointRef(Arc); impl EndpointRef { - pub(crate) fn new(socket: Box, inner: proto::Endpoint, ipv6: bool) -> Self { + pub(crate) fn new( + socket: Box, + inner: proto::Endpoint, + ipv6: bool, + runtime: Arc, + ) -> Self { let udp_state = UdpState::new(); let recv_buf = vec![ 0; @@ -646,6 +679,7 @@ impl EndpointRef { idle: Notify::new(), }, state: Mutex::new(State { + runtime, socket, udp_state, inner, @@ -665,6 +699,8 @@ impl EndpointRef { dirty_recv, dirty_send, timers: DelayQueue::new(), + timer_epoch: Instant::now(), + base_timer: None, }), })) } diff --git a/quinn/src/lib.rs b/quinn/src/lib.rs index 70cff11e8..218d16efa 100644 --- a/quinn/src/lib.rs +++ b/quinn/src/lib.rs @@ -51,6 +51,7 @@ macro_rules! ready { } mod connection; +mod delay_queue; mod endpoint; mod mutex; mod recv_stream;