Skip to content

Commit

Permalink
fix: don't lock up if handshake fails in acceptor (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac authored May 6, 2024
1 parent a6f25ba commit eefbcad
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 45 deletions.
6 changes: 4 additions & 2 deletions examples/ssl_trace.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use rustls::client::danger::{ServerCertVerified, ServerCertVerifier};
use rustls::client::danger::ServerCertVerified;
use rustls::client::danger::ServerCertVerifier;
use rustls::pki_types::ServerName;
use rustls::{ClientConfig, ClientConnection};
use rustls::ClientConfig;
use rustls::ClientConnection;
use rustls_tokio_stream::TlsStream;
use std::env;
use std::sync::Arc;
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ mod stream;
#[cfg(test)]
mod system_test;

pub use stream::ServerConfigProvider;
pub use stream::TlsHandshake;
pub use stream::TlsStream;
pub use stream::TlsStreamRead;
pub use stream::TlsStreamWrite;
pub use stream::ServerConfigProvider;

/// Re-export the version of rustls we are built on
pub use rustls;
Expand Down
118 changes: 76 additions & 42 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ impl TlsStream {
let handshake_send = handshake.clone();
let handle = spawn(async move {
let res =
send_handshake(tcp_handshake, tls, test_options, handshake_send).await;
send_handshake(tcp_handshake, Ok(tls), test_options, handshake_send)
.await;

// We may have read/writes blocked on the handshake, so wake them all up
read_waker_clone.wake();
Expand All @@ -172,6 +173,26 @@ impl TlsStream {
}
}

async fn accept(
tcp_handshake: &TcpStream,
server_config_provider: ServerConfigProvider,
) -> Result<ServerConnection, io::Error> {
let mut acceptor = Acceptor::default();
let tls = loop {
tcp_handshake.readable().await?;
read_acceptor(&tcp_handshake, &mut acceptor)?;
if let Some(accepted) = acceptor.accept().map_err(rustls_to_io_error)? {
let f = server_config_provider(accepted.client_hello());
let config = f.await?;
let tls = accepted
.into_connection(config)
.map_err(rustls_to_io_error)?;
break tls;
}
};
Ok(tls)
}

fn new_server_acceptor(
tcp: TcpStream,
server_config_provider: ServerConfigProvider,
Expand All @@ -187,24 +208,12 @@ impl TlsStream {
let tcp_handshake = tcp.clone();

let handshake_send = handshake.clone();
let handle = spawn(async move {
let mut acceptor = Acceptor::default();
let tls = loop {
tcp_handshake.readable().await?;
read_acceptor(&tcp_handshake, &mut acceptor)?;
if let Some(accepted) = acceptor.accept().map_err(rustls_to_io_error)? {
let f = server_config_provider(accepted.client_hello());
let config = f.await?;
let tls = accepted
.into_connection(config)
.map_err(rustls_to_io_error)?;
break tls;
}
};

let handle = spawn(async move {
let tls = Self::accept(&tcp_handshake, server_config_provider).await;
let res = send_handshake(
tcp_handshake,
rustls::Connection::Server(tls),
tls.map(rustls::Connection::Server),
test_options,
handshake_send,
)
Expand Down Expand Up @@ -592,10 +601,20 @@ impl TlsStream {

async fn send_handshake(
tcp: Arc<TcpStream>,
tls: Connection,
tls: Result<Connection, io::Error>,
test_options: TestOptions,
handshake: Arc<HandshakeWatch>,
) -> Result<HandshakeResult, io::Error> {
let tls = match tls {
Ok(tls) => tls,
Err(err) => {
*handshake.handshake.lock().unwrap() = Some(Err(clone_error(&err)));
handshake.rx_waker.wake();
handshake.tx_waker.wake();
return Err(err);
}
};

#[cfg(test)]
if test_options.delay_handshake {
tokio::time::sleep(std::time::Duration::from_millis(1000)).await;
Expand Down Expand Up @@ -1257,13 +1276,20 @@ pub(super) mod tests {
}

async fn make_config(
alpn: &'static [&'static str],
alpn: Result<&'static [&'static str], &'static str>,
) -> Result<Arc<ServerConfig>, io::Error> {
Ok(server_config(alpn).into())
Ok(
server_config(
alpn.map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?,
)
.into(),
)
}

async fn tls_pair_alpn_acceptor(
server_alpn: fn(ClientHello) -> &'static [&'static str],
server_alpn: fn(
ClientHello,
) -> Result<&'static [&'static str], &'static str>,
server_buffer_size: Option<NonZeroUsize>,
client_alpn: &[&str],
client_buffer_size: Option<NonZeroUsize>,
Expand Down Expand Up @@ -1356,7 +1382,7 @@ pub(super) mod tests {

/// Test that a flush before a handshake completes works.
#[tokio::test]
// #[ntest::timeout(60000)]
#[ntest::timeout(60000)]
async fn test_flush_before_handshake() -> TestResult {
let (mut server, mut client) = tls_pair().await;
server.write_all(b"hello?").await.unwrap();
Expand Down Expand Up @@ -1394,35 +1420,39 @@ pub(super) mod tests {
Ok(())
}

fn alpn_handler(
client_hello: ClientHello,
) -> Result<&'static [&'static str], &'static str> {
if let Some(alpn) = client_hello.alpn() {
for alpn in alpn {
if alpn == b"a" {
return Ok(&["a"]);
}
if alpn == b"b" {
return Ok(&["b"]);
}
}
}
Err("bad server")
}

/// Test that the handshake works, and we get the correct ALPN negotiated values.
#[rstest]
#[case("a")]
#[case("b")]
#[case("c")]
#[tokio::test]
// #[ntest::timeout(60000)]
#[ntest::timeout(60000)]
async fn test_client_server_alpn_acceptor(
#[case] alpn: &'static str,
) -> TestResult {
let (mut server, mut client) = tls_pair_alpn_acceptor(
|client_hello| {
if let Some(alpn) = client_hello.alpn() {
for alpn in alpn {
if alpn == b"a" {
return &["a"];
}
if alpn == b"b" {
return &["b"];
}
}
}
&[]
},
None,
&[alpn],
None,
)
.await;
let (mut server, mut client) =
tls_pair_alpn_acceptor(alpn_handler, None, &[alpn], None).await;
let a = spawn(async move {
if alpn == "c" {
server.handshake().await.expect_err("expected failure");
return;
}
let handshake = server.handshake().await.unwrap();
assert_eq!(handshake.alpn, Some(alpn.as_bytes().to_vec()));
assert_eq!(handshake.sni, Some("example.com".into()));
Expand All @@ -1432,6 +1462,10 @@ pub(super) mod tests {
assert_eq!(buf.as_slice(), b"hello!");
});
let b = spawn(async move {
if alpn == "c" {
client.handshake().await.expect_err("expected failure");
return;
}
let handshake = client.handshake().await.unwrap();
assert_eq!(handshake.alpn, Some(alpn.as_bytes().to_vec()));
client.write_all(b"hello!").await.unwrap();
Expand All @@ -1446,7 +1480,7 @@ pub(super) mod tests {

/// Test that the handshake fails, and we get the correct errors on both ends.
#[tokio::test]
// #[ntest::timeout(60000)]
#[ntest::timeout(60000)]
async fn test_client_server_alpn_mismatch() -> TestResult {
let (mut server, mut client) =
tls_pair_alpn(&["a"], None, &["b"], None).await;
Expand Down

0 comments on commit eefbcad

Please sign in to comment.