Skip to content

Commit

Permalink
Use default 'CryptoProvider' for all TLS ops.
Browse files Browse the repository at this point in the history
Prior to this commit, some TLS related operations used 'ring' even when
a different default 'CryptoProvider' was installed. This commit fixes
that by refactoring 'TlsConfig' such that all utility methods are
required to use the default 'CryptoProvider'.

This commit also cleans up code related to the rustls 0.23 update.
  • Loading branch information
SergioBenitez committed Mar 31, 2024
1 parent a96d221 commit 7a039c2
Show file tree
Hide file tree
Showing 11 changed files with 184 additions and 184 deletions.
21 changes: 16 additions & 5 deletions core/lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,6 @@ serde_json = { version = "1.0.26", optional = true }
rmp-serde = { version = "1", optional = true }
uuid_ = { package = "uuid", version = "1", optional = true, features = ["serde"] }

# Optional TLS dependencies
rustls = { version = "0.23", default-features = false, features = ["ring", "logging", "std", "tls12"], optional = true }
tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "tls12", "ring"], optional = true }
rustls-pemfile = { version = "2.0.0", optional = true }

# Optional MTLS dependencies
x509-parser = { version = "0.16", optional = true }

Expand Down Expand Up @@ -111,6 +106,22 @@ version = "0.6.0-dev"
path = "../http"
features = ["serde"]

[dependencies.rustls]
version = "0.23"
default-features = false
features = ["ring", "logging", "std", "tls12"]
optional = true

[dependencies.tokio-rustls]
version = "0.26"
default-features = false
features = ["logging", "tls12", "ring"]
optional = true

[dependencies.rustls-pemfile]
version = "2.1.0"
optional = true

