From e535472b76fb5685393bbc97d48ea5b6506b74b4 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Tue, 22 Jun 2021 14:07:53 +0200 Subject: [PATCH] Upgrade to rustls 0.20 --- quinn-proto/Cargo.toml | 9 +- quinn-proto/src/config.rs | 50 +------ quinn-proto/src/crypto.rs | 5 +- quinn-proto/src/crypto/ring.rs | 45 ------ quinn-proto/src/crypto/rustls.rs | 245 ++++++++++++++++++------------- quinn-proto/src/crypto/types.rs | 41 ++++-- quinn-proto/src/packet.rs | 4 +- quinn-proto/src/tests/mod.rs | 61 ++++---- quinn-proto/src/tests/util.rs | 18 ++- 9 files changed, 223 insertions(+), 255 deletions(-) diff --git a/quinn-proto/Cargo.toml b/quinn-proto/Cargo.toml index ea8c556be4..403eb89100 100644 --- a/quinn-proto/Cargo.toml +++ b/quinn-proto/Cargo.toml @@ -20,7 +20,7 @@ maintenance = { status = "experimental" } default = ["tls-rustls"] # Use Google's list of CT logs to enable certificate transparency checks certificate-transparency = ["ct-logs"] -tls-rustls = ["rustls", "webpki", "ring"] +tls-rustls = ["rustls", "webpki", "ring", "rustls-pemfile"] # Trust the contents of the OS certificate store by default native-certs = ["rustls-native-certs"] @@ -28,19 +28,20 @@ native-certs = ["rustls-native-certs"] arbitrary = { version = "0.4.5", features = ["derive"], optional = true } bytes = "1" fxhash = "0.2.1" -ct-logs = { version = "0.8", optional = true } +ct-logs = { version = "0.9", optional = true } rand = "0.8" ring = { version = "0.16.7", optional = true } # If rustls gets updated to a new version which contains # https://github.com/ctz/rustls/commit/7117a805e0104705da50259357d8effa7d599e37 # the custom cipher list in `quinn-proto/src/crypto/rustls.rs` can be removed. -rustls = { version = "0.19", features = ["quic"], optional = true } +rustls = { version = "0.20.0-beta1", git = "https://github.com/djc/rustls", rev = "8f77bf5c878a2264b50756400982fbc141eb9940", features = ["quic"], optional = true } rustls-native-certs = { git = "https://github.com/djc/rustls-native-certs", rev = "c862ff371d8766deab109990f3e7b2e89d9af168", optional = true } +rustls-pemfile = { version = "0.2.1", optional = true } slab = "0.4" thiserror = "1.0.21" tinyvec = { version = "1.1", features = ["alloc"] } tracing = "0.1.10" -webpki = { version = "0.21", optional = true } +webpki = { version = "0.22", optional = true } [dev-dependencies] assert_matches = "1.1" diff --git a/quinn-proto/src/config.rs b/quinn-proto/src/config.rs index b8ae5102b8..158ea606cc 100644 --- a/quinn-proto/src/config.rs +++ b/quinn-proto/src/config.rs @@ -3,12 +3,10 @@ use std::{convert::TryInto, fmt, num::TryFromIntError, sync::Arc, time::Duration use rand::RngCore; use thiserror::Error; -#[cfg(feature = "rustls")] -use crate::crypto::types::{Certificate, CertificateChain, PrivateKey}; use crate::{ cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator}, congestion, - crypto::{self, ClientConfig as _, HandshakeTokenKey as _, HmacKey as _, ServerConfig as _}, + crypto::{self, ClientConfig as _, HandshakeTokenKey as _, HmacKey as _}, VarInt, VarIntBoundsExceeded, DEFAULT_SUPPORTED_VERSIONS, }; @@ -457,10 +455,10 @@ where S: crypto::Session, { /// Create a default config with a particular `master_key` - pub fn new(prk: S::HandshakeTokenKey) -> Self { + pub fn new(prk: S::HandshakeTokenKey, crypto: S::ServerConfig) -> Self { Self { transport: Arc::new(TransportConfig::default()), - crypto: S::ServerConfig::new(), + crypto, token_key: Arc::new(prk), use_stateless_retry: false, @@ -511,19 +509,6 @@ where } } -#[cfg(feature = "rustls")] -impl ServerConfig { - /// Set the certificate chain that will be presented to clients - pub fn certificate( - &mut self, - cert_chain: CertificateChain, - key: PrivateKey, - ) -> Result<&mut Self, rustls::TLSError> { - Arc::make_mut(&mut self.crypto).set_single_cert(cert_chain.certs, key.inner)?; - Ok(self) - } -} - impl fmt::Debug for ServerConfig where S: crypto::Session, @@ -541,20 +526,6 @@ where } } -impl Default for ServerConfig -where - S: crypto::Session, -{ - fn default() -> Self { - let rng = &mut rand::thread_rng(); - - let mut master_key = [0u8; 64]; - rng.fill_bytes(&mut master_key); - - Self::new(S::HandshakeTokenKey::from_secret(&master_key)) - } -} - impl Clone for ServerConfig where S: crypto::Session, @@ -587,21 +558,6 @@ where pub crypto: S::ClientConfig, } -#[cfg(feature = "rustls")] -impl ClientConfig { - /// Add a trusted certificate authority - pub fn add_certificate_authority( - &mut self, - cert: Certificate, - ) -> Result<&mut Self, webpki::Error> { - let anchor = webpki::trust_anchor_util::cert_der_as_trust_anchor(&cert.inner.0)?; - Arc::make_mut(&mut self.crypto) - .root_store - .add_server_trust_anchors(&webpki::TLSServerTrustAnchors(&[anchor])); - Ok(self) - } -} - impl Default for ClientConfig where S: crypto::Session, diff --git a/quinn-proto/src/crypto.rs b/quinn-proto/src/crypto.rs index 1a84fd71b1..f05e82824e 100644 --- a/quinn-proto/src/crypto.rs +++ b/quinn-proto/src/crypto.rs @@ -16,6 +16,7 @@ use crate::{ config::ConfigError, shared::ConnectionId, transport_parameters::TransportParameters, ConnectError, Side, TransportError, }; +use crate::{Certificate, PrivateKey}; /// Cryptography interface based on *ring* #[cfg(feature = "ring")] @@ -158,8 +159,8 @@ pub trait ServerConfig: Clone + Send + Sync where S: Session, { - /// Construct the default configuration - fn new() -> Self + /// Construct a default configuration with a single server certificate + fn with_single_cert(cert_chain: Vec, key_der: PrivateKey) -> Self where Self: Sized; diff --git a/quinn-proto/src/crypto/ring.rs b/quinn-proto/src/crypto/ring.rs index ff4697d657..6d830aab9b 100644 --- a/quinn-proto/src/crypto/ring.rs +++ b/quinn-proto/src/crypto/ring.rs @@ -3,53 +3,8 @@ use ring::{aead, hkdf, hmac}; use crate::{ config::ConfigError, crypto::{self, CryptoError}, - packet::{PacketNumber, LONG_HEADER_FORM}, }; -impl crypto::HeaderKey for aead::quic::HeaderProtectionKey { - fn decrypt(&self, pn_offset: usize, packet: &mut [u8]) { - let (header, sample) = packet.split_at_mut(pn_offset + 4); - let mask = self.new_mask(&sample[0..self.sample_size()]).unwrap(); - if header[0] & LONG_HEADER_FORM == LONG_HEADER_FORM { - // Long header: 4 bits masked - header[0] ^= mask[0] & 0x0f; - } else { - // Short header: 5 bits masked - header[0] ^= mask[0] & 0x1f; - } - let pn_length = PacketNumber::decode_len(header[0]); - for (out, inp) in header[pn_offset..pn_offset + pn_length] - .iter_mut() - .zip(&mask[1..]) - { - *out ^= inp; - } - } - - fn encrypt(&self, pn_offset: usize, packet: &mut [u8]) { - let (header, sample) = packet.split_at_mut(pn_offset + 4); - let mask = self.new_mask(&sample[0..self.sample_size()]).unwrap(); - let pn_length = PacketNumber::decode_len(header[0]); - if header[0] & LONG_HEADER_FORM == LONG_HEADER_FORM { - // Long header: 4 bits masked - header[0] ^= mask[0] & 0x0f; - } else { - // Short header: 5 bits masked - header[0] ^= mask[0] & 0x1f; - } - for (out, inp) in header[pn_offset..pn_offset + pn_length] - .iter_mut() - .zip(&mask[1..]) - { - *out ^= inp; - } - } - - fn sample_size(&self) -> usize { - self.algorithm().sample_len() - } -} - impl crypto::HmacKey for hmac::Key { const KEY_LEN: usize = 64; type Signature = hmac::Tag; diff --git a/quinn-proto/src/crypto/rustls.rs b/quinn-proto/src/crypto/rustls.rs index 289ff9f2bd..dd3a8eba4c 100644 --- a/quinn-proto/src/crypto/rustls.rs +++ b/quinn-proto/src/crypto/rustls.rs @@ -1,4 +1,5 @@ use std::{ + convert::TryInto, io, ops::{Deref, DerefMut}, str, @@ -6,17 +7,19 @@ use std::{ }; use bytes::BytesMut; -use ring::{aead, aead::quic::HeaderProtectionKey, hkdf, hmac}; -pub use rustls::TLSError; +use ring::{aead, hkdf, hmac}; +pub use rustls::Error; use rustls::{ self, - quic::{ClientQuicExt, PacketKey, ServerQuicExt}, - Session, + quic::{ + ClientQuicExt, HeaderProtectionKey, KeyChange, PacketKey, Secrets, ServerQuicExt, Version, + }, + Connection, RootCertStore, }; -use webpki::DNSNameRef; use crate::{ crypto::{self, CryptoError, ExportKeyingMaterialError, KeyPair, Keys}, + packet::PacketNumber, transport_parameters::TransportParameters, CertificateChain, ConnectError, ConnectionId, Side, TransportError, TransportErrorCode, }; @@ -26,13 +29,14 @@ use crate::{ pub struct TlsSession { using_alpn: bool, got_handshake_data: bool, + next_secrets: Option, inner: SessionKind, } #[derive(Debug)] enum SessionKind { - Client(rustls::ClientSession), - Server(rustls::ServerSession), + Client(rustls::ClientConnection), + Server(rustls::ServerConnection), } impl TlsSession { @@ -55,13 +59,7 @@ impl crypto::Session for TlsSession { type ServerConfig = Arc; fn initial_keys(dst_cid: &ConnectionId, side: Side) -> Keys { - const INITIAL_SALT: [u8; 20] = [ - 0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, - 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99, - ]; - - let salt = ring::hkdf::Salt::new(ring::hkdf::HKDF_SHA256, &INITIAL_SALT); - let keys = rustls::quic::Keys::initial(&salt, dst_cid, side.is_client()); + let keys = rustls::quic::Keys::initial(Version::V1Draft, dst_cid, side.is_client()); Keys { header: KeyPair { local: keys.local.header, @@ -79,20 +77,20 @@ impl crypto::Session for TlsSession { return None; } Some(HandshakeData { - protocol: self.get_alpn_protocol().map(|x| x.into()), + protocol: self.alpn_protocol().map(|x| x.into()), server_name: match self.inner { SessionKind::Client(_) => None, - SessionKind::Server(ref session) => session.get_sni_hostname().map(|x| x.into()), + SessionKind::Server(ref session) => session.sni_hostname().map(|x| x.into()), }, }) } fn peer_identity(&self) -> Option { - self.get_peer_certificates().map(|v| v.into()) + self.peer_certificates().map(|v| v.into()) } fn early_crypto(&self) -> Option<(Self::HeaderKey, Self::PacketKey)> { - let keys = self.get_0rtt_keys()?; + let keys = self.zero_rtt_keys()?; Some((keys.header, keys.packet)) } @@ -112,7 +110,7 @@ impl crypto::Session for TlsSession { fn read_handshake(&mut self, buf: &[u8]) -> Result { self.read_hs(buf).map_err(|e| { - if let Some(alert) = self.get_alert() { + if let Some(alert) = self.alert() { TransportError { code: TransportErrorCode::crypto(alert.get_u8()), frame: None, @@ -128,11 +126,11 @@ impl crypto::Session for TlsSession { // connections. let have_server_name = match self.inner { SessionKind::Client(_) => false, - SessionKind::Server(ref session) => session.get_sni_hostname().is_some(), + SessionKind::Server(ref session) => session.sni_hostname().is_some(), }; - if self.get_alpn_protocol().is_some() || have_server_name || !self.is_handshaking() { + if self.alpn_protocol().is_some() || have_server_name || !self.is_handshaking() { self.got_handshake_data = true; - if self.using_alpn && self.get_alpn_protocol().is_none() { + if self.using_alpn && self.alpn_protocol().is_none() { // rustls ignores total ALPN failure for compat, but QUIC gets a fresh start return Err(TransportError { code: TransportErrorCode::crypto(0x78), @@ -147,7 +145,7 @@ impl crypto::Session for TlsSession { } fn transport_parameters(&self) -> Result, TransportError> { - match self.get_quic_transport_parameters() { + match self.quic_transport_parameters() { None => Ok(None), Some(buf) => match TransportParameters::read(self.side(), &mut io::Cursor::new(buf)) { Ok(params) => Ok(Some(params)), @@ -157,7 +155,14 @@ impl crypto::Session for TlsSession { } fn write_handshake(&mut self, buf: &mut Vec) -> Option> { - let keys = self.write_hs(buf)?; + let keys = match self.write_hs(buf)? { + KeyChange::Handshake { keys } => keys, + KeyChange::OneRtt { keys, next } => { + self.next_secrets = Some(next); + keys + } + }; + Some(Keys { header: KeyPair { local: keys.local.header, @@ -171,7 +176,8 @@ impl crypto::Session for TlsSession { } fn next_1rtt_keys(&mut self) -> Option> { - let keys = (**self).next_1rtt_keys(); + let secrets = self.next_secrets.as_mut()?; + let keys = secrets.next_packet_keys(); Some(KeyPair { local: keys.local, remote: keys.remote, @@ -226,7 +232,7 @@ impl crypto::Session for TlsSession { label: &[u8], context: &[u8], ) -> Result<(), ExportKeyingMaterialError> { - let session: &dyn rustls::Session = match &self.inner { + let session: &dyn rustls::Connection = match &self.inner { SessionKind::Client(s) => s, SessionKind::Server(s) => s, }; @@ -237,7 +243,7 @@ impl crypto::Session for TlsSession { } impl Deref for TlsSession { - type Target = dyn rustls::Session; + type Target = dyn rustls::Connection; fn deref(&self) -> &Self::Target { match self.inner { SessionKind::Client(ref session) => session, @@ -247,7 +253,7 @@ impl Deref for TlsSession { } impl DerefMut for TlsSession { - fn deref_mut(&mut self) -> &mut (dyn rustls::Session + 'static) { + fn deref_mut(&mut self) -> &mut (dyn rustls::Connection + 'static) { match self.inner { SessionKind::Client(ref mut session) => session, SessionKind::Server(ref mut session) => session, @@ -262,6 +268,36 @@ const RETRY_INTEGRITY_NONCE: [u8; 12] = [ 0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c, ]; +impl crypto::HeaderKey for HeaderProtectionKey { + fn decrypt(&self, pn_offset: usize, packet: &mut [u8]) { + let (header, sample) = packet.split_at_mut(pn_offset + 4); + let (first, rest) = header.split_at_mut(1); + let pn_length = PacketNumber::decode_len(first[0]); + self.xor_in_place( + sample, + &mut first[0], + &mut rest[pn_offset - 1..pn_offset + pn_length - 1], + ) + .unwrap(); + } + + fn encrypt(&self, pn_offset: usize, packet: &mut [u8]) { + let (header, sample) = packet.split_at_mut(pn_offset + 4); + let (first, rest) = header.split_at_mut(1); + let pn_length = PacketNumber::decode_len(first[0]); + self.xor_in_place( + sample, + &mut first[0], + &mut rest[pn_offset - 1..pn_offset + pn_length - 1], + ) + .unwrap(); + } + + fn sample_size(&self) -> usize { + self.sample_len() + } +} + /// Authentication data for (rustls) TLS session pub struct HandshakeData { /// The negotiated application protocol, if ALPN is in use @@ -276,28 +312,44 @@ pub struct HandshakeData { impl crypto::ClientConfig for Arc { fn new() -> Self { - let mut cfg = rustls::ClientConfig::with_ciphersuites(&QUIC_CIPHER_SUITES); - cfg.versions = vec![rustls::ProtocolVersion::TLSv1_3]; - cfg.enable_early_data = true; + let cfg = rustls::config_builder() + .with_cipher_suites(&QUIC_CIPHER_SUITES) + .with_safe_default_kx_groups() + .with_protocol_versions(&[&rustls::version::TLS13]) + .for_client() + .unwrap(); + + #[allow(unused_mut)] + let mut roots = RootCertStore::empty(); #[cfg(feature = "native-certs")] - match rustls_native_certs::load_native_certs() { - Ok(certs) => { - let mut roots = rustls::RootCertStore::empty(); - for cert in certs { - if let Err(e) = roots.add(&rustls::Certificate(cert.0)) { - tracing::warn!("failed to parse trust anchor: {}", e); - } + { + let certs = match rustls_native_certs::load_native_certs() { + Ok(certs) => certs, + Err(e) => { + tracing::warn!("couldn't load any default trust roots: {}", e); + Vec::new() + } + }; + + for cert in certs { + if let Err(e) = roots.add(&rustls::Certificate(cert.0)) { + tracing::warn!("failed to parse trust anchor: {}", e); } - cfg.root_store = roots; - } - Err(e) => { - tracing::warn!("couldn't load any default trust roots: {}", e); } } + + #[allow(unused_mut)] + let mut ct_logs = &[]; #[cfg(feature = "certificate-transparency")] { - cfg.ct_logs = Some(&ct_logs::LOGS); + ct_logs = ct_logs::LOGS; } + + let mut cfg = cfg + .with_root_certificates(roots, ct_logs) + .with_no_client_auth(); + + cfg.enable_early_data = true; Arc::new(cfg) } @@ -306,36 +358,55 @@ impl crypto::ClientConfig for Arc { server_name: &str, params: &TransportParameters, ) -> Result { - let pki_server_name = DNSNameRef::try_from_ascii_str(server_name) - .map_err(|_| ConnectError::InvalidDnsName(server_name.into()))?; Ok(TlsSession { using_alpn: !self.alpn_protocols.is_empty(), got_handshake_data: false, - inner: SessionKind::Client(rustls::ClientSession::new_quic( - self, - pki_server_name, - to_vec(params), - )), + next_secrets: None, + inner: SessionKind::Client( + rustls::ClientConnection::new_quic( + self.clone(), + Version::V1Draft, + server_name + .try_into() + .map_err(|_| ConnectError::InvalidDnsName(server_name.into()))?, + to_vec(params), + ) + .unwrap(), + ), }) } } impl crypto::ServerConfig for Arc { - fn new() -> Self { - let mut cfg = rustls::ServerConfig::with_ciphersuites( - rustls::NoClientAuth::new(), - &QUIC_CIPHER_SUITES, - ); - cfg.versions = vec![rustls::ProtocolVersion::TLSv1_3]; - cfg.max_early_data_size = u32::max_value(); - Arc::new(cfg) + fn with_single_cert(cert_chain: Vec, key_der: crate::PrivateKey) -> Self + where + Self: Sized, + { + Arc::new( + rustls::config_builder() + .with_cipher_suites(&QUIC_CIPHER_SUITES) + .with_safe_default_kx_groups() + .with_protocol_versions(&[&rustls::version::TLS13]) + .for_server() + .unwrap() + .with_no_client_auth() + .with_single_cert( + cert_chain.into_iter().map(|c| c.inner).collect(), + key_der.inner, + ) + .unwrap(), + ) } fn start_session(&self, params: &TransportParameters) -> TlsSession { TlsSession { using_alpn: !self.alpn_protocols.is_empty(), got_handshake_data: false, - inner: SessionKind::Server(rustls::ServerSession::new_quic(self, to_vec(params))), + next_secrets: None, + inner: SessionKind::Server( + rustls::ServerConnection::new_quic(self.clone(), Version::V1Draft, to_vec(params)) + .unwrap(), + ), } } } @@ -348,16 +419,10 @@ fn to_vec(params: &TransportParameters) -> Vec { impl crypto::PacketKey for PacketKey { fn encrypt(&self, packet: u64, buf: &mut [u8], header_len: usize) { - let (header, payload) = buf.split_at_mut(header_len); - let (payload, tag_storage) = - payload.split_at_mut(payload.len() - self.key.algorithm().tag_len()); - let aad = aead::Aad::from(header); - let nonce = self.iv.nonce_for(packet); - let tag = self - .key - .seal_in_place_separate_tag(nonce, aad, payload) - .unwrap(); - tag_storage.copy_from_slice(tag.as_ref()); + let (header, payload_tag) = buf.split_at_mut(header_len); + let (payload, tag_buf) = payload_tag.split_at_mut(payload_tag.len() - self.tag_len()); + let tag = self.encrypt_in_place(packet, &*header, payload).unwrap(); + tag_buf.copy_from_slice(tag.as_ref()); } fn decrypt( @@ -366,42 +431,24 @@ impl crypto::PacketKey for PacketKey { header: &[u8], payload: &mut BytesMut, ) -> Result<(), CryptoError> { - if payload.len() < self.key.algorithm().tag_len() { - return Err(CryptoError); - } - - let payload_len = payload.len(); - let aad = aead::Aad::from(header); - let nonce = self.iv.nonce_for(packet); - self.key.open_in_place(nonce, aad, payload.as_mut())?; - payload.truncate(payload_len - self.key.algorithm().tag_len()); + let plain = self + .decrypt_in_place(packet, &*header, payload.as_mut()) + .map_err(|_| CryptoError)?; + let plain_len = plain.len(); + payload.truncate(plain_len); Ok(()) } fn tag_len(&self) -> usize { - self.key.algorithm().tag_len() + self.tag_len() } fn confidentiality_limit(&self) -> u64 { - let cipher = self.key.algorithm(); - if cipher == &aead::AES_128_GCM || cipher == &aead::AES_256_GCM { - 2u64.pow(23) - } else if cipher == &aead::CHACHA20_POLY1305 { - u64::MAX - } else { - panic!("unknown cipher") - } + self.confidentiality_limit() } fn integrity_limit(&self) -> u64 { - let cipher = self.key.algorithm(); - if cipher == &aead::AES_128_GCM || cipher == &aead::AES_256_GCM { - 2u64.pow(52) - } else if cipher == &aead::CHACHA20_POLY1305 { - 2u64.pow(36) - } else { - panic!("unknown cipher") - } + self.integrity_limit() } } @@ -413,8 +460,8 @@ impl crypto::PacketKey for PacketKey { /// This list prefers AES ciphers, which are hardware accelerated on most platforms. /// This list can be removed if the rustls dependency is updated to a new version /// which contains the linked change. -static QUIC_CIPHER_SUITES: [&rustls::SupportedCipherSuite; 3] = [ - &rustls::ciphersuite::TLS13_AES_256_GCM_SHA384, - &rustls::ciphersuite::TLS13_AES_128_GCM_SHA256, - &rustls::ciphersuite::TLS13_CHACHA20_POLY1305_SHA256, +static QUIC_CIPHER_SUITES: [rustls::SupportedCipherSuite; 3] = [ + rustls::cipher_suite::TLS13_AES_256_GCM_SHA384, + rustls::cipher_suite::TLS13_AES_128_GCM_SHA256, + rustls::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, ]; diff --git a/quinn-proto/src/crypto/types.rs b/quinn-proto/src/crypto/types.rs index acf8971e03..7fe32dd8e8 100644 --- a/quinn-proto/src/crypto/types.rs +++ b/quinn-proto/src/crypto/types.rs @@ -1,7 +1,5 @@ use std::fmt; -use rustls::internal::pemfile; - /// A single TLS certificate #[derive(Debug, Clone)] pub struct Certificate { @@ -18,9 +16,12 @@ impl Certificate { /// Parse a PEM-formatted certificate pub fn from_pem(pem: &[u8]) -> Result { - let certs = pemfile::certs(&mut &*pem).map_err(|()| ParseError("invalid pem cert"))?; + let certs = + rustls_pemfile::certs(&mut &*pem).map_err(|_| ParseError("invalid pem cert"))?; if let Some(pem) = certs.into_iter().next() { - return Ok(Self { inner: pem }); + return Ok(Self { + inner: rustls::Certificate(pem), + }); } Err(ParseError("no cert found")) @@ -52,9 +53,11 @@ impl CertificateChain { /// let cert_chain = quinn_proto::PrivateKey::from_pem(&pem).expect("error parsing certificates"); /// ``` pub fn from_pem(pem: &[u8]) -> Result { + let der_certs = rustls_pemfile::certs(&mut &*pem) + .map_err(|_| ParseError("malformed certificate chain"))?; + Ok(Self { - certs: pemfile::certs(&mut &*pem) - .map_err(|()| ParseError("malformed certificate chain"))?, + certs: der_certs.into_iter().map(rustls::Certificate).collect(), }) } @@ -80,9 +83,11 @@ impl std::iter::FromIterator for CertificateChain { } } -impl From> for CertificateChain { - fn from(certs: Vec) -> Self { - Self { certs } +impl From<&[rustls::Certificate]> for CertificateChain { + fn from(certs: &[rustls::Certificate]) -> Self { + Self { + certs: certs.to_vec(), + } } } @@ -118,16 +123,22 @@ impl PrivateKey { /// let key = quinn_proto::PrivateKey::from_pem(&pem).expect("error parsing key"); /// ``` pub fn from_pem(pem: &[u8]) -> Result { - let pkcs8 = pemfile::pkcs8_private_keys(&mut &*pem) - .map_err(|()| ParseError("malformed PKCS #8 private key"))?; + let pkcs8 = rustls_pemfile::pkcs8_private_keys(&mut &*pem) + .map_err(|_| ParseError("malformed PKCS #8 private key"))?; if let Some(x) = pkcs8.into_iter().next() { - return Ok(Self { inner: x }); + return Ok(Self { + inner: rustls::PrivateKey(x), + }); } - let rsa = pemfile::rsa_private_keys(&mut &*pem) - .map_err(|()| ParseError("malformed PKCS #1 private key"))?; + + let rsa = rustls_pemfile::rsa_private_keys(&mut &*pem) + .map_err(|_| ParseError("malformed PKCS #1 private key"))?; if let Some(x) = rsa.into_iter().next() { - return Ok(Self { inner: x }); + return Ok(Self { + inner: rustls::PrivateKey(x), + }); } + Err(ParseError("no private key found")) } diff --git a/quinn-proto/src/packet.rs b/quinn-proto/src/packet.rs index c2b5bee183..ca601a0cc3 100644 --- a/quinn-proto/src/packet.rs +++ b/quinn-proto/src/packet.rs @@ -832,7 +832,7 @@ mod tests { #[test] fn header_encoding() { use crate::{ - crypto::{rustls::TlsSession, PacketKey, Session}, + crypto::{rustls::TlsSession, Session}, Side, }; @@ -880,7 +880,7 @@ mod tests { server .packet .remote - .decrypt(0, &packet.header_data, &mut packet.payload) + .decrypt_in_place(0, &packet.header_data, &mut packet.payload) .unwrap(); assert_eq!(packet.payload[..], [0; 16]); match packet.header { diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs index 52e63b1452..16f744742c 100644 --- a/quinn-proto/src/tests/mod.rs +++ b/quinn-proto/src/tests/mod.rs @@ -447,14 +447,13 @@ fn zero_rtt_happypath() { fn zero_rtt_rejection() { let _guard = subscribe(); let mut server_config = server_config(); - Arc::get_mut(&mut server_config.crypto) - .unwrap() - .set_protocols(&["foo".into(), "bar".into()]); + let server_config_mut = Arc::get_mut(&mut server_config.crypto).unwrap(); + server_config_mut.alpn_protocols = vec!["foo".into(), "bar".into()]; + let mut pair = Pair::new(Arc::new(EndpointConfig::default()), server_config); let mut client_config = client_config(); - Arc::get_mut(&mut client_config.crypto) - .unwrap() - .set_protocols(&["foo".into()]); + let client_config_mut = Arc::get_mut(&mut client_config.crypto).unwrap(); + client_config_mut.alpn_protocols = vec!["foo".into()]; // Establish normal connection let client_ch = pair.begin_connect(client_config.clone()); @@ -484,9 +483,9 @@ fn zero_rtt_rejection() { pair.server.connections.clear(); // Changing protocols invalidates 0-RTT - Arc::get_mut(&mut client_config.crypto) - .unwrap() - .set_protocols(&["bar".into()]); + let mut client_config_mut = Arc::get_mut(&mut client_config.crypto).unwrap(); + client_config_mut.alpn_protocols = vec!["bar".into()]; + info!("resuming session"); let client_ch = pair.begin_connect(client_config); assert!(pair.client_conn_mut(client_ch).has_0rtt()); @@ -519,14 +518,13 @@ fn zero_rtt_rejection() { fn alpn_success() { let _guard = subscribe(); let mut server_config = server_config(); - Arc::get_mut(&mut server_config.crypto) - .unwrap() - .set_protocols(&["foo".into(), "bar".into(), "baz".into()]); + let mut server_config_mut = Arc::get_mut(&mut server_config.crypto).unwrap(); + server_config_mut.alpn_protocols = vec!["foo".into(), "bar".into(), "baz".into()]; + let mut pair = Pair::new(Arc::new(EndpointConfig::default()), server_config); let mut client_config = client_config(); - Arc::get_mut(&mut client_config.crypto) - .unwrap() - .set_protocols(&["bar".into(), "quux".into(), "corge".into()]); + let mut client_config_mut = Arc::get_mut(&mut client_config.crypto).unwrap(); + client_config_mut.alpn_protocols = vec!["bar".into(), "quux".into(), "corge".into()]; // Establish normal connection let client_ch = pair.begin_connect(client_config); @@ -554,9 +552,8 @@ fn server_alpn_unset() { let _guard = subscribe(); let mut pair = Pair::new(Arc::new(EndpointConfig::default()), server_config()); let mut client_config = client_config(); - Arc::get_mut(&mut client_config.crypto) - .unwrap() - .set_protocols(&["foo".into()]); + let mut client_config_mut = Arc::get_mut(&mut client_config.crypto).unwrap(); + client_config_mut.alpn_protocols = vec!["foo".into()]; let client_ch = pair.begin_connect(client_config); pair.drive(); @@ -570,9 +567,8 @@ fn server_alpn_unset() { fn client_alpn_unset() { let _guard = subscribe(); let mut server_config = server_config(); - Arc::get_mut(&mut server_config.crypto) - .unwrap() - .set_protocols(&["foo".into(), "bar".into(), "baz".into()]); + let mut server_config_mut = Arc::get_mut(&mut server_config.crypto).unwrap(); + server_config_mut.alpn_protocols = vec!["foo".into(), "bar".into(), "baz".into()]; let mut pair = Pair::new(Arc::new(EndpointConfig::default()), server_config); let client_ch = pair.begin_connect(client_config()); @@ -586,14 +582,13 @@ fn client_alpn_unset() { #[test] fn alpn_mismatch() { let mut server_config = server_config(); - Arc::get_mut(&mut server_config.crypto) - .unwrap() - .set_protocols(&["foo".into(), "bar".into(), "baz".into()]); + let mut server_config_mut = Arc::get_mut(&mut server_config.crypto).unwrap(); + server_config_mut.alpn_protocols = vec!["foo".into(), "bar".into(), "baz".into()]; let mut pair = Pair::new(Arc::new(EndpointConfig::default()), server_config); + let mut client_config = client_config(); - Arc::get_mut(&mut client_config.crypto) - .unwrap() - .set_protocols(&["quux".into(), "corge".into()]); + let mut client_config_mut = Arc::get_mut(&mut client_config.crypto).unwrap(); + client_config_mut.alpn_protocols = vec!["quux".into(), "corge".into()]; let client_ch = pair.begin_connect(client_config); pair.drive(); @@ -1549,17 +1544,17 @@ fn datagram_unsupported() { fn large_initial() { let _guard = subscribe(); let mut server_config = server_config(); - Arc::get_mut(&mut server_config.crypto) - .unwrap() - .set_protocols(&[vec![0, 0, 0, 42]]); + let mut server_config_mut = Arc::get_mut(&mut server_config.crypto).unwrap(); + server_config_mut.alpn_protocols = vec![vec![0, 0, 0, 42]]; let mut pair = Pair::new(Arc::new(EndpointConfig::default()), server_config); + let mut cfg = client_config(); let protocols = (0..1000u32) .map(|x| x.to_be_bytes().to_vec()) .collect::>(); - Arc::get_mut(&mut cfg.crypto) - .unwrap() - .set_protocols(&protocols); + let mut cfg_mut = Arc::get_mut(&mut cfg.crypto).unwrap(); + cfg_mut.alpn_protocols = protocols; + let client_ch = pair.begin_connect(cfg); pair.drive(); let server_ch = pair.server.assert_accept(); diff --git a/quinn-proto/src/tests/util.rs b/quinn-proto/src/tests/util.rs index 09f1038d52..52cfe8d96a 100644 --- a/quinn-proto/src/tests/util.rs +++ b/quinn-proto/src/tests/util.rs @@ -379,13 +379,14 @@ pub fn server_config() -> ServerConfig { let key = CERTIFICATE.serialize_private_key_der(); let cert = CERTIFICATE.serialize_pem().unwrap(); - let mut crypto = crypto::ServerConfig::new(); - Arc::make_mut(&mut crypto) - .set_single_cert( - rustls::internal::pemfile::certs(&mut cert.as_bytes()).unwrap(), - rustls::PrivateKey(key.to_vec()), - ) - .unwrap(); + let mut crypto = crypto::ServerConfig::with_single_cert( + rustls_pemfile::certs(&mut cert.as_bytes()) + .unwrap() + .into_iter() + .map(|cert| crate::Certificate { inner: rustls::Certificate(cert) }) + .collect(), + crate::PrivateKey { inner: rustls::PrivateKey(key.to_vec()) }, + ); ServerConfig { crypto, ..Default::default() @@ -394,13 +395,14 @@ pub fn server_config() -> ServerConfig { pub fn client_config() -> ClientConfig { let cert = CERTIFICATE.serialize_der().unwrap(); - let anchor = webpki::trust_anchor_util::cert_der_as_trust_anchor(&cert).unwrap(); + let anchor = webpki::TrustAnchor::try_from_cert_der(&cert).unwrap(); let anchor_vec = vec![anchor]; let mut crypto = crypto::ClientConfig::new(); Arc::make_mut(&mut crypto) .root_store .add_server_trust_anchors(&webpki::TLSServerTrustAnchors(&anchor_vec)); + Arc::make_mut(&mut crypto).key_log = Arc::new(KeyLogFile::new()); Arc::make_mut(&mut crypto).enable_early_data = true; ClientConfig {