Skip to content

Commit

Permalink
fix: cancel accept if client disconnected (#28)
Browse files Browse the repository at this point in the history
* rebase

* Added test

* fix test
  • Loading branch information
wille-io authored Jul 12, 2024
1 parent 8237025 commit a1b7b64
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,10 @@ impl TlsStream {
let mut acceptor = Acceptor::default();
loop {
tcp_handshake.readable().await?;
read_acceptor(tcp_handshake, &mut acceptor)?;
// Stop if connection was closed by client
if read_acceptor(&tcp_handshake, &mut acceptor)? < 1 {
return Err(io::ErrorKind::ConnectionReset.into());
}

let accepted = match acceptor.accept() {
Ok(Some(accepted)) => accepted,
Expand Down
96 changes: 96 additions & 0 deletions src/system_test/disconnect_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use crate::tests::certificate;
use crate::tests::private_key;
use crate::TlsStream;
use rustls::server::ClientHello;
use rustls::ServerConfig;
use std::io;
use std::net::SocketAddr;
use std::net::SocketAddrV4;
use std::net::Ipv4Addr;
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpSocket;
use tokio::net::TcpListener;
use tokio::spawn;

fn alpn_handler(
client_hello: ClientHello,
) -> Result<&'static [&'static str], &'static str> {
if let Some(alpn) = client_hello.alpn() {
for alpn in alpn {
if alpn == b"a" {
return Ok(&["a"]);
}
if alpn == b"b" {
return Ok(&["b"]);
}
}
}
Err("bad server")
}

fn server_config_alpn(alpn: &[&str]) -> ServerConfig {
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![certificate()], private_key())
.expect("Failed to build server config");
config.alpn_protocols =
alpn.iter().map(|v| v.as_bytes().to_owned()).collect();
config
}

async fn make_config(
alpn: Result<&'static [&'static str], &'static str>,
) -> Result<Arc<ServerConfig>, io::Error> {
Ok(
server_config_alpn(
alpn.map_err(|e| io::Error::new(std::io::ErrorKind::InvalidData, e))?,
)
.into(),
)
}

#[tokio::test]
async fn disconnect_test() {
let listener = TcpListener::bind(SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::LOCALHOST,
0,
)))
.await
.unwrap();

let port = listener.local_addr().unwrap().port();

let _client = spawn(async move {
TcpSocket::new_v4()
.unwrap()
.connect(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)))
.await
.unwrap()
});

let server = listener.accept().await.unwrap().0;
let mut client = _client.await.unwrap();

client.shutdown().await.expect("Shutdown failed"); // Disconnect before tls handshake

TlsStream::new_server_side_acceptor(
server,
Arc::new(move |client_hello| {
Box::pin(make_config(alpn_handler(client_hello)))
}),
None
);

// At this point, the acceptor is in an infinite loop, to test if it's really so, try to connect another client.

spawn(async move {
TcpSocket::new_v4()
.unwrap()
.connect(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)))
.await
.unwrap()
}).await.unwrap();

listener.accept().await.unwrap().0; // The test should be stuck now if the bug is still active
}
1 change: 1 addition & 0 deletions src/system_test/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// Copyright 2018-2023 the Deno authors. All rights reserved. MIT license.
mod fastwebsockets;
mod speed_test;
mod disconnect_test;

0 comments on commit a1b7b64

Please sign in to comment.