[dependencies.s2n-quic]
version = "1.32"
default-features = false
Expand Down
8 changes: 4 additions & 4 deletions core/lib/src/listener/quic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,15 @@ impl QuicListener {
use quic::provider::tls::rustls::{rustls, DEFAULT_CIPHERSUITES, Server as H3TlsServer};

// FIXME: Remove this as soon as `s2n_quic` is on rustls >= 0.22.
let cert_chain = crate::tls::util::load_cert_chain(&mut tls.certs_reader().unwrap())
.unwrap()
let cert_chain = tls.load_certs()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?
.into_iter()
.map(|v| v.to_vec())
.map(rustls::Certificate)
.collect::<Vec<_>>();

let key = crate::tls::util::load_key(&mut tls.key_reader().unwrap())
.unwrap()
let key = tls.load_key()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?
.secret_der()
.to_vec();

Expand Down
16 changes: 5 additions & 11 deletions core/lib/src/listener/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsAcceptor;

use crate::tls::{TlsConfig, Error};
use crate::tls::util::{self, load_cert_chain, load_key, load_ca_certs};
use crate::listener::{Listener, Bindable, Connection, Certificates, Endpoint};

#[doc(inline)]
Expand All @@ -29,16 +28,13 @@ pub struct TlsBindable<I> {

impl TlsConfig {
pub(crate) fn server_config(&self) -> Result<ServerConfig, Error> {
let provider = rustls::crypto::CryptoProvider {
cipher_suites: self.ciphers().map(|c| c.into()).collect(),
..util::get_crypto_provider()
};
let provider = Arc::new(self.default_crypto_provider());

#[cfg(feature = "mtls")]
let verifier = match self.mutual {
Some(ref mtls) => {
let ca_certs = load_ca_certs(&mut mtls.ca_certs_reader()?)?;
let verifier = WebPkiClientVerifier::builder(Arc::new(ca_certs));
let ca = Arc::new(mtls.load_ca_certs()?);
let verifier = WebPkiClientVerifier::builder_with_provider(ca, provider.clone());
match mtls.mandatory {
true => verifier.build()?,
false => verifier.allow_unauthenticated().build()?,
Expand All @@ -50,12 +46,10 @@ impl TlsConfig {
#[cfg(not(feature = "mtls"))]
let verifier = WebPkiClientVerifier::no_client_auth();

let key = load_key(&mut self.key_reader()?)?;
let cert_chain = load_cert_chain(&mut self.certs_reader()?)?;
let mut tls_config = ServerConfig::builder_with_provider(Arc::new(provider))
let mut tls_config = ServerConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()?
.with_client_cert_verifier(verifier)
.with_single_cert(cert_chain, key)?;
.with_single_cert(self.load_certs()?, self.load_key()?)?;

tls_config.ignore_client_order = self.prefer_server_cipher_order;
tls_config.session_storage = ServerSessionMemoryCache::new(1024);
Expand Down
8 changes: 5 additions & 3 deletions core/lib/src/local/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,14 @@ macro_rules! pub_request_impl {
#[cfg_attr(nightly, doc(cfg(feature = "mtls")))]
pub fn identity<C: std::io::Read>(mut self, reader: C) -> Self {
use std::sync::Arc;
use crate::tls::util::load_cert_chain;
use crate::listener::Certificates;

let mut reader = std::io::BufReader::new(reader);
let certs = load_cert_chain(&mut reader).map(Certificates::from);
self._request_mut().connection.peer_certs = certs.ok().map(Arc::new);
self._request_mut().connection.peer_certs = rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()
.map(|certs| Arc::new(Certificates::from(certs)))
.ok();

self
}

Expand Down
13 changes: 13 additions & 0 deletions core/lib/src/mtls/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::io;
use figment::value::magic::{RelativePathBuf, Either};
use serde::{Serialize, Deserialize};

use crate::tls::{Result, Error};

/// Mutual TLS configuration.
///
/// Configuration works in concert with the [`mtls`](crate::mtls) module, which
Expand Down Expand Up @@ -142,6 +144,7 @@ impl MtlsConfig {
}

/// Returns the value of the `ca_certs` parameter.
///
/// # Example
///
/// ```rust
Expand All @@ -162,6 +165,16 @@ impl MtlsConfig {
pub fn ca_certs_reader(&self) -> io::Result<Box<dyn io::BufRead + Sync + Send>> {
crate::tls::config::to_reader(&self.ca_certs)
}

/// Load and decode CA certificates from `reader`.
pub(crate) fn load_ca_certs(&self) -> Result<rustls::RootCertStore> {
let mut roots = rustls::RootCertStore::empty();
for cert in rustls_pemfile::certs(&mut self.ca_certs_reader()?) {
roots.add(cert?).map_err(Error::CertAuth)?;
}

Ok(roots)
}
}

#[cfg(test)]
Expand Down
136 changes: 109 additions & 27 deletions core/lib/src/tls/config.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use std::io;

use rustls::crypto::{ring, CryptoProvider};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use figment::value::magic::{Either, RelativePathBuf};
use serde::{Deserialize, Serialize};
use indexmap::IndexSet;

use crate::tls::error::{Result, Error, KeyError};

/// TLS configuration: certificate chain, key, and ciphersuites.
///
/// Four parameters control `tls` configuration:
Expand Down Expand Up @@ -431,6 +435,72 @@ impl TlsConfig {
}
}

/// Loads certificates from `reader`.
impl TlsConfig {
pub(crate) fn load_certs(&self) -> Result<Vec<CertificateDer<'static>>> {
rustls_pemfile::certs(&mut self.certs_reader()?)
.collect::<Result<_, _>>()
.map_err(Error::CertChain)
}

/// Load and decode the private key from `reader`.
pub(crate) fn load_key(&self) -> Result<PrivateKeyDer<'static>> {
use rustls_pemfile::Item::*;

let mut keys = rustls_pemfile::read_all(&mut self.key_reader()?)
.map(|result| result.map_err(KeyError::Io)
.and_then(|item| match item {
Pkcs1Key(key) => Ok(key.into()),
Pkcs8Key(key) => Ok(key.into()),
Sec1Key(key) => Ok(key.into()),
_ => Err(KeyError::BadItem(item))
})
)
.collect::<Result<Vec<PrivateKeyDer<'static>>, _>>()?;

if keys.len() != 1 {
return Err(KeyError::BadKeyCount(keys.len()).into());
}

// Ensure we can use the key.
let key = keys.remove(0);
self.default_crypto_provider()
.key_provider
.load_private_key(key.clone_key())
.map_err(KeyError::Unsupported)?;

Ok(key)
}

pub(crate) fn default_crypto_provider(&self) -> CryptoProvider {
CryptoProvider::get_default()
.map(|arc| (**arc).clone())
.unwrap_or_else(|| rustls::crypto::CryptoProvider {
cipher_suites: self.ciphers().map(|cipher| match cipher {
CipherSuite::TLS_CHACHA20_POLY1305_SHA256 =>
ring::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256,
CipherSuite::TLS_AES_256_GCM_SHA384 =>
ring::cipher_suite::TLS13_AES_256_GCM_SHA384,
CipherSuite::TLS_AES_128_GCM_SHA256 =>
ring::cipher_suite::TLS13_AES_128_GCM_SHA256,
CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 =>
ring::cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 =>
ring::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 =>
ring::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 =>
ring::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 =>
ring::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 =>
ring::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
}).collect(),
..ring::default_provider()
})
}
}

impl CipherSuite {
/// The default set and order of cipher suites. These are all of the
/// variants in [`CipherSuite`] in their declaration order.
Expand Down Expand Up @@ -474,33 +544,6 @@ impl CipherSuite {
}
}

impl From<CipherSuite> for rustls::SupportedCipherSuite {
fn from(cipher: CipherSuite) -> Self {
use rustls::crypto::ring::cipher_suite;

match cipher {
CipherSuite::TLS_CHACHA20_POLY1305_SHA256 =>
cipher_suite::TLS13_CHACHA20_POLY1305_SHA256,
CipherSuite::TLS_AES_256_GCM_SHA384 =>
cipher_suite::TLS13_AES_256_GCM_SHA384,
CipherSuite::TLS_AES_128_GCM_SHA256 =>
cipher_suite::TLS13_AES_128_GCM_SHA256,
CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 =>
cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 =>
cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 =>
cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 =>
cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 =>
cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 =>
cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
}
}
}

pub(crate) fn to_reader(
value: &Either<RelativePathBuf, Vec<u8>>
) -> io::Result<Box<dyn io::BufRead + Sync + Send>> {
Expand All @@ -522,6 +565,7 @@ pub(crate) fn to_reader(

#[cfg(test)]
mod tests {
use super::*;
use figment::{Figment, providers::{Toml, Format}};

#[test]
Expand Down Expand Up @@ -650,4 +694,42 @@ mod tests {
Ok(())
});
}

macro_rules! tls_example_private_pem {
($k:expr) => {
concat!(env!("CARGO_MANIFEST_DIR"), "/../../examples/tls/private/", $k)
}
}

#[test]
fn verify_load_private_keys_of_different_types() -> Result<()> {
let key_paths = [
tls_example_private_pem!("rsa_sha256_key.pem"),
tls_example_private_pem!("ecdsa_nistp256_sha256_key_pkcs8.pem"),
tls_example_private_pem!("ecdsa_nistp384_sha384_key_pkcs8.pem"),
tls_example_private_pem!("ed25519_key.pem"),
];

for key in key_paths {
TlsConfig::from_paths("", key).load_key()?;
}

Ok(())
}

#[test]
fn verify_load_certs_of_different_types() -> Result<()> {
let cert_paths = [
tls_example_private_pem!("rsa_sha256_cert.pem"),
tls_example_private_pem!("ecdsa_nistp256_sha256_cert.pem"),
tls_example_private_pem!("ecdsa_nistp384_sha384_cert.pem"),
tls_example_private_pem!("ed25519_cert.pem"),
];

for cert in cert_paths {
TlsConfig::from_paths(cert, "").load_certs()?;
}

Ok(())
}
}
1 change: 0 additions & 1 deletion core/lib/src/tls/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
mod error;
pub(crate) mod config;
pub(crate) mod util;

pub use error::Result;
pub use config::{TlsConfig, CipherSuite};
Expand Down
Loading

0 comments on commit 7a039c2

Please sign in to comment.