Skip to content

Commit

Permalink
Map upstream tls errors into a better format (#891)
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai authored Nov 1, 2022
1 parent a11207f commit 0b6f7c3
Showing 1 changed file with 106 additions and 29 deletions.
135 changes: 106 additions & 29 deletions shotover-proxy/src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use anyhow::{anyhow, Result};
use openssl::ssl::Ssl;
use openssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod};
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
Expand Down Expand Up @@ -44,25 +45,34 @@ impl TlsAcceptor {
check_file_field("private_key_path", &tls_config.private_key_path)?;
check_file_field("certificate_path", &tls_config.certificate_path)?;

let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls())?;
builder.set_ca_file(tls_config.certificate_authority_path)?;
builder.set_private_key_file(tls_config.private_key_path, SslFiletype::PEM)?;
builder.set_certificate_chain_file(tls_config.certificate_path)?;
builder.check_private_key()?;
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls())
.map_err(openssl_stack_error_to_anyhow)?;
builder
.set_ca_file(tls_config.certificate_authority_path)
.map_err(openssl_stack_error_to_anyhow)?;
builder
.set_private_key_file(tls_config.private_key_path, SslFiletype::PEM)
.map_err(openssl_stack_error_to_anyhow)?;
builder
.set_certificate_chain_file(tls_config.certificate_path)
.map_err(openssl_stack_error_to_anyhow)?;
builder
.check_private_key()
.map_err(openssl_stack_error_to_anyhow)?;

Ok(TlsAcceptor {
acceptor: Arc::new(builder.build()),
})
}

pub async fn accept(&self, tcp_stream: TcpStream) -> Result<SslStream<TcpStream>> {
let ssl = Ssl::new(self.acceptor.context())?;
let mut ssl_stream = SslStream::new(ssl, tcp_stream)?;
let ssl = Ssl::new(self.acceptor.context()).map_err(openssl_stack_error_to_anyhow)?;
let mut ssl_stream =
SslStream::new(ssl, tcp_stream).map_err(openssl_stack_error_to_anyhow)?;

Pin::new(&mut ssl_stream)
.accept()
.await
.map_err(|e| anyhow!(e).context("Failed to accept TLS connection"))?;
Pin::new(&mut ssl_stream).accept().await.map_err(|e| {
openssl_ssl_error_to_anyhow(e).context("Failed to accept TLS connection")
})?;
Ok(ssl_stream)
}
}
Expand All @@ -88,17 +98,24 @@ impl TlsConnector {
"certificate_authority_path",
&tls_config.certificate_authority_path,
)?;
let mut builder = SslConnector::builder(SslMethod::tls())?;
builder.set_ca_file(tls_config.certificate_authority_path)?;
let mut builder =
SslConnector::builder(SslMethod::tls()).map_err(openssl_stack_error_to_anyhow)?;
builder
.set_ca_file(tls_config.certificate_authority_path)
.map_err(openssl_stack_error_to_anyhow)?;

if let Some(private_key_path) = tls_config.private_key_path {
check_file_field("private_key_path", &private_key_path)?;
builder.set_private_key_file(private_key_path, SslFiletype::PEM)?;
builder
.set_private_key_file(private_key_path, SslFiletype::PEM)
.map_err(openssl_stack_error_to_anyhow)?;
}

if let Some(certificate_path) = tls_config.certificate_path {
check_file_field("certificate_path", &certificate_path)?;
builder.set_certificate_chain_file(certificate_path)?;
builder
.set_certificate_chain_file(certificate_path)
.map_err(openssl_stack_error_to_anyhow)?;
}

