Skip to content

Commit

Permalink
feat: port to rustls 0.20
Browse files Browse the repository at this point in the history
* attempt to port to rustls 0.20

* clippy

* format

* Fix test certificates (expired) and add a script to regenerate them.

* Fix hanging and failing unit-tests

UnexpectedEof errors are bubbled up from rustls. Tests needed to changed slightly, but are en par with tokio/tls.

* Fix integration tests

* Fix client and server examples

* Update async-std version used for testing

---------

Co-authored-by: Jason Mobarak <jason@swiftnav.com>
  • Loading branch information
mfelsche and Jason Mobarak authored Feb 1, 2023
1 parent a8ca3ca commit 55f643c
Show file tree
Hide file tree
Showing 21 changed files with 442 additions and 318 deletions.
16 changes: 11 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
[package]
name = "async-tls"
version = "0.11.0"
authors = ["The async-rs developers", "Florian Gilcher <florian.gilcher@ferrous-systems.com>", "dignifiedquire <dignifiedquire@gmail.com>", "quininer kel <quininer@live.com>"]
authors = [
"The async-rs developers",
"Florian Gilcher <florian.gilcher@ferrous-systems.com>",
"dignifiedquire <dignifiedquire@gmail.com>",
"quininer kel <quininer@live.com>",
]
license = "MIT/Apache-2.0"
repository = "https://github.com/async-std/async-tls"
homepage = "https://github.com/async-std/async-tls"
Expand All @@ -18,9 +23,10 @@ appveyor = { repository = "async-std/async-tls" }
[dependencies]
futures-io = "0.3.5"
futures-core = "0.3.5"
rustls = "0.19.0"
webpki = { version = "0.21.3", optional = true }
webpki-roots = { version = "0.21.0", optional = true }
rustls = "0.20.6"
rustls-pemfile = "1.0"
webpki = { version = "0.22.0", optional = true }
webpki-roots = { version = "0.22.3", optional = true }

[features]
default = ["client", "server"]
Expand All @@ -32,7 +38,7 @@ server = []
lazy_static = "1"
futures-executor = "0.3.5"
futures-util = { version = "0.3.5", features = ["io"] }
async-std = { version = "1.0", features = ["unstable"] }
async-std = { version = "1.11", features = ["unstable"] }

[[test]]
name = "test"
Expand Down
5 changes: 3 additions & 2 deletions examples/client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ edition = "2018"

[dependencies]
structopt = "0.3.9"
rustls = "0.19.0"
async-std = "1.5.0"
rustls = "0.20.6"
rustls-pemfile = "1.0"
async-std = "1.11.0"
async-tls = { path = "../.." }
19 changes: 11 additions & 8 deletions examples/client/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use async_std::task;
use async_tls::TlsConnector;

use rustls::ClientConfig;
use rustls_pemfile::certs;

use std::io::Cursor;
use std::io::{BufReader, Cursor};
use std::net::ToSocketAddrs;
use std::path::{Path, PathBuf};
use std::sync::Arc;
Expand Down Expand Up @@ -81,12 +82,14 @@ fn main() -> io::Result<()> {
}

async fn connector_for_ca_file(cafile: &Path) -> io::Result<TlsConnector> {
let mut config = ClientConfig::new();
let file = async_std::fs::read(cafile).await?;
let mut pem = Cursor::new(file);
config
.root_store
.add_pem_file(&mut pem)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))?;
let mut root_store = rustls::RootCertStore::empty();
let ca_bytes = async_std::fs::read(cafile).await?;
let cert = certs(&mut BufReader::new(Cursor::new(ca_bytes))).unwrap();
debug_assert_eq!((1, 0), root_store.add_parsable_certificates(&cert));

let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
Ok(TlsConnector::from(Arc::new(config)))
}
9 changes: 5 additions & 4 deletions examples/server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ authors = ["The async-rs developers", "quininer <quininer@live.com>"]
edition = "2018"

