From fc18505ce16bc30a113392e5b29df8e9a982d519 Mon Sep 17 00:00:00 2001 From: Lucas Kent Date: Tue, 4 Apr 2023 11:14:57 +1000 Subject: [PATCH] Remove openssl --- Cargo.lock | 9 +- .../tests/cassandra_int_tests/cluster/mod.rs | 2 +- shotover/src/lib.rs | 1 - shotover/src/server.rs | 9 +- shotover/src/sources/cassandra.rs | 2 +- shotover/src/sources/kafka.rs | 2 +- shotover/src/sources/redis.rs | 2 +- shotover/src/tls.rs | 363 +++++++++++------- shotover/src/tlsls.rs | 319 --------------- .../src/transforms/cassandra/connection.rs | 2 +- .../transforms/cassandra/sink_cluster/mod.rs | 2 +- .../transforms/cassandra/sink_cluster/node.rs | 2 +- .../src/transforms/cassandra/sink_single.rs | 2 +- shotover/src/transforms/redis/sink_cluster.rs | 2 +- shotover/src/transforms/redis/sink_single.rs | 2 +- .../util/cluster_connection_pool.rs | 2 +- .../src/connection/redis_connection.rs | 2 - 17 files changed, 240 insertions(+), 485 deletions(-) delete mode 100644 shotover/src/tlsls.rs diff --git a/Cargo.lock b/Cargo.lock index ccdc4cfd3..14d96e096 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2192,9 +2192,9 @@ checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" [[package]] name = "openssl" -version = "0.10.48" +version = "0.10.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "518915b97df115dd36109bfa429a48b8f737bd05508cf9588977b599648926d2" +checksum = "4d2f106ab837a24e03672c59b1239669a0596406ff657c3c0835b6b7f0f35a33" dependencies = [ "bitflags 1.3.2", "cfg-if", @@ -2233,11 +2233,10 @@ dependencies = [ [[package]] name = "openssl-sys" -version = "0.9.83" +version = "0.9.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "666416d899cf077260dac8698d60a60b435a46d57e82acb1be3d0dad87284e5b" +checksum = "3a20eace9dc2d82904039cb76dcf50fb1a0bba071cfd1629720b5d6f1ddba0fa" dependencies = [ - "autocfg", "cc", "libc", "openssl-src", diff --git a/shotover-proxy/tests/cassandra_int_tests/cluster/mod.rs b/shotover-proxy/tests/cassandra_int_tests/cluster/mod.rs index 803a41649..246d6bb3a 100644 --- a/shotover-proxy/tests/cassandra_int_tests/cluster/mod.rs +++ b/shotover-proxy/tests/cassandra_int_tests/cluster/mod.rs @@ -2,7 +2,7 @@ use cassandra_protocol::frame::message_startup::BodyReqStartup; use cassandra_protocol::frame::Version; use shotover::frame::{cassandra::Tracing, CassandraFrame, CassandraOperation, Frame}; use shotover::message::Message; -use shotover::tlsls::{TlsConnector, TlsConnectorConfig}; +use shotover::tls::{TlsConnector, TlsConnectorConfig}; use shotover::transforms::cassandra::sink_cluster::{ node::{CassandraNode, ConnectionFactory}, topology::{create_topology_task, TaskConnectionInfo}, diff --git a/shotover/src/lib.rs b/shotover/src/lib.rs index 888cdcca1..2b568f039 100644 --- a/shotover/src/lib.rs +++ b/shotover/src/lib.rs @@ -33,6 +33,5 @@ mod server; pub mod sources; pub mod tcp; pub mod tls; -pub mod tlsls; pub mod tracing_panic_handler; pub mod transforms; diff --git a/shotover/src/server.rs b/shotover/src/server.rs index 2a7d26d1f..65969ddf5 100644 --- a/shotover/src/server.rs +++ b/shotover/src/server.rs @@ -1,12 +1,13 @@ use crate::codec::{CodecBuilder, CodecReadError}; use crate::message::Messages; -use crate::tlsls::{AcceptError, TlsAcceptor}; +use crate::tls::{AcceptError, TlsAcceptor}; use crate::transforms::chain::{TransformChain, TransformChainBuilder}; use crate::transforms::Wrapper; use anyhow::{anyhow, Context, Result}; use futures::future::join_all; use futures::{SinkExt, StreamExt}; use metrics::{register_gauge, Gauge}; +use std::io::ErrorKind; use std::net::SocketAddr; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; @@ -363,7 +364,11 @@ fn spawn_read_write_tasks< return; } Err(CodecReadError::Io(err)) => { - warn!("failed to receive message on tcp stream: {:?}", err); + // I suspect (but have not confirmed) that UnexpectedEof occurs here when the ssl client does not send "close notify" before terminating the connection. + // We shouldnt report that as a warning because its common for clients to do that for performance reasons. + if !matches!(err.kind(), ErrorKind::UnexpectedEof) { + warn!("failed to receive message on tcp stream: {:?}", err); + } return; } } diff --git a/shotover/src/sources/cassandra.rs b/shotover/src/sources/cassandra.rs index ad82296cc..074c548bc 100644 --- a/shotover/src/sources/cassandra.rs +++ b/shotover/src/sources/cassandra.rs @@ -2,7 +2,7 @@ use crate::codec::Direction; use crate::codec::{cassandra::CassandraCodecBuilder, CodecBuilder}; use crate::server::TcpCodecListener; use crate::sources::Sources; -use crate::tlsls::{TlsAcceptor, TlsAcceptorConfig}; +use crate::tls::{TlsAcceptor, TlsAcceptorConfig}; use crate::transforms::chain::TransformChainBuilder; use anyhow::Result; use serde::Deserialize; diff --git a/shotover/src/sources/kafka.rs b/shotover/src/sources/kafka.rs index 398754219..7bc9117d6 100644 --- a/shotover/src/sources/kafka.rs +++ b/shotover/src/sources/kafka.rs @@ -1,7 +1,7 @@ use crate::codec::{kafka::KafkaCodecBuilder, CodecBuilder, Direction}; use crate::server::TcpCodecListener; use crate::sources::Sources; -use crate::tlsls::{TlsAcceptor, TlsAcceptorConfig}; +use crate::tls::{TlsAcceptor, TlsAcceptorConfig}; use crate::transforms::chain::TransformChainBuilder; use anyhow::Result; use serde::Deserialize; diff --git a/shotover/src/sources/redis.rs b/shotover/src/sources/redis.rs index 386d7b91c..a2b3f29f9 100644 --- a/shotover/src/sources/redis.rs +++ b/shotover/src/sources/redis.rs @@ -1,7 +1,7 @@ use crate::codec::{redis::RedisCodecBuilder, CodecBuilder, Direction}; use crate::server::TcpCodecListener; use crate::sources::Sources; -use crate::tlsls::{TlsAcceptor, TlsAcceptorConfig}; +use crate::tls::{TlsAcceptor, TlsAcceptorConfig}; use crate::transforms::chain::TransformChainBuilder; use anyhow::Result; use serde::Deserialize; diff --git a/shotover/src/tls.rs b/shotover/src/tls.rs index d9aa849f5..2efe9a2eb 100644 --- a/shotover/src/tls.rs +++ b/shotover/src/tls.rs @@ -1,17 +1,23 @@ use crate::tcp; -use anyhow::{anyhow, Error, Result}; -use openssl::ssl::{ErrorCode, Ssl}; -use openssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod}; +use anyhow::{anyhow, bail, Context, Error, Result}; +use rustls::client::{InvalidDnsNameError, ServerCertVerified, ServerCertVerifier, WebPkiVerifier}; +use rustls::server::{AllowAnyAuthenticatedClient, NoClientAuth}; +use rustls::{ + Certificate, CertificateError, OwnedTrustAnchor, PrivateKey, RootCertStore, ServerName, +}; +use rustls_pemfile::{certs, Item}; use serde::{Deserialize, Serialize}; -use std::fmt::Write; +use std::fs::File; +use std::io::{BufReader, ErrorKind}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; -use std::path::Path; -use std::pin::Pin; use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpStream, ToSocketAddrs}; -use tokio_openssl::SslStream; +use tokio_rustls::client::TlsStream as TlsStreamClient; +use tokio_rustls::server::TlsStream as TlsStreamServer; +use tokio_rustls::{TlsAcceptor as RustlsAcceptor, TlsConnector as RustlsConnector}; +use webpki::TrustAnchor; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct TlsAcceptorConfig { @@ -25,7 +31,7 @@ pub struct TlsAcceptorConfig { #[derive(Clone)] pub struct TlsAcceptor { - acceptor: Arc, + acceptor: RustlsAcceptor, } pub enum AcceptError { @@ -35,65 +41,73 @@ pub enum AcceptError { Failure(Error), } -pub fn check_file_field(field_name: &str, file_path: &str) -> Result<()> { - if Path::new(file_path).exists() { - Ok(()) - } else { - Err(anyhow!( - "configured {field_name} does not exist '{file_path}'" - )) - } +fn load_certs(path: &str) -> Result> { + certs(&mut BufReader::new(File::open(path)?)) + .context("Error while parsing PEM") + .map(|certs| certs.into_iter().map(Certificate).collect()) +} + +fn load_keys(path: &str) -> Result> { + rustls_pemfile::read_all(&mut BufReader::new(File::open(path)?)) + .context("Error while parsing PEM") + .map(|keys| { + keys.into_iter() + .filter_map(|item| match item { + Item::RSAKey(x) | Item::PKCS8Key(x) => Some(PrivateKey(x)), + _ => None, + }) + .collect() + }) } impl TlsAcceptor { pub fn new(tls_config: TlsAcceptorConfig) -> Result { - // openssl's errors are really bad so we do our own checks so we can provide reasonable errors - check_file_field("private_key_path", &tls_config.private_key_path)?; - check_file_field("certificate_path", &tls_config.certificate_path)?; - - let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()) - .map_err(openssl_stack_error_to_anyhow)?; - - if let Some(path) = tls_config.certificate_authority_path.as_ref() { - check_file_field("certificate_authority_path", path)?; - builder - .set_ca_file(path) - .map_err(openssl_stack_error_to_anyhow)?; - return Err(anyhow!("Client auth is not yet supported in shotover")); - } + let client_cert_verifier = + if let Some(path) = tls_config.certificate_authority_path.as_ref() { + let root_cert_store = load_ca(path).map_err(|err| { + anyhow!(err).context(format!( + "Failed to read file {path} configured at 'certificate_authority_path'" + )) + })?; + AllowAnyAuthenticatedClient::new(root_cert_store).boxed() + } else { + NoClientAuth::boxed() + }; + + let mut keys = load_keys(&tls_config.private_key_path).map_err(|err| { + anyhow!(err).context(format!( + "Failed to read file {} configured at 'private_key_path", + tls_config.private_key_path, + )) + })?; + let certs = load_certs(&tls_config.certificate_path).map_err(|err| { + anyhow!(err).context(format!( + "Failed to read file {} configured at 'certificate_path'", + tls_config.private_key_path, + )) + })?; - builder - .set_private_key_file(tls_config.private_key_path, SslFiletype::PEM) - .map_err(openssl_stack_error_to_anyhow)?; - builder - .set_certificate_chain_file(tls_config.certificate_path) - .map_err(openssl_stack_error_to_anyhow)?; - builder - .check_private_key() - .map_err(openssl_stack_error_to_anyhow)?; + let config = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_client_cert_verifier(client_cert_verifier) + .with_single_cert(certs, keys.remove(0))?; Ok(TlsAcceptor { - acceptor: Arc::new(builder.build()), + acceptor: RustlsAcceptor::from(Arc::new(config)), }) } - pub async fn accept(&self, tcp_stream: TcpStream) -> Result, AcceptError> { - let ssl = Ssl::new(self.acceptor.context()) - .map_err(|e| AcceptError::Failure(openssl_stack_error_to_anyhow(e)))?; - let mut ssl_stream = SslStream::new(ssl, tcp_stream) - .map_err(|e| AcceptError::Failure(openssl_stack_error_to_anyhow(e)))?; - - Pin::new(&mut ssl_stream).accept().await.map_err(|e| { - // This is the internal logic that results in the "unexpected EOF" error in the ssl::error::Error display impl - if e.code() == ErrorCode::SYSCALL && e.io_error().is_none() { - AcceptError::Disconnected - } else { - AcceptError::Failure( - openssl_ssl_error_to_anyhow(e).context("Failed to accept TLS connection"), - ) - } - })?; - Ok(ssl_stream) + pub async fn accept( + &self, + tcp_stream: TcpStream, + ) -> Result, AcceptError> { + self.acceptor + .accept(tcp_stream) + .await + .map_err(|err| match err.kind() { + ErrorKind::UnexpectedEof => AcceptError::Disconnected, + _ => AcceptError::Failure(anyhow!(err).context("Failed to accept TLS connection")), + }) } } @@ -109,41 +123,89 @@ pub struct TlsConnectorConfig { pub verify_hostname: bool, } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct TlsConnector { - connector: Arc, - verify_hostname: bool, + connector: RustlsConnector, +} + +fn load_ca(path: &str) -> Result { + let mut root_cert_store = RootCertStore::empty(); + + let mut pem = BufReader::new(File::open(path)?); + let certs = rustls_pemfile::certs(&mut pem).context("Error while parsing PEM")?; + let trust_anchors = certs.iter().map(|cert| { + let ta = TrustAnchor::try_from_cert_der(&cert[..]).unwrap(); + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }); + root_cert_store.add_server_trust_anchors(trust_anchors); + Ok(root_cert_store) } impl TlsConnector { pub fn new(tls_config: TlsConnectorConfig) -> Result { - check_file_field( - "certificate_authority_path", - &tls_config.certificate_authority_path, - )?; - let mut builder = - SslConnector::builder(SslMethod::tls()).map_err(openssl_stack_error_to_anyhow)?; - builder - .set_ca_file(tls_config.certificate_authority_path) - .map_err(openssl_stack_error_to_anyhow)?; - - if let Some(private_key_path) = tls_config.private_key_path { - check_file_field("private_key_path", &private_key_path)?; - builder - .set_private_key_file(private_key_path, SslFiletype::PEM) - .map_err(openssl_stack_error_to_anyhow)?; - } + let root_cert_store = load_ca(&tls_config.certificate_authority_path).map_err(|err| { + anyhow!(err).context(format!( + "Failed to read file {} configured at 'certificate_authority_path'", + tls_config.certificate_authority_path, + )) + })?; - if let Some(certificate_path) = tls_config.certificate_path { - check_file_field("certificate_path", &certificate_path)?; - builder - .set_certificate_chain_file(certificate_path) - .map_err(openssl_stack_error_to_anyhow)?; - } + let keys = tls_config + .private_key_path + .as_ref() + .map(|path| { + load_keys(path).map_err(|err| { + anyhow!(err).context(format!( + "Failed to read file {path} configured at 'private_key_path", + )) + }) + }) + .transpose()?; + let certs = tls_config + .certificate_path + .as_ref() + .map(|path| { + load_certs(path).map_err(|err| { + anyhow!(err).context(format!( + "Failed to read file {path} configured at 'certificate_path'", + )) + }) + }) + .transpose()?; + + let config_builder = rustls::ClientConfig::builder().with_safe_defaults(); + let config = match (keys, certs, tls_config.verify_hostname) { + (Some(mut keys), Some(certs), true) => config_builder + .with_root_certificates(root_cert_store) + .with_single_cert(certs, keys.remove(0))?, + (Some(mut keys), Some(certs), false) => config_builder + .with_custom_certificate_verifier(Arc::new(SkipVerifyHostName::new( + root_cert_store, + ))) + .with_single_cert(certs, keys.remove(0))?, + (None, None, true) => config_builder + .with_root_certificates(root_cert_store) + .with_no_client_auth(), + (None, None, false) => config_builder + .with_custom_certificate_verifier(Arc::new(SkipVerifyHostName::new( + root_cert_store, + ))) + .with_no_client_auth(), + + (Some(_), None, _) => { + bail!("private_key_path was specified but certificate_path was not") + } + (None, Some(_), _) => { + bail!("certificate_path was specified but private_key_path was not") + } + }; Ok(TlsConnector { - connector: Arc::new(builder.build()), - verify_hostname: tls_config.verify_hostname, + connector: RustlsConnector::from(Arc::new(config)), }) } @@ -151,76 +213,58 @@ impl TlsConnector { &self, connect_timeout: Duration, address: A, - ) -> Result> { - let ssl = self - .connector - .configure() - .map_err(openssl_stack_error_to_anyhow)? - .verify_hostname(self.verify_hostname) - .into_ssl(&address.to_hostname()) - .map_err(openssl_stack_error_to_anyhow)?; - + ) -> Result> { + let servername = address.to_servername()?; let tcp_stream = tcp::tcp_stream(connect_timeout, address).await?; - let mut ssl_stream = - SslStream::new(ssl, tcp_stream).map_err(openssl_stack_error_to_anyhow)?; - Pin::new(&mut ssl_stream).connect().await.map_err(|e| { - openssl_ssl_error_to_anyhow(e) - .context("Failed to establish TLS connection to destination") - })?; - - Ok(ssl_stream) + self.connector + .connect(servername, tcp_stream) + .await + .map_err(|e| anyhow!("{e:#?}")) + .context("Failed to establish TLS connection to destination") } } -// Always use these openssl_* conversion methods instead of directly converting to anyhow - -fn openssl_ssl_error_to_anyhow(error: openssl::ssl::Error) -> anyhow::Error { - if let Some(stack) = error.ssl_error() { - openssl_stack_error_to_anyhow(stack.clone()) - } else { - anyhow!("{error}") - } +pub struct SkipVerifyHostName { + verifier: WebPkiVerifier, } -fn openssl_stack_error_to_anyhow(error: openssl::error::ErrorStack) -> anyhow::Error { - let mut anyhow_stack: Option = None; - for inner in error.errors() { - let anyhow_error = openssl_error_to_anyhow(inner.clone()); - anyhow_stack = Some(match anyhow_stack { - Some(anyhow) => anyhow.context(anyhow_error), - None => anyhow_error, - }); - } - match anyhow_stack { - Some(anyhow_stack) => anyhow_stack, - None => anyhow!("{error}"), +impl SkipVerifyHostName { + pub fn new(roots: RootCertStore) -> Self { + SkipVerifyHostName { + verifier: WebPkiVerifier::new(roots, None), + } } } -fn openssl_error_to_anyhow(error: openssl::error::Error) -> anyhow::Error { - let mut fmt = String::new(); - write!(fmt, "error 0x{:08X} ", error.code()).unwrap(); - match error.reason() { - Some(r) => write!(fmt, "'{}", r).unwrap(), - None => write!(fmt, "'Unknown'").unwrap(), - } - if let Some(data) = error.data() { - write!(fmt, ": {}' ", data).unwrap(); - } else { - write!(fmt, "' ").unwrap(); - } - write!(fmt, "occurred in ").unwrap(); - match error.function() { - Some(f) => write!(fmt, "function '{}' ", f).unwrap(), - None => write!(fmt, "function 'Unknown' ").unwrap(), - } - write!(fmt, "in file '{}:{}' ", error.file(), error.line()).unwrap(); - match error.library() { - Some(l) => write!(fmt, "in library '{}'", l).unwrap(), - None => write!(fmt, "in library 'Unknown'").unwrap(), +// This recreates the verify_hostname(false) functionality from openssl. +// This adds an opening for MitM attacks but we provide this functionality because there are some +// circumstances where providing a cert per instance in a cluster is difficult and this allows at least some security by sharing a single cert between all instances. +// Note that the SAN dnsname wildcards (e.g. *foo.com) wouldnt help here because we need to refer to destinations by ip address and there is no such wildcard functionality for ip addresses. +impl ServerCertVerifier for SkipVerifyHostName { + fn verify_server_cert( + &self, + end_entity: &Certificate, + intermediates: &[Certificate], + server_name: &ServerName, + scts: &mut dyn Iterator, + ocsp_response: &[u8], + now: std::time::SystemTime, + ) -> std::result::Result { + match self.verifier.verify_server_cert( + end_entity, + intermediates, + server_name, + scts, + ocsp_response, + now, + ) { + Ok(result) => Ok(result), + Err(rustls::Error::InvalidCertificate(CertificateError::NotValidForName)) => { + Ok(ServerCertVerified::assertion()) + } + Err(err) => Err(err), + } } - - anyhow!(fmt) } /// A trait object can only consist of one trait + special language traits like Send/Sync etc @@ -228,12 +272,14 @@ fn openssl_error_to_anyhow(error: openssl::error::Error) -> anyhow::Error { pub trait AsyncStream: AsyncRead + AsyncWrite {} /// We need to tell rust that these types implement AsyncStream even though they already implement AsyncRead and AsyncWrite -impl AsyncStream for tokio_openssl::SslStream {} +impl AsyncStream for TlsStreamClient {} +impl AsyncStream for TlsStreamServer {} impl AsyncStream for TcpStream {} /// Allows retrieving the hostname from any ToSocketAddrs type pub trait ToHostname { fn to_hostname(&self) -> String; + fn to_servername(&self) -> Result; } /// Implement for all reference types @@ -241,52 +287,79 @@ impl ToHostname for &T { fn to_hostname(&self) -> String { (**self).to_hostname() } + fn to_servername(&self) -> Result { + (**self).to_servername() + } } impl ToHostname for String { fn to_hostname(&self) -> String { self.split(':').next().unwrap_or("").to_owned() } + fn to_servername(&self) -> Result { + ServerName::try_from(self.to_hostname().as_str()) + } } impl ToHostname for &str { fn to_hostname(&self) -> String { self.split(':').next().unwrap_or("").to_owned() } + fn to_servername(&self) -> Result { + ServerName::try_from(self.split(':').next().unwrap_or("")) + } } impl ToHostname for (&str, u16) { fn to_hostname(&self) -> String { self.0.to_string() } + fn to_servername(&self) -> Result { + ServerName::try_from(self.0) + } } impl ToHostname for (String, u16) { fn to_hostname(&self) -> String { self.0.to_string() } + fn to_servername(&self) -> Result { + ServerName::try_from(self.0.as_str()) + } } impl ToHostname for (IpAddr, u16) { fn to_hostname(&self) -> String { self.0.to_string() } + fn to_servername(&self) -> Result { + Ok(ServerName::IpAddress(self.0)) + } } impl ToHostname for (Ipv4Addr, u16) { fn to_hostname(&self) -> String { self.0.to_string() } + fn to_servername(&self) -> Result { + Ok(ServerName::IpAddress(IpAddr::V4(self.0))) + } } impl ToHostname for (Ipv6Addr, u16) { fn to_hostname(&self) -> String { self.0.to_string() } + fn to_servername(&self) -> Result { + Ok(ServerName::IpAddress(IpAddr::V6(self.0))) + } } impl ToHostname for SocketAddr { fn to_hostname(&self) -> String { self.ip().to_string() } + fn to_servername(&self) -> Result { + Ok(ServerName::IpAddress(self.ip())) + } } diff --git a/shotover/src/tlsls.rs b/shotover/src/tlsls.rs deleted file mode 100644 index 01499e272..000000000 --- a/shotover/src/tlsls.rs +++ /dev/null @@ -1,319 +0,0 @@ -use crate::tcp; -use anyhow::{anyhow, bail, Context, Error, Result}; -use rustls::client::{InvalidDnsNameError, ServerCertVerified, ServerCertVerifier}; -use rustls::{Certificate, OwnedTrustAnchor, PrivateKey, RootCertStore, ServerName}; -use rustls_pemfile::{certs, Item}; -use serde::{Deserialize, Serialize}; -use std::fs::File; -use std::io::{BufReader, ErrorKind}; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; -use std::sync::Arc; -use std::time::Duration; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::net::{TcpStream, ToSocketAddrs}; -use tokio_rustls::client::TlsStream as TlsStreamClient; -use tokio_rustls::server::TlsStream as TlsStreamServer; -use tokio_rustls::{TlsAcceptor as RustlsAcceptor, TlsConnector as RustlsConnector}; -use webpki::TrustAnchor; - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct TlsAcceptorConfig { - /// Path to the certificate authority in PEM format - pub certificate_authority_path: String, - /// Path to the certificate in PEM format - pub certificate_path: String, - /// Path to the private key in PEM format - pub private_key_path: String, -} - -#[derive(Clone)] -pub struct TlsAcceptor { - acceptor: RustlsAcceptor, -} - -pub enum AcceptError { - /// The client decided it didnt need the connection anymore and politely disconnected before the handshake completed. - /// This can occur during regular use and indicates the connection should be quietly discarded. - Disconnected, - Failure(Error), -} - -fn load_certs(path: &str) -> Result> { - certs(&mut BufReader::new(File::open(path)?)) - .context("Error while parsing PEM") - .map(|certs| certs.into_iter().map(Certificate).collect()) -} - -fn load_keys(path: &str) -> Result> { - rustls_pemfile::read_all(&mut BufReader::new(File::open(path)?)) - .context("Error while parsing PEM") - .map(|keys| { - keys.into_iter() - .filter_map(|item| match item { - Item::RSAKey(x) | Item::PKCS8Key(x) => Some(PrivateKey(x)), - _ => None, - }) - .collect() - }) -} - -impl TlsAcceptor { - pub fn new(tls_config: TlsAcceptorConfig) -> Result { - let mut keys = load_keys(&tls_config.private_key_path).map_err(|err| { - anyhow!(err).context(format!( - "Failed to read file {} configured at 'private_key_path", - tls_config.private_key_path, - )) - })?; - let certs = load_certs(&tls_config.certificate_path).map_err(|err| { - anyhow!(err).context(format!( - "Failed to read file {} configured at 'certificate_path'", - tls_config.private_key_path, - )) - })?; - - let config = rustls::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(certs, keys.remove(0))?; - - Ok(TlsAcceptor { - acceptor: RustlsAcceptor::from(Arc::new(config)), - }) - } - - pub async fn accept( - &self, - tcp_stream: TcpStream, - ) -> Result, AcceptError> { - self.acceptor - .accept(tcp_stream) - .await - .map_err(|err| match err.kind() { - ErrorKind::UnexpectedEof => AcceptError::Disconnected, - _ => AcceptError::Failure(anyhow!(err).context("Failed to accept TLS connection")), - }) - } -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct TlsConnectorConfig { - /// Path to the certificate authority in PEM format - pub certificate_authority_path: String, - /// Path to the certificate in PEM format - pub certificate_path: Option, - /// Path to the private key in PEM format - pub private_key_path: Option, - /// enable/disable verifying the hostname of the destination's certificate. - pub verify_hostname: bool, -} - -#[derive(Clone)] -pub struct TlsConnector { - connector: RustlsConnector, -} - -fn load_ca(path: &str) -> Result { - let mut root_cert_store = RootCertStore::empty(); - - let mut pem = BufReader::new(File::open(path)?); - let certs = rustls_pemfile::certs(&mut pem).context("Error while parsing PEM")?; - let trust_anchors = certs.iter().map(|cert| { - let ta = TrustAnchor::try_from_cert_der(&cert[..]).unwrap(); - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - }); - root_cert_store.add_server_trust_anchors(trust_anchors); - Ok(root_cert_store) -} - -impl TlsConnector { - pub fn new(tls_config: TlsConnectorConfig) -> Result { - let root_cert_store = load_ca(&tls_config.certificate_authority_path).map_err(|err| { - anyhow!(err).context(format!( - "Failed to read file {} configured at 'certificate_authority_path'", - tls_config.certificate_authority_path, - )) - })?; - - let keys = tls_config - .private_key_path - .as_ref() - .map(|path| { - load_keys(path).map_err(|err| { - anyhow!(err).context(format!( - "Failed to read file {path} configured at 'private_key_path", - )) - }) - }) - .transpose()?; - let certs = tls_config - .certificate_path - .as_ref() - .map(|path| { - load_certs(path).map_err(|err| { - anyhow!(err).context(format!( - "Failed to read file {path} configured at 'certificate_path'", - )) - }) - }) - .transpose()?; - - let config_builder = rustls::ClientConfig::builder().with_safe_defaults(); - let config = match (keys, certs, tls_config.verify_hostname) { - (Some(mut keys), Some(certs), true) => config_builder - .with_root_certificates(root_cert_store) - .with_single_cert(certs, keys.remove(0))?, - (Some(mut keys), Some(certs), false) => config_builder - .with_custom_certificate_verifier(Arc::new(SkipVerifyHostName)) - .with_single_cert(certs, keys.remove(0))?, - (None, None, true) => config_builder - .with_root_certificates(root_cert_store) - .with_no_client_auth(), - (None, None, false) => config_builder - .with_custom_certificate_verifier(Arc::new(SkipVerifyHostName)) - .with_no_client_auth(), - - (Some(_), None, _) => { - bail!("private_key_path was specified but certificate_path was not") - } - (None, Some(_), _) => { - bail!("certificate_path was specified but private_key_path was not") - } - }; - - Ok(TlsConnector { - connector: RustlsConnector::from(Arc::new(config)), - }) - } - - pub async fn connect( - &self, - connect_timeout: Duration, - address: A, - ) -> Result> { - let servername = address.to_servername()?; - let tcp_stream = tcp::tcp_stream(connect_timeout, address).await?; - self.connector - .connect(servername, tcp_stream) - .await - .map_err(|e| anyhow!("{e:#?}")) - .context("Failed to establish TLS connection to destination") - } -} - -struct SkipVerifyHostName; -impl ServerCertVerifier for SkipVerifyHostName { - fn verify_server_cert( - &self, - _end_entity: &Certificate, - _intermediates: &[Certificate], - _server_name: &ServerName, - _scts: &mut dyn Iterator, - _ocsp_response: &[u8], - _now: std::time::SystemTime, - ) -> std::result::Result { - // TLS is added and removed here :) - Ok(ServerCertVerified::assertion()) - } -} - -/// A trait object can only consist of one trait + special language traits like Send/Sync etc -/// So we need to use this trait when creating trait objects that need both AsyncRead and AsyncWrite -pub trait AsyncStream: AsyncRead + AsyncWrite {} - -/// We need to tell rust that these types implement AsyncStream even though they already implement AsyncRead and AsyncWrite -impl AsyncStream for TlsStreamClient {} -impl AsyncStream for TlsStreamServer {} -impl AsyncStream for TcpStream {} - -/// Allows retrieving the hostname from any ToSocketAddrs type -pub trait ToHostname { - fn to_hostname(&self) -> String; - fn to_servername(&self) -> Result; -} - -/// Implement for all reference types -impl ToHostname for &T { - fn to_hostname(&self) -> String { - (**self).to_hostname() - } - fn to_servername(&self) -> Result { - (**self).to_servername() - } -} - -impl ToHostname for String { - fn to_hostname(&self) -> String { - self.split(':').next().unwrap_or("").to_owned() - } - fn to_servername(&self) -> Result { - ServerName::try_from(self.to_hostname().as_str()) - } -} - -impl ToHostname for &str { - fn to_hostname(&self) -> String { - self.split(':').next().unwrap_or("").to_owned() - } - fn to_servername(&self) -> Result { - ServerName::try_from(self.split(':').next().unwrap_or("")) - } -} - -impl ToHostname for (&str, u16) { - fn to_hostname(&self) -> String { - self.0.to_string() - } - fn to_servername(&self) -> Result { - ServerName::try_from(self.0) - } -} - -impl ToHostname for (String, u16) { - fn to_hostname(&self) -> String { - self.0.to_string() - } - fn to_servername(&self) -> Result { - ServerName::try_from(self.0.as_str()) - } -} - -impl ToHostname for (IpAddr, u16) { - fn to_hostname(&self) -> String { - self.0.to_string() - } - fn to_servername(&self) -> Result { - Ok(ServerName::IpAddress(self.0)) - } -} - -impl ToHostname for (Ipv4Addr, u16) { - fn to_hostname(&self) -> String { - self.0.to_string() - } - fn to_servername(&self) -> Result { - Ok(ServerName::IpAddress(IpAddr::V4(self.0))) - } -} - -impl ToHostname for (Ipv6Addr, u16) { - fn to_hostname(&self) -> String { - self.0.to_string() - } - fn to_servername(&self) -> Result { - Ok(ServerName::IpAddress(IpAddr::V6(self.0))) - } -} - -impl ToHostname for SocketAddr { - fn to_hostname(&self) -> String { - self.ip().to_string() - } - fn to_servername(&self) -> Result { - Ok(ServerName::IpAddress(self.ip())) - } -} diff --git a/shotover/src/transforms/cassandra/connection.rs b/shotover/src/transforms/cassandra/connection.rs index 8940841b0..03a12406e 100644 --- a/shotover/src/transforms/cassandra/connection.rs +++ b/shotover/src/transforms/cassandra/connection.rs @@ -4,7 +4,7 @@ use crate::frame::cassandra::CassandraMetadata; use crate::frame::{CassandraFrame, Frame}; use crate::message::{Message, Metadata}; use crate::tcp; -use crate::tlsls::{TlsConnector, ToHostname}; +use crate::tls::{TlsConnector, ToHostname}; use crate::transforms::Messages; use anyhow::{anyhow, Result}; use cassandra_protocol::frame::{Opcode, Version}; diff --git a/shotover/src/transforms/cassandra/sink_cluster/mod.rs b/shotover/src/transforms/cassandra/sink_cluster/mod.rs index f6fd4110c..20e7c5031 100644 --- a/shotover/src/transforms/cassandra/sink_cluster/mod.rs +++ b/shotover/src/transforms/cassandra/sink_cluster/mod.rs @@ -4,7 +4,7 @@ use crate::error::ChainResponse; use crate::frame::cassandra::{CassandraMetadata, Tracing}; use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; use crate::message::{Message, Messages, Metadata}; -use crate::tlsls::{TlsConnector, TlsConnectorConfig}; +use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::cassandra::connection::{CassandraConnection, Response, ResponseError}; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; use anyhow::{anyhow, Result}; diff --git a/shotover/src/transforms/cassandra/sink_cluster/node.rs b/shotover/src/transforms/cassandra/sink_cluster/node.rs index 5e2379c22..8008d9b86 100644 --- a/shotover/src/transforms/cassandra/sink_cluster/node.rs +++ b/shotover/src/transforms/cassandra/sink_cluster/node.rs @@ -2,7 +2,7 @@ use crate::codec::cassandra::CassandraCodecBuilder; use crate::codec::{CodecBuilder, Direction}; use crate::frame::Frame; use crate::message::{Message, Messages}; -use crate::tlsls::{TlsConnector, ToHostname}; +use crate::tls::{TlsConnector, ToHostname}; use crate::transforms::cassandra::connection::CassandraConnection; use anyhow::{anyhow, Result}; use cassandra_protocol::frame::Version; diff --git a/shotover/src/transforms/cassandra/sink_single.rs b/shotover/src/transforms/cassandra/sink_single.rs index 9d028025f..4c78eac4f 100644 --- a/shotover/src/transforms/cassandra/sink_single.rs +++ b/shotover/src/transforms/cassandra/sink_single.rs @@ -3,7 +3,7 @@ use crate::codec::{cassandra::CassandraCodecBuilder, CodecBuilder, Direction}; use crate::error::ChainResponse; use crate::frame::cassandra::CassandraMetadata; use crate::message::{Messages, Metadata}; -use crate::tlsls::{TlsConnector, TlsConnectorConfig}; +use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::cassandra::connection::Response; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; use anyhow::{anyhow, Result}; diff --git a/shotover/src/transforms/redis/sink_cluster.rs b/shotover/src/transforms/redis/sink_cluster.rs index 4970750f6..c795f0a35 100644 --- a/shotover/src/transforms/redis/sink_cluster.rs +++ b/shotover/src/transforms/redis/sink_cluster.rs @@ -3,7 +3,7 @@ use crate::codec::{CodecBuilder, Direction}; use crate::error::ChainResponse; use crate::frame::{Frame, RedisFrame}; use crate::message::Message; -use crate::tlsls::TlsConnectorConfig; +use crate::tls::TlsConnectorConfig; use crate::transforms::redis::RedisError; use crate::transforms::redis::TransformError; use crate::transforms::util::cluster_connection_pool::{Authenticator, ConnectionPool}; diff --git a/shotover/src/transforms/redis/sink_single.rs b/shotover/src/transforms/redis/sink_single.rs index 9257b7420..a70d90563 100644 --- a/shotover/src/transforms/redis/sink_single.rs +++ b/shotover/src/transforms/redis/sink_single.rs @@ -5,7 +5,7 @@ use crate::codec::{ use crate::frame::{Frame, RedisFrame}; use crate::message::{Message, Messages}; use crate::tcp; -use crate::tlsls::{AsyncStream, TlsConnector, TlsConnectorConfig}; +use crate::tls::{AsyncStream, TlsConnector, TlsConnectorConfig}; use crate::transforms::{ ChainResponse, Transform, TransformBuilder, TransformConfig, Transforms, Wrapper, }; diff --git a/shotover/src/transforms/util/cluster_connection_pool.rs b/shotover/src/transforms/util/cluster_connection_pool.rs index bc315e27a..85a960986 100644 --- a/shotover/src/transforms/util/cluster_connection_pool.rs +++ b/shotover/src/transforms/util/cluster_connection_pool.rs @@ -1,7 +1,7 @@ use super::Response; use crate::codec::{CodecBuilder, DecoderHalf, EncoderHalf}; use crate::tcp; -use crate::tlsls::{TlsConnector, TlsConnectorConfig}; +use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::util::{ConnectionError, Request}; use anyhow::{anyhow, Result}; use async_trait::async_trait; diff --git a/test-helpers/src/connection/redis_connection.rs b/test-helpers/src/connection/redis_connection.rs index 57e06ea0d..b44644c1c 100644 --- a/test-helpers/src/connection/redis_connection.rs +++ b/test-helpers/src/connection/redis_connection.rs @@ -55,8 +55,6 @@ pub async fn new_async_tls(port: u16) -> redis::aio::Connection { .configure() .unwrap() .verify_hostname(false) - // really upstream should deal with this for us but for now we can easily just disable it ourselves https://github.com/sfackler/rust-openssl/issues/1860 - .use_server_name_indication(false) .into_ssl("127.0.0.1") .unwrap();