Skip to content

Commit

Permalink
tests: convert more tests to utils::make_configs()
Browse files Browse the repository at this point in the history
There's still some improvements left to be made, but this reduces
a great deal of duplication in the test code.
  • Loading branch information
cpu committed Jul 13, 2024
1 parent c5726b7 commit 5690851
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 58 deletions.
31 changes: 7 additions & 24 deletions tests/early-data.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#![cfg(feature = "early-data")]

use std::io::{self, BufReader, Cursor, Read, Write};
use std::io::{self, Read, Write};
use std::net::{SocketAddr, TcpListener};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::thread;

use futures_util::{future::Future, ready};
use rustls::{self, ClientConfig, RootCertStore, ServerConfig, ServerConnection, Stream};
use rustls::{self, ClientConfig, ServerConnection, Stream};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf};
use tokio::net::TcpStream;
use tokio_rustls::{client::TlsStream, TlsConnector};
Expand Down Expand Up @@ -65,14 +65,7 @@ async fn test_0rtt_vectored() -> io::Result<()> {
}

async fn test_0rtt_impl(vectored: bool) -> io::Result<()> {
let cert_chain = rustls_pemfile::certs(&mut Cursor::new(include_bytes!("certs/end.cert")))
.collect::<io::Result<Vec<_>>>()?;
let key_der =
rustls_pemfile::private_key(&mut Cursor::new(include_bytes!("certs/end.rsa")))?.unwrap();
let mut server = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert_chain, key_der)
.unwrap();
let (mut server, mut client) = utils::make_configs();
server.max_early_data_size = 8192;
let server = Arc::new(server);

Expand Down Expand Up @@ -109,25 +102,15 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> {
});
});

let mut chain = BufReader::new(Cursor::new(include_str!("certs/end.chain")));
let mut root_store = RootCertStore::empty();
for cert in rustls_pemfile::certs(&mut chain) {
root_store.add(cert.unwrap()).unwrap();
}

let mut config =
rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
.with_root_certificates(root_store)
.with_no_client_auth();
config.enable_early_data = true;
let config = Arc::new(config);
client.enable_early_data = true;
let client = Arc::new(client);
let addr = SocketAddr::from(([127, 0, 0, 1], server_port));

let (io, buf) = send(config.clone(), addr, b"hello", vectored).await?;
let (io, buf) = send(client.clone(), addr, b"hello", vectored).await?;
assert!(!io.get_ref().1.is_early_data_accepted());
assert_eq!("LATE:hello", String::from_utf8_lossy(&buf));

let (io, buf) = send(config, addr, b"world!", vectored).await?;
let (io, buf) = send(client, addr, b"world!", vectored).await?;
assert!(io.get_ref().1.is_early_data_accepted());
assert_eq!("EARLY:world!LATE:", String::from_utf8_lossy(&buf));

Expand Down
40 changes: 6 additions & 34 deletions tests/test.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::io::{BufReader, Cursor, ErrorKind};
use std::io::{Cursor, ErrorKind};
use std::net::SocketAddr;
use std::sync::mpsc::channel;
use std::sync::Arc;
Expand All @@ -8,31 +8,17 @@ use std::{io, thread};
use futures_util::future::TryFutureExt;
use lazy_static::lazy_static;
use rustls::ClientConfig;
use rustls_pemfile::{certs, rsa_private_keys};
use tokio::io::{copy, split, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot;
use tokio::{runtime, time};
use tokio_rustls::{LazyConfigAcceptor, TlsAcceptor, TlsConnector};

const CERT: &str = include_str!("certs/end.cert");
const CHAIN: &[u8] = include_bytes!("certs/end.chain");
const RSA: &str = include_str!("certs/end.rsa");

lazy_static! {
static ref TEST_SERVER: (SocketAddr, &'static str, &'static [u8]) = {
let cert = certs(&mut BufReader::new(Cursor::new(CERT)))
.map(|result| result.unwrap())
.collect();
let key = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA)))
.next()
.unwrap()
.unwrap();

let config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert, key.into())
.unwrap();
let (config, _) = utils::make_configs();
let acceptor = TlsAcceptor::from(Arc::new(config));

let (send, recv) = channel();
Expand Down Expand Up @@ -102,22 +88,15 @@ async fn start_client(addr: SocketAddr, domain: &str, config: Arc<ClientConfig>)

#[tokio::test]
async fn pass() -> io::Result<()> {
let (addr, domain, chain) = start_server();
let (addr, domain, _) = start_server();

// TODO: not sure how to resolve this right now but since
// TcpStream::bind now returns a future it creates a race
// condition until its ready sometimes.
use std::time::*;
tokio::time::sleep(Duration::from_secs(1)).await;

let mut root_store = rustls::RootCertStore::empty();
for cert in certs(&mut std::io::Cursor::new(*chain)) {
root_store.add(cert.unwrap()).unwrap();
}

let config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
let (_, config) = utils::make_configs();
let config = Arc::new(config);

start_client(*addr, domain, config).await?;
Expand All @@ -127,16 +106,9 @@ async fn pass() -> io::Result<()> {

#[tokio::test]
async fn fail() -> io::Result<()> {
let (addr, domain, chain) = start_server();

let mut root_store = rustls::RootCertStore::empty();
for cert in certs(&mut std::io::Cursor::new(*chain)) {
root_store.add(cert.unwrap()).unwrap();
}
let (addr, domain, _) = start_server();

let config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
let (_, config) = utils::make_configs();
let config = Arc::new(config);

assert_ne!(domain, &"google.com");
Expand Down

0 comments on commit 5690851

Please sign in to comment.