[dependencies]
structopt = "0.3.9"
async-std = "1.5.0"
async-std = "1.11.0"
async-tls = { path = "../.." }
rustls = "0.19.0"
webpki = "0.21.3"
futures-lite = "1.12.0"
rustls = "0.20.6"
rustls-pemfile = "1.0"
structopt = "0.3.9"
38 changes: 25 additions & 13 deletions examples/server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use async_std::io;
use async_std::net::{TcpListener, TcpStream};
use async_std::prelude::*;
use async_std::stream::StreamExt;
use async_std::task;
use async_tls::TlsAcceptor;
use rustls::internal::pemfile::{certs, rsa_private_keys};
use rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
use futures_lite::io::AsyncWriteExt;
use rustls::{Certificate, PrivateKey, ServerConfig};
use rustls_pemfile::{certs, read_one, Item};

use std::fs::File;
use std::io::BufReader;
Expand All @@ -28,14 +29,23 @@ struct Options {

/// Load the passed certificates file
fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> {
certs(&mut BufReader::new(File::open(path)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))
Ok(certs(&mut BufReader::new(File::open(path)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))?
.into_iter()
.map(Certificate)
.collect())
}

/// Load the passed keys file
fn load_keys(path: &Path) -> io::Result<Vec<PrivateKey>> {
rsa_private_keys(&mut BufReader::new(File::open(path)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))
fn load_key(path: &Path) -> io::Result<PrivateKey> {
match read_one(&mut BufReader::new(File::open(path)?)) {
Ok(Some(Item::RSAKey(data) | Item::PKCS8Key(data))) => Ok(PrivateKey(data)),
Ok(_) => Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid key in {}", path.display()),
)),
Err(e) => Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
}
}

/// Configure the server using rusttls
Expand All @@ -44,13 +54,15 @@ fn load_keys(path: &Path) -> io::Result<Vec<PrivateKey>> {
/// A TLS server needs a certificate and a fitting private key
fn load_config(options: &Options) -> io::Result<ServerConfig> {
let certs = load_certs(&options.cert)?;
let mut keys = load_keys(&options.key)?;
debug_assert_eq!(1, certs.len());
let key = load_key(&options.key)?;

// we don't use client authentication
let mut config = ServerConfig::new(NoClientAuth::new());
config
let config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
// set this server to use one cert together with the loaded private key
.set_single_cert(certs, keys.remove(0))
.with_single_cert(certs, key)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;

Ok(config)
Expand Down Expand Up @@ -78,7 +90,7 @@ async fn handle_connection(acceptor: &TlsAcceptor, tcp_stream: &mut TcpStream) -
)
.await?;

tls_stream.flush().await?;
tls_stream.close().await?;

Ok(())
}
Expand Down
18 changes: 12 additions & 6 deletions src/acceptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::common::tls_state::TlsState;
use crate::server;

use futures_io::{AsyncRead, AsyncWrite};
use rustls::{ServerConfig, ServerSession};
use rustls::{ServerConfig, ServerConnection};
use std::future::Future;
use std::io;
use std::pin::Pin;
Expand Down Expand Up @@ -39,17 +39,23 @@ impl TlsAcceptor {
self.accept_with(stream, |_| ())
}

// Currently private, as exposing ServerSessions exposes rusttls
// Currently private, as exposing ServerConnections exposes rusttls
fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
F: FnOnce(&mut ServerSession),
F: FnOnce(&mut ServerConnection),
{
let mut session = ServerSession::new(&self.inner);
f(&mut session);
let mut conn = match ServerConnection::new(self.inner.clone()) {
Ok(conn) => conn,
Err(_) => {
return Accept(server::MidHandshake::End);
}
};

f(&mut conn);

Accept(server::MidHandshake::Handshaking(server::TlsStream {
session,
conn,
io: stream,
state: TlsState::Stream,
}))
Expand Down
19 changes: 10 additions & 9 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@ use crate::common::tls_state::TlsState;
use crate::rusttls::stream::Stream;
use futures_core::ready;
use futures_io::{AsyncRead, AsyncWrite};
use rustls::ClientSession;
use rustls::ClientConnection;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, mem};

