Skip to content

Commit

Permalink
wip: rewrite early-data test using rustls
Browse files Browse the repository at this point in the history
  • Loading branch information
paolobarbolini committed Mar 10, 2024
1 parent c547c63 commit 2f08191
Showing 1 changed file with 63 additions and 107 deletions.
170 changes: 63 additions & 107 deletions tests/early-data.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
#![cfg(feature = "early-data")]

use std::io::{self, BufRead, BufReader, Cursor};
use std::net::SocketAddr;
use std::io::{self, BufReader, Cursor, Read, Write};
use std::net::{SocketAddr, TcpListener};
use std::pin::Pin;
use std::process::{Child, Command, Stdio};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::thread;
use std::time::Duration;

use futures_util::{future, future::Future, ready};
use rustls::{self, ClientConfig, RootCertStore};
use tokio::io::{split, AsyncRead, AsyncWriteExt, ReadBuf};
use futures_util::{future::Future, ready};
use rustls::{self, ClientConfig, RootCertStore, ServerConfig, ServerConnection, Stream};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf};
use tokio::net::TcpStream;
use tokio::sync::oneshot;
use tokio::time::sleep;
use tokio_rustls::{client::TlsStream, TlsConnector};

struct Read1<T>(T);
Expand Down Expand Up @@ -42,99 +38,76 @@ async fn send(
addr: SocketAddr,
data: &[u8],
vectored: bool,
) -> io::Result<TlsStream<TcpStream>> {
) -> io::Result<(TlsStream<TcpStream>, Vec<u8>)> {
let connector = TlsConnector::from(config).early_data(true);
let stream = TcpStream::connect(&addr).await?;
let domain = pki_types::ServerName::try_from("foobar.com").unwrap();

let stream = connector.connect(domain, stream).await?;
let (mut rd, mut wd) = split(stream);
let (notify, wait) = oneshot::channel();

let j = tokio::spawn(async move {
// read to eof
//
// see https://www.mail-archive.com/openssl-users@openssl.org/msg84451.html
let mut read_task = Read1(&mut rd);
let mut notify = Some(notify);

// read once, then write
//
// this is a regression test, see https://github.com/tokio-rs/tls/issues/54
future::poll_fn(|cx| {
let ret = Pin::new(&mut read_task).poll(cx)?;
assert_eq!(ret, Poll::Pending);

notify.take().unwrap().send(()).unwrap();

Poll::Ready(Ok(())) as Poll<io::Result<_>>
})
.await?;

match read_task.await {
Ok(()) => (),
Err(ref err) if err.kind() == io::ErrorKind::UnexpectedEof => (),
Err(err) => return Err(err),
}

Ok(rd) as io::Result<_>
});

wait.await.unwrap();

utils::write(&mut wd, data, vectored).await?;
wd.flush().await?;
wd.shutdown().await?;
let mut stream = connector.connect(domain, stream).await?;
utils::write(&mut stream, data, vectored).await?;
stream.flush().await?;
stream.shutdown().await?;

let rd: tokio::io::ReadHalf<_> = j.await??;

Ok(rd.unsplit(wd))
}

struct DropKill(Child);

impl Drop for DropKill {
fn drop(&mut self) {
self.0.kill().unwrap();
}
}
let mut buf = Vec::new();
stream.read_to_end(&mut buf).await?;

async fn wait_for_server(addr: &str) {
let tries = 10;
for i in 0..tries {
if let Ok(_) = TcpStream::connect(addr).await {
return;
}
sleep(Duration::from_millis(i * 100)).await;
}
panic!("failed to connect to {:?} after {} tries", addr, tries)
Ok((stream, buf))
}

#[tokio::test]
async fn test_0rtt() -> io::Result<()> {
test_0rtt_impl(12354, false).await
test_0rtt_impl(false).await
}

#[tokio::test]
async fn test_0rtt_vectored() -> io::Result<()> {
test_0rtt_impl(12353, true).await
test_0rtt_impl(true).await
}

