diff --git a/quic/s2n-quic/src/tests/deduplicate.rs b/quic/s2n-quic/src/tests/deduplicate.rs index ddd25e093..a1ceba878 100644 --- a/quic/s2n-quic/src/tests/deduplicate.rs +++ b/quic/s2n-quic/src/tests/deduplicate.rs @@ -1,6 +1,8 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 +use crate::provider::endpoint_limits::Outcome; + use super::*; use s2n_quic_core::{dc::testing::MockDcEndpoint, stateless_reset::token::testing::TEST_TOKEN_1}; @@ -14,7 +16,7 @@ fn deduplicate_successfully() { let server_subscriber = recorder::ConnectionStarted::new(); let server_events = server_subscriber.events(); let client_subscriber = recorder::ConnectionStarted::new(); - let client_events = server_subscriber.events(); + let client_events = client_subscriber.events(); test(model, |handle| { let mut server = Server::builder() .with_io(handle.builder().build()?)? @@ -127,3 +129,93 @@ fn confirm_conn_works( assert_eq!(LEN, recv_len); } } + +#[derive(Clone)] +struct Toggle(Arc>); + +impl Toggle { + fn new(outcome: Outcome) -> Self { + Self(Arc::new(Mutex::new(outcome))) + } + + fn set(&self, outcome: Outcome) { + *self.0.lock().unwrap() = outcome; + } +} + +impl crate::provider::endpoint_limits::Limiter for Toggle { + fn on_connection_attempt( + &mut self, + _info: &crate::provider::endpoint_limits::ConnectionAttempt<'_>, + ) -> Outcome { + self.0.lock().unwrap().clone() + } +} + +#[test] +fn deduplicate_non_terminal() { + let model = Model::default(); + model.set_delay(Duration::from_millis(50)); + + let server_subscriber = recorder::ConnectionStarted::new(); + let server_events = server_subscriber.events(); + let client_subscriber = recorder::ConnectionStarted::new(); + let client_events = client_subscriber.events(); + test(model, |handle| { + let toggle = Toggle::new(Outcome::drop()); + let tokens = [TEST_TOKEN_1]; + let mut server = Server::builder() + .with_io(handle.builder().build()?)? + .with_tls(SERVER_CERTS)? + .with_event((tracing_events(), server_subscriber.clone()))? + .with_random(Random::with_seed(456))? + .with_dc(MockDcEndpoint::new(&tokens))? + .with_endpoint_limits(toggle.clone())? + .start()?; + + let addr = server.local_addr()?; + spawn(async move { + let mut conn = server.accept().await.unwrap(); + for _ in 0..2 { + let mut stream = conn.open_bidirectional_stream().await.unwrap(); + stream.send(vec![42; LEN].into()).await.unwrap(); + stream.flush().await.unwrap(); + } + let mut conn = server.accept().await.unwrap(); + let mut stream = conn.open_bidirectional_stream().await.unwrap(); + stream.send(vec![42; LEN].into()).await.unwrap(); + stream.flush().await.unwrap(); + }); + + let tokens = [TEST_TOKEN_1]; + let client = Client::builder() + .with_io(handle.builder().build().unwrap())? + .with_tls(certificates::CERT_PEM)? + .with_event((tracing_events(), client_subscriber))? + .with_random(Random::with_seed(456))? + .with_dc(MockDcEndpoint::new(&tokens))? + .start()?; + + primary::spawn(async move { + let connect = Connect::new(addr) + .with_server_name("localhost") + .with_deduplicate(true); + client.connect(connect.clone()).await.unwrap_err(); + + // now allow connections + toggle.set(Outcome::allow()); + + let mut conn = client.connect(connect.clone()).await.unwrap(); + confirm_conn_works(&mut conn).await; + }); + + Ok(addr) + }) + .unwrap(); + + let server_started_count = server_events.lock().unwrap().len(); + let client_started_count = client_events.lock().unwrap().len(); + + assert_eq!(server_started_count, 1); + assert_eq!(client_started_count, 2); +} diff --git a/quic/s2n-quic/src/tests/recorder.rs b/quic/s2n-quic/src/tests/recorder.rs index 526bc0b8c..a8bc332fc 100644 --- a/quic/s2n-quic/src/tests/recorder.rs +++ b/quic/s2n-quic/src/tests/recorder.rs @@ -162,8 +162,6 @@ event_recorder!( SocketAddr, |event: &events::ConnectionStarted, storage: &mut Vec| { let addr: SocketAddr = event.path.local_addr.to_string().parse().unwrap(); - if storage.last().map_or(true, |prev| *prev != addr) { - storage.push(addr); - } + storage.push(addr); } );