use rustls::Session;

/// The client end of a TLS connection. Can be used like any other bidirectional IO stream.
/// Wraps the underlying TCP stream.
#[derive(Debug)]
pub struct TlsStream<IO> {
pub(crate) io: IO,
pub(crate) session: ClientSession,
pub(crate) session: ClientConnection,
pub(crate) state: TlsState,

#[cfg(feature = "early-data")]
Expand Down Expand Up @@ -58,11 +56,11 @@ where
let (io, session) = (&mut stream.io, &mut stream.session);
let mut stream = Stream::new(io, session).set_eof(eof);

if stream.session.is_handshaking() {
if stream.conn.is_handshaking() {
ready!(stream.complete_io(cx))?;
}

if stream.session.wants_write() {
if stream.conn.wants_write() {
ready!(stream.complete_io(cx))?;
}
}
Expand Down Expand Up @@ -90,17 +88,20 @@ where
TlsState::EarlyData => {
let this = self.get_mut();

let is_handshaking = this.session.is_handshaking();
let is_early_data_accepted = this.session.is_early_data_accepted();

let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
let (pos, data) = &mut this.early_data;

// complete handshake
if stream.session.is_handshaking() {
if is_handshaking {
ready!(stream.complete_io(cx))?;
}

// write early data (fallback)
if !stream.session.is_early_data_accepted() {
if !is_early_data_accepted {
while *pos < data.len() {
let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
*pos += len;
Expand All @@ -127,7 +128,7 @@ where
Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => {
this.state.shutdown_read();
if this.state.writeable() {
stream.session.send_close_notify();
stream.conn.send_close_notify();
this.state.shutdown_write();
}
Poll::Ready(Ok(0))
Expand Down
37 changes: 27 additions & 10 deletions src/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ use crate::common::tls_state::TlsState;
use crate::client;

use futures_io::{AsyncRead, AsyncWrite};
use rustls::{ClientConfig, ClientSession};
use rustls::{ClientConfig, ClientConnection, OwnedTrustAnchor, RootCertStore, ServerName};
use std::convert::TryFrom;
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use webpki::DNSNameRef;

/// The TLS connecting part. The acceptor drives
/// the client side of the TLS handshake process. It works
Expand Down Expand Up @@ -64,10 +64,18 @@ impl From<ClientConfig> for TlsConnector {

impl Default for TlsConnector {
fn default() -> Self {
let mut config = ClientConfig::new();
config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
let mut root_certs = RootCertStore::empty();
root_certs.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_certs)
.with_no_client_auth();
Arc::new(config).into()
}
}
Expand Down Expand Up @@ -102,14 +110,14 @@ impl TlsConnector {
self.connect_with(domain, stream, |_| ())
}

// NOTE: Currently private, exposing ClientSession exposes rusttls
// NOTE: Currently private, exposing ClientConnection exposes rusttls
// Early data should be exposed differently
fn connect_with<'a, IO, F>(&self, domain: impl AsRef<str>, stream: IO, f: F) -> Connect<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
F: FnOnce(&mut ClientSession),
F: FnOnce(&mut ClientConnection),
{
let domain = match DNSNameRef::try_from_ascii_str(domain.as_ref()) {
let domain = match ServerName::try_from(domain.as_ref()) {
Ok(domain) => domain,
Err(_) => {
return Connect(ConnectInner::Error(Some(io::Error::new(
Expand All @@ -119,7 +127,16 @@ impl TlsConnector {
}
};

let mut session = ClientSession::new(&self.inner, domain);
let mut session = match ClientConnection::new(self.inner.clone(), domain) {
Ok(session) => session,
Err(_) => {
return Connect(ConnectInner::Error(Some(io::Error::new(
io::ErrorKind::Other,
"invalid connection",
))))
}
};

f(&mut session);

#[cfg(not(feature = "early-data"))]
Expand Down
Loading

0 comments on commit 55f643c

Please sign in to comment.