async fn test_0rtt_impl(server_port: u16, vectored: bool) -> io::Result<()> {
let mut handle = Command::new("openssl")
.arg("s_server")
.arg("-early_data")
.arg("-tls1_3")
.args(["-cert", "./tests/end.cert"])
.args(["-key", "./tests/end.rsa"])
.args(["-port", &server_port.to_string()])
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()
.map(DropKill)?;

// wait openssl server
wait_for_server(format!("127.0.0.1:{}", server_port).as_str()).await;
async fn test_0rtt_impl(vectored: bool) -> io::Result<()> {
let cert_chain = rustls_pemfile::certs(&mut Cursor::new(include_bytes!("end.cert")))
.collect::<io::Result<Vec<_>>>()?;
let key_der =
rustls_pemfile::private_key(&mut Cursor::new(include_bytes!("end.rsa")))?.unwrap();
let mut server = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert_chain, key_der)
.unwrap();
server.max_early_data_size = 8192;
let server = Arc::new(server);

let listener = TcpListener::bind("127.0.0.1:0")?;
let server_port = listener.local_addr().unwrap().port();
thread::spawn(move || loop {
let (mut sock, _addr) = listener.accept().unwrap();

let server = Arc::clone(&server);
thread::spawn(move || {
let mut conn = ServerConnection::new(server).unwrap();
conn.complete_io(&mut sock).unwrap();

if let Some(mut early_data) = conn.early_data() {
let mut buf = Vec::new();
early_data.read_to_end(&mut buf).unwrap();
let mut stream = Stream::new(&mut conn, &mut sock);
stream.write_all(b"EARLY:").unwrap();
stream.write_all(&buf).unwrap();
}

let mut stream = Stream::new(&mut conn, &mut sock);
stream.write_all(b"LATE:").unwrap();
loop {
let mut buf = [0; 1024];
let n = stream.read(&mut buf).unwrap();
if n == 0 {
conn.send_close_notify();
conn.complete_io(&mut sock).unwrap();
break;
}
stream.write_all(&buf[..n]).unwrap();
}
});
});

let mut chain = BufReader::new(Cursor::new(include_str!("end.chain")));
let mut root_store = RootCertStore::empty();
Expand All @@ -150,30 +123,13 @@ async fn test_0rtt_impl(server_port: u16, vectored: bool) -> io::Result<()> {
let config = Arc::new(config);
let addr = SocketAddr::from(([127, 0, 0, 1], server_port));

// workaround: write to openssl s_server standard input periodically, to
// get it unstuck on Windows
let stdin = handle.0.stdin.take().unwrap();
thread::spawn(move || {
let mut stdin = stdin;
loop {
thread::sleep(std::time::Duration::from_secs(5));
std::io::Write::write_all(&mut stdin, b"\n").unwrap();
}
});

let io = send(config.clone(), addr, b"hello", vectored).await?;
let (io, buf) = send(config.clone(), addr, b"hello", vectored).await?;
assert!(!io.get_ref().1.is_early_data_accepted());
assert_eq!("LATE:hello", String::from_utf8_lossy(&buf));

let io = send(config, addr, b"world!", vectored).await?;
let (io, buf) = send(config, addr, b"world!", vectored).await?;
assert!(io.get_ref().1.is_early_data_accepted());

let stdout = handle.0.stdout.as_mut().unwrap();
let mut lines = BufReader::new(stdout).lines();

let has_msg1 = lines.by_ref().any(|line| line.unwrap().contains("hello"));
let has_msg2 = lines.by_ref().any(|line| line.unwrap().contains("world!"));

assert!(has_msg1 && has_msg2);
assert_eq!("EARLY:world!LATE:", String::from_utf8_lossy(&buf));

Ok(())
}
Expand Down

0 comments on commit 2f08191

Please sign in to comment.