From 61cfc7a84cca2c7153d4900711130d1935d7fd4b Mon Sep 17 00:00:00 2001 From: Wesley Rosenblum <55108558+WesleyRosenblum@users.noreply.github.com> Date: Mon, 22 Jul 2024 17:01:10 -0700 Subject: [PATCH] fix(s2n-quic-dc): handle possible secret control packet correctly (#2280) --- quic/s2n-quic-core/src/dc/testing.rs | 12 +- quic/s2n-quic-transport/src/endpoint/mod.rs | 23 ++- quic/s2n-quic/src/tests/dc.rs | 194 +++++++++++++++++--- 3 files changed, 191 insertions(+), 38 deletions(-) diff --git a/quic/s2n-quic-core/src/dc/testing.rs b/quic/s2n-quic-core/src/dc/testing.rs index e1ccd1abde..2fea792b56 100644 --- a/quic/s2n-quic-core/src/dc/testing.rs +++ b/quic/s2n-quic-core/src/dc/testing.rs @@ -9,15 +9,23 @@ use crate::{ varint::VarInt, }; use core::time::Duration; +use std::sync::{ + atomic::{AtomicU8, Ordering}, + Arc, +}; pub struct MockDcEndpoint { stateless_reset_tokens: Vec, + pub on_possible_secret_control_packet_count: Arc, + pub on_possible_secret_control_packet: fn() -> bool, } impl MockDcEndpoint { pub fn new(tokens: &[stateless_reset::Token]) -> Self { Self { stateless_reset_tokens: tokens.to_vec(), + on_possible_secret_control_packet_count: Arc::new(AtomicU8::default()), + on_possible_secret_control_packet: || false, } } } @@ -45,7 +53,9 @@ impl dc::Endpoint for MockDcEndpoint { _datagram_info: &DatagramInfo, _payload: &mut [u8], ) -> bool { - false + self.on_possible_secret_control_packet_count + .fetch_add(1, Ordering::Relaxed); + (self.on_possible_secret_control_packet)() } } diff --git a/quic/s2n-quic-transport/src/endpoint/mod.rs b/quic/s2n-quic-transport/src/endpoint/mod.rs index c0dba82f1f..2219e13e5e 100644 --- a/quic/s2n-quic-transport/src/endpoint/mod.rs +++ b/quic/s2n-quic-transport/src/endpoint/mod.rs @@ -458,14 +458,6 @@ impl Endpoint { endpoint_context.connection_id_format, ) { (packet, remaining) - } else if Cfg::DcEndpoint::ENABLED - && endpoint_context - .dc - .on_possible_secret_control_packet(&dc::DatagramInfo::new(&remote_address), payload) - { - // This was a DC secret control packet, so we don't need to proceed - // with checking for a stateless reset - return; } else { //= https://www.rfc-editor.org/rfc/rfc9000#section-5.2.2 //# Servers MUST drop incoming packets under all other circumstances. @@ -761,12 +753,25 @@ impl Endpoint { } } (_, packet) => { + let is_short_header_packet = matches!(packet, ProtectedPacket::Short(_)); + + if Cfg::DcEndpoint::ENABLED + && is_short_header_packet // dc packets are short header packets + && endpoint_context.dc.on_possible_secret_control_packet( + &dc::DatagramInfo::new(&remote_address), + payload, + ) + { + // This was a DC secret control packet, so we don't need to proceed + // with checking for a stateless reset + return; + } + publisher.on_endpoint_datagram_dropped(event::builder::EndpointDatagramDropped { len: payload_len as u16, reason: event::builder::DatagramDropReason::UnknownDestinationConnectionId, }); - let is_short_header_packet = matches!(packet, ProtectedPacket::Short(_)); //= https://www.rfc-editor.org/rfc/rfc9000#section-10.3.1 //# Endpoints MAY skip this check if any packet from a datagram is //# successfully processed. However, the comparison MUST be performed diff --git a/quic/s2n-quic/src/tests/dc.rs b/quic/s2n-quic/src/tests/dc.rs index 8d7225c9e3..a49846030c 100644 --- a/quic/s2n-quic/src/tests/dc.rs +++ b/quic/s2n-quic/src/tests/dc.rs @@ -10,16 +10,22 @@ use crate::{ server, server::ServerProviders, }; +use s2n_codec::DecoderBufferMut; use s2n_quic_core::{ crypto::tls, dc::testing::MockDcEndpoint, - event::{api::DcState, Timestamp}, + event::{ + api::{DatagramDropReason, DcState, EndpointDatagramDropped, EndpointMeta, Subject}, + Timestamp, + }, frame::ConnectionClose, + packet::interceptor::{Datagram, Interceptor}, stateless_reset, stateless_reset::token::testing::{TEST_TOKEN_1, TEST_TOKEN_2}, transport, varint::VarInt, }; +use std::sync::atomic::Ordering; const SERVER_TOKENS: [stateless_reset::Token; 1] = [TEST_TOKEN_1]; const CLIENT_TOKENS: [stateless_reset::Token; 1] = [TEST_TOKEN_2]; @@ -63,7 +69,9 @@ fn dc_handshake_self_test() -> Result<()> { .with_tls(certificates::CERT_PEM)? .with_dc(MockDcEndpoint::new(&CLIENT_TOKENS))?; - self_test(server, client, None, None) + self_test(server, client, None, None)?; + + Ok(()) } // Client Server @@ -106,7 +114,9 @@ fn dc_mtls_handshake_self_test() -> Result<()> { .with_tls(client_tls)? .with_dc(MockDcEndpoint::new(&SERVER_TOKENS))?; - self_test(server, client, None, None) + self_test(server, client, None, None)?; + + Ok(()) } #[test] @@ -133,7 +143,9 @@ fn dc_mtls_handshake_auth_failure_self_test() -> Result<()> { } .into(); - self_test(server, client, Some(expected_client_error), None) + self_test(server, client, Some(expected_client_error), None)?; + + Ok(()) } // Client Server @@ -173,7 +185,9 @@ fn dc_mtls_handshake_server_not_supported_self_test() -> Result<()> { "peer does not support specified dc versions", )), Some(expected_server_error), - ) + )?; + + Ok(()) } // Client Server @@ -218,7 +232,66 @@ fn dc_mtls_handshake_client_not_supported_self_test() -> Result<()> { Some(connection::Error::invalid_configuration( "peer does not support specified dc versions", )), - ) + )?; + + Ok(()) +} + +#[test] +fn dc_secret_control_packet() -> Result<()> { + dc_possible_secret_control_packet(|| true) +} + +#[test] +fn dc_not_secret_control_packet() -> Result<()> { + dc_possible_secret_control_packet(|| false) +} + +fn dc_possible_secret_control_packet( + on_possible_secret_control_packet: fn() -> bool, +) -> Result<()> { + let server_tls = build_server_mtls_provider(certificates::MTLS_CA_CERT)?; + let server = Server::builder() + .with_tls(server_tls)? + .with_dc(MockDcEndpoint::new(&SERVER_TOKENS))?; + + let client_tls = build_client_mtls_provider(certificates::MTLS_CA_CERT)?; + let mut dc_endpoint = MockDcEndpoint::new(&CLIENT_TOKENS); + dc_endpoint.on_possible_secret_control_packet = on_possible_secret_control_packet; + let on_possible_secret_control_packet_count = + dc_endpoint.on_possible_secret_control_packet_count.clone(); + + let client = Client::builder() + .with_tls(client_tls)? + .with_dc(dc_endpoint)? + .with_packet_interceptor(RandomShort::default())?; + + let (client_events, _server_events) = self_test(server, client, None, None)?; + + assert_eq!( + 1, + on_possible_secret_control_packet_count.load(Ordering::Relaxed) + ); + + let client_datagram_drops = client_events + .endpoint_datagram_dropped_events + .lock() + .unwrap(); + + if on_possible_secret_control_packet() { + // No datagrams should be recorded as dropped because MockDcEndpoint::on_possible_secret_control_packet + // returned true, indicating the given datagram was a secret control packet + assert_eq!(0, client_datagram_drops.len()); + } else { + // The datagram was not a secret control packet, so it is dropped + assert_eq!(1, client_datagram_drops.len()); + assert!(matches!( + client_datagram_drops[0].reason, + DatagramDropReason::UnknownDestinationConnectionId { .. } + )); + } + + Ok(()) } fn self_test( @@ -226,14 +299,14 @@ fn self_test( client: client::Builder, expected_client_error: Option, expected_server_error: Option, -) -> Result<()> { +) -> Result<(DcRecorder, DcRecorder)> { let model = Model::default(); let rtt = Duration::from_millis(100); model.set_delay(rtt / 2); - let server_subscriber = DcStateChanged::new(); + let server_subscriber = DcRecorder::new(); let server_events = server_subscriber.clone(); - let client_subscriber = DcStateChanged::new(); + let client_subscriber = DcRecorder::new(); let client_events = client_subscriber.clone(); test(model, |handle| { @@ -284,7 +357,11 @@ fn self_test( } } else { assert!(result.is_ok()); - let client_events = client_events.events().lock().unwrap().clone(); + let client_events = client_events + .dc_state_changed_events() + .lock() + .unwrap() + .clone(); assert_dc_complete(&client_events); // wait briefly so the ack for the `DC_STATELESS_RESET_TOKENS` frame from the server is sent // before the client closes the connection. This is only necessary to confirm the `dc::State` @@ -298,33 +375,55 @@ fn self_test( .unwrap(); if expected_client_error.is_some() || expected_server_error.is_some() { - return Ok(()); + return Ok((client_events, server_events)); } - let server_events = server_events.events().lock().unwrap().clone(); - let client_events = client_events.events().lock().unwrap().clone(); + let server_dc_state_changed_events = server_events + .dc_state_changed_events() + .lock() + .unwrap() + .clone(); + let client_dc_state_changed_events = client_events + .dc_state_changed_events() + .lock() + .unwrap() + .clone(); - assert_dc_complete(&server_events); - assert_dc_complete(&client_events); + assert_dc_complete(&server_dc_state_changed_events); + assert_dc_complete(&client_dc_state_changed_events); // Server path secrets are ready in 1.5 RTTs measured from the start of the test, since it takes // .5 RTT for the Initial from the client to reach the server assert_eq!( // remove floating point division error Duration::from_millis(rtt.mul_f32(1.5).as_millis() as u64), - server_events[1].timestamp.duration_since_start() + server_dc_state_changed_events[1] + .timestamp + .duration_since_start() + ); + assert_eq!( + rtt, + client_dc_state_changed_events[1] + .timestamp + .duration_since_start() ); - assert_eq!(rtt, client_events[1].timestamp.duration_since_start()); // Server completes in 2.5 RTTs measured from the start of the test, since it takes .5 RTT // for the Initial from the client to reach the server assert_eq!( rtt.mul_f32(2.5), - server_events[2].timestamp.duration_since_start() + server_dc_state_changed_events[2] + .timestamp + .duration_since_start() + ); + assert_eq!( + rtt * 2, + client_dc_state_changed_events[2] + .timestamp + .duration_since_start() ); - assert_eq!(rtt * 2, client_events[2].timestamp.duration_since_start()); - Ok(()) + Ok((client_events, server_events)) } fn assert_dc_complete(events: &[DcStateChangedEvent]) { @@ -358,21 +457,22 @@ struct DcStateChangedEvent { } #[derive(Clone, Default)] -struct DcStateChanged { - pub events: Arc>>, +struct DcRecorder { + pub dc_state_changed_events: Arc>>, + pub endpoint_datagram_dropped_events: Arc>>, } -impl DcStateChanged { +impl DcRecorder { pub fn new() -> Self { Self::default() } - pub fn events(&self) -> Arc>> { - self.events.clone() + pub fn dc_state_changed_events(&self) -> Arc>> { + self.dc_state_changed_events.clone() } } -impl events::Subscriber for DcStateChanged { - type ConnectionContext = DcStateChanged; +impl events::Subscriber for DcRecorder { + type ConnectionContext = DcRecorder; fn create_connection_context( &mut self, @@ -394,7 +494,45 @@ impl events::Subscriber for DcStateChanged { state: event.state.clone(), }); }; - let mut buffer = context.events.lock().unwrap(); + let mut buffer = context.dc_state_changed_events.lock().unwrap(); store(event, &mut buffer); } + + fn on_endpoint_datagram_dropped( + &mut self, + _meta: &EndpointMeta, + event: &EndpointDatagramDropped, + ) { + self.endpoint_datagram_dropped_events + .lock() + .unwrap() + .push(event.clone()); + } +} + +/// Replace the first short packet payload with a randomized short packet +#[derive(Default)] +struct RandomShort(bool); + +impl Interceptor for RandomShort { + #[inline] + fn intercept_rx_datagram<'a>( + &mut self, + _subject: &Subject, + _datagram: &Datagram, + payload: DecoderBufferMut<'a>, + ) -> DecoderBufferMut<'a> { + let payload = payload.into_less_safe_slice(); + + if let 0b0100u8..=0b0111u8 = payload[0] >> 4 { + if !self.0 { + // randomize everything after the short header tag + rand::fill_bytes(&mut payload[1..]); + // only change the first short packet + self.0 = true; + } + } + + DecoderBufferMut::new(payload) + } }