Ok(TlsConnector {
Expand All @@ -112,32 +129,92 @@ impl TlsConnector {
) -> Result<SslStream<TcpStream>> {
let ssl = self
.connector
.configure()?
.configure()
.map_err(openssl_stack_error_to_anyhow)?
.verify_hostname(false)
.into_ssl("localhost")?;
.into_ssl("localhost")
.map_err(openssl_stack_error_to_anyhow)?;

let mut ssl_stream = SslStream::new(ssl, tcp_stream)?;
Pin::new(&mut ssl_stream)
.connect()
.await
.map_err(|e| anyhow!(e).context("Failed to establish TLS connection to destination"))?;
let mut ssl_stream =
SslStream::new(ssl, tcp_stream).map_err(openssl_stack_error_to_anyhow)?;
Pin::new(&mut ssl_stream).connect().await.map_err(|e| {
openssl_ssl_error_to_anyhow(e)
.context("Failed to establish TLS connection to destination")
})?;

Ok(ssl_stream)
}

pub async fn connect(&self, tcp_stream: TcpStream) -> Result<SslStream<TcpStream>> {
let ssl = self.connector.configure()?.into_ssl("localhost")?;

let mut ssl_stream = SslStream::new(ssl, tcp_stream)?;
Pin::new(&mut ssl_stream)
.connect()
.await
.map_err(|e| anyhow!(e).context("Failed to establish TLS connection to destination"))?;
let ssl = self
.connector
.configure()
.map_err(openssl_stack_error_to_anyhow)?
.into_ssl("localhost")
.map_err(openssl_stack_error_to_anyhow)?;

let mut ssl_stream =
SslStream::new(ssl, tcp_stream).map_err(openssl_stack_error_to_anyhow)?;
Pin::new(&mut ssl_stream).connect().await.map_err(|e| {
openssl_ssl_error_to_anyhow(e)
.context("Failed to establish TLS connection to destination")
})?;

Ok(ssl_stream)
}
}

// Always use these openssl_* conversion methods instead of directly directly converting to anyhow

fn openssl_ssl_error_to_anyhow(error: openssl::ssl::Error) -> anyhow::Error {
if let Some(stack) = error.ssl_error() {
openssl_stack_error_to_anyhow(stack.clone())
} else {
anyhow!("{error}")
}
}

fn openssl_stack_error_to_anyhow(error: openssl::error::ErrorStack) -> anyhow::Error {
let mut anyhow_stack: Option<anyhow::Error> = None;
for inner in error.errors() {
let anyhow_error = openssl_error_to_anyhow(inner.clone());
anyhow_stack = Some(match anyhow_stack {
Some(anyhow) => anyhow.context(anyhow_error),
None => anyhow_error,
});
}
match anyhow_stack {
Some(anyhow_stack) => anyhow_stack,
None => anyhow!("{error}"),
}
}

fn openssl_error_to_anyhow(error: openssl::error::Error) -> anyhow::Error {
let mut fmt = String::new();
write!(fmt, "error 0x{:08X} ", error.code()).unwrap();
match error.reason() {
Some(r) => write!(fmt, "'{}", r).unwrap(),
None => write!(fmt, "'Unknown'").unwrap(),
}
if let Some(data) = error.data() {
write!(fmt, ": {}' ", data).unwrap();
} else {
write!(fmt, "' ").unwrap();
}
write!(fmt, "occurred in ").unwrap();
match error.function() {
Some(f) => write!(fmt, "function '{}' ", f).unwrap(),
None => write!(fmt, "function 'Unknown' ").unwrap(),
}
write!(fmt, "in file '{}:{}' ", error.file(), error.line()).unwrap();
match error.library() {
Some(l) => write!(fmt, "in library '{}'", l).unwrap(),
None => write!(fmt, "in library 'Unknown'").unwrap(),
}

anyhow!(fmt)
}

/// A trait object can only consist of one trait + special language traits like Send/Sync etc
/// So we need to use this trait when creating trait objects that need both AsyncRead and AsyncWrite
pub trait AsyncStream: AsyncRead + AsyncWrite {}
Expand Down

0 comments on commit 0b6f7c3

Please sign in to comment.