From 513212c48b25c4a1db5dc3f597fe17225f60a67f Mon Sep 17 00:00:00 2001 From: Cameron Bytheway Date: Wed, 4 Dec 2024 11:34:10 -0700 Subject: [PATCH] refactor(s2n-quic-dc): reduce the number of `peer_addr` calls (#2401) --- dc/s2n-quic-dc/src/stream/application.rs | 16 ++- .../src/stream/environment/tokio.rs | 4 +- dc/s2n-quic-dc/src/stream/recv/application.rs | 16 ++- dc/s2n-quic-dc/src/stream/send/application.rs | 16 ++- dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs | 111 ++++++++---------- 5 files changed, 84 insertions(+), 79 deletions(-) diff --git a/dc/s2n-quic-dc/src/stream/application.rs b/dc/s2n-quic-dc/src/stream/application.rs index f9d9bc4e9..c36de11e4 100644 --- a/dc/s2n-quic-dc/src/stream/application.rs +++ b/dc/s2n-quic-dc/src/stream/application.rs @@ -114,10 +114,18 @@ where Sub: event::Subscriber, { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Stream") - .field("peer_addr", &self.peer_addr().unwrap()) - .field("local_addr", &self.local_addr().unwrap()) - .finish() + let mut s = f.debug_struct("Stream"); + + for (name, addr) in [ + ("peer_addr", self.peer_addr()), + ("local_addr", self.local_addr()), + ] { + if let Ok(addr) = addr { + s.field(name, &addr); + } + } + + s.finish() } } diff --git a/dc/s2n-quic-dc/src/stream/environment/tokio.rs b/dc/s2n-quic-dc/src/stream/environment/tokio.rs index e0d532d70..3c5091b82 100644 --- a/dc/s2n-quic-dc/src/stream/environment/tokio.rs +++ b/dc/s2n-quic-dc/src/stream/environment/tokio.rs @@ -289,7 +289,7 @@ where } /// A socket that should be reregistered with the application runtime -pub struct TcpReregistered(pub TcpStream); +pub struct TcpReregistered(pub TcpStream, pub SocketAddress); impl super::Peer> for TcpReregistered where @@ -308,7 +308,7 @@ where #[inline] fn setup(self, _env: &Environment) -> super::Result> { - let remote_addr = self.0.peer_addr()?.into(); + let remote_addr = self.1; let source_control_port = self.0.local_addr()?.port(); let application = Box::new(self.0.into_std()?); Ok(super::SocketSet { diff --git a/dc/s2n-quic-dc/src/stream/recv/application.rs b/dc/s2n-quic-dc/src/stream/recv/application.rs index 5988e3e4d..89762fc72 100644 --- a/dc/s2n-quic-dc/src/stream/recv/application.rs +++ b/dc/s2n-quic-dc/src/stream/recv/application.rs @@ -60,10 +60,18 @@ where Sub: event::Subscriber, { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Reader") - .field("peer_addr", &self.peer_addr().unwrap()) - .field("local_addr", &self.local_addr().unwrap()) - .finish() + let mut s = f.debug_struct("Reader"); + + for (name, addr) in [ + ("peer_addr", self.peer_addr()), + ("local_addr", self.local_addr()), + ] { + if let Ok(addr) = addr { + s.field(name, &addr); + } + } + + s.finish() } } diff --git a/dc/s2n-quic-dc/src/stream/send/application.rs b/dc/s2n-quic-dc/src/stream/send/application.rs index fe309225f..88d1d9e30 100644 --- a/dc/s2n-quic-dc/src/stream/send/application.rs +++ b/dc/s2n-quic-dc/src/stream/send/application.rs @@ -47,10 +47,18 @@ where Sub: event::Subscriber, { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Writer") - .field("peer_addr", &self.peer_addr().unwrap()) - .field("local_addr", &self.local_addr().unwrap()) - .finish() + let mut s = f.debug_struct("Writer"); + + for (name, addr) in [ + ("peer_addr", self.peer_addr()), + ("local_addr", self.local_addr()), + ] { + if let Ok(addr) = addr { + s.field(name, &addr); + } + } + + s.finish() } } diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs b/dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs index 189b0afb0..a52957ec0 100644 --- a/dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs +++ b/dc/s2n-quic-dc/src/stream/server/tokio/tcp.rs @@ -102,8 +102,8 @@ where fresh.fill(cx, &self.socket, &publisher); - for socket in fresh.drain() { - workers.push(socket, now, &self.subscriber, &publisher); + for (socket, remote_address) in fresh.drain() { + workers.push(socket, remote_address, now, &self.subscriber, &publisher); } let res = workers.poll(cx, &mut context, now, &publisher); @@ -154,7 +154,7 @@ fn publisher<'a, Sub: Subscriber, C: Clock>( /// /// This should produce overall better latencies in the case of overloaded queues. struct FreshQueue { - queue: VecDeque, + queue: VecDeque<(TcpStream, SocketAddress)>, } impl FreshQueue { @@ -186,32 +186,31 @@ impl FreshQueue { while let Poll::Ready(res) = listener.poll_accept(cx) { match res { - Ok((socket, remote_addr)) => { + Ok((socket, remote_address)) => { if self.queue.len() == self.queue.capacity() { - if let Some(remote_addr) = self + if let Some(remote_address) = self .queue .pop_back() - .and_then(|socket| socket.peer_addr().ok()) + .map(|(_socket, remote_address)| remote_address) { - let remote_address: SocketAddress = remote_addr.into(); - let remote_address = &remote_address; publisher.on_acceptor_tcp_stream_dropped( - event::builder::AcceptorTcpStreamDropped { remote_address, reason: event::builder::AcceptorTcpStreamDropReason::FreshQueueAtCapacity }, + event::builder::AcceptorTcpStreamDropped { remote_address: &remote_address, reason: event::builder::AcceptorTcpStreamDropReason::FreshQueueAtCapacity }, ); dropped += 1; } } - let remote_address: SocketAddress = remote_addr.into(); - let remote_address = &remote_address; + let remote_address: SocketAddress = remote_address.into(); publisher.on_acceptor_tcp_fresh_enqueued( - event::builder::AcceptorTcpFreshEnqueued { remote_address }, + event::builder::AcceptorTcpFreshEnqueued { + remote_address: &remote_address, + }, ); enqueued += 1; // most recent streams go to the front of the line, since they're the most // likely to be successfully processed - self.queue.push_front(socket); + self.queue.push_front((socket, remote_address)); } Err(error) => { // TODO submit to a separate error channel that the application can subscribe @@ -239,7 +238,7 @@ impl FreshQueue { ) } - fn drain(&mut self) -> impl Iterator + '_ { + fn drain(&mut self) -> impl Iterator + '_ { self.queue.drain(..) } } @@ -292,6 +291,7 @@ where pub fn push( &mut self, stream: TcpStream, + remote_address: SocketAddress, now: Timestamp, subscriber: &Sub, publisher: &Pub, @@ -304,20 +304,14 @@ where // // TODO: we need to investigate how this interacts with SYN cookies/retries and fast // failure modes in kernel space. - if let Ok(remote_addr) = stream.peer_addr() { - let remote_address: SocketAddress = remote_addr.into(); - let remote_address = &remote_address; - publisher.on_acceptor_tcp_stream_dropped( - event::builder::AcceptorTcpStreamDropped { - remote_address, - reason: event::builder::AcceptorTcpStreamDropReason::SlotsAtCapacity, - }, - ); - } + publisher.on_acceptor_tcp_stream_dropped(event::builder::AcceptorTcpStreamDropped { + remote_address: &remote_address, + reason: event::builder::AcceptorTcpStreamDropReason::SlotsAtCapacity, + }); drop(stream); return; }; - self.workers[idx].push(stream, now, subscriber, publisher); + self.workers[idx].push(stream, remote_address, now, subscriber, publisher); self.working.push_back(idx); } @@ -473,7 +467,7 @@ where Sub: event::Subscriber + Clone, { queue_time: Timestamp, - stream: Option, + stream: Option<(TcpStream, SocketAddress)>, subscriber_ctx: Option, state: WorkerState, } @@ -495,6 +489,7 @@ where pub fn push( &mut self, stream: TcpStream, + remote_address: SocketAddress, now: Timestamp, subscriber: &Sub, publisher: &Pub, @@ -514,12 +509,10 @@ where let prev_queue_time = core::mem::replace(&mut self.queue_time, now); let prev_state = core::mem::replace(&mut self.state, WorkerState::Init); - let prev_stream = core::mem::replace(&mut self.stream, Some(stream)); + let prev_stream = core::mem::replace(&mut self.stream, Some((stream, remote_address))); let prev_ctx = core::mem::replace(&mut self.subscriber_ctx, Some(subscriber_ctx)); - if let Some(remote_addr) = prev_stream.and_then(|socket| socket.peer_addr().ok()) { - let remote_address: SocketAddress = remote_addr.into(); - let remote_address = &remote_address; + if let Some(remote_address) = prev_stream.map(|(_socket, remote_address)| remote_address) { let sojourn_time = now.saturating_duration_since(prev_queue_time); let buffer_len = match prev_state { WorkerState::Init => 0, @@ -527,7 +520,7 @@ where WorkerState::Erroring { .. } => 0, }; publisher.on_acceptor_tcp_stream_replaced(event::builder::AcceptorTcpStreamReplaced { - remote_address, + remote_address: &remote_address, sojourn_time, buffer_len, }); @@ -615,7 +608,7 @@ impl WorkerState { &mut self, cx: &mut Context, context: &mut WorkerContext, - stream: &mut Option, + stream: &mut Option<(TcpStream, SocketAddress)>, subscriber_ctx: &mut Option, queue_time: Timestamp, now: Timestamp, @@ -639,8 +632,8 @@ impl WorkerState { } => (buffer, *blocked_count), // we encountered an error so try and send it back WorkerState::Erroring { offset, buffer, .. } => { - let stream = Pin::new(stream.as_mut().unwrap()); - let len = ready!(stream.poll_write(cx, &buffer[*offset..]))?; + let (stream, _remote_address) = stream.as_mut().unwrap(); + let len = ready!(Pin::new(stream).poll_write(cx, &buffer[*offset..]))?; *offset += len; @@ -660,13 +653,17 @@ impl WorkerState { }; // try to read an initial packet from the socket - let res = Self::poll_initial_packet( - cx, - stream.as_mut().unwrap(), - recv_buffer, - sojourn_time, - publisher, - ); + let res = { + let (stream, remote_address) = stream.as_mut().unwrap(); + Self::poll_initial_packet( + cx, + stream, + remote_address, + recv_buffer, + sojourn_time, + publisher, + ) + }; let Poll::Ready(res) = res else { // if we got `Pending` but we don't own the recv buffer then we need to copy it @@ -689,11 +686,12 @@ impl WorkerState { let initial_packet = res?; let subscriber_ctx = subscriber_ctx.take().unwrap(); + let (socket, remote_address) = stream.take().unwrap(); let stream_builder = match endpoint::accept_stream( now, &context.env, - env::TcpReregistered(stream.take().unwrap()), + env::TcpReregistered(socket, remote_address), &initial_packet, None, Some(recv_buffer), @@ -704,11 +702,11 @@ impl WorkerState { ) { Ok(stream) => stream, Err(error) => { - if let Some(env::TcpReregistered(socket)) = error.peer { + if let Some(env::TcpReregistered(socket, remote_address)) = error.peer { if !error.secret_control.is_empty() { // if we need to send an error then update the state and loop back // around - *stream = Some(socket); + *stream = Some((socket, remote_address)); *self = WorkerState::Erroring { offset: 0, buffer: error.secret_control, @@ -768,6 +766,7 @@ impl WorkerState { fn poll_initial_packet( cx: &mut Context, stream: &mut TcpStream, + remote_address: &SocketAddress, recv_buffer: &mut msg::recv::Message, sojourn_time: Duration, publisher: &Pub, @@ -777,15 +776,9 @@ impl WorkerState { { loop { if recv_buffer.payload_len() > 10_000 { - let remote_address = stream - .peer_addr() - .ok() - .map(SocketAddress::from) - .unwrap_or_default(); - publisher.on_acceptor_tcp_packet_dropped( event::builder::AcceptorTcpPacketDropped { - remote_address: &remote_address, + remote_address, reason: DecoderError::UnexpectedBytes(recv_buffer.payload_len()) .into_event(), sojourn_time, @@ -798,15 +791,9 @@ impl WorkerState { match server::InitialPacket::peek(recv_buffer, 16) { Ok(packet) => { - let remote_address = stream - .peer_addr() - .ok() - .map(SocketAddress::from) - .unwrap_or_default(); - publisher.on_acceptor_tcp_packet_received( event::builder::AcceptorTcpPacketReceived { - remote_address: &remote_address, + remote_address, credential_id: &*packet.credentials.id, stream_id: packet.stream_id.into_varint().as_u64(), payload_len: packet.payload_len, @@ -823,15 +810,9 @@ impl WorkerState { continue; } - let remote_address = stream - .peer_addr() - .ok() - .map(SocketAddress::from) - .unwrap_or_default(); - publisher.on_acceptor_tcp_packet_dropped( event::builder::AcceptorTcpPacketDropped { - remote_address: &remote_address, + remote_address, reason: err.into_event(), sojourn_time, },