Skip to content

Commit

Permalink
fix: remove indirections in protocols:tls:TlsStream
Browse files Browse the repository at this point in the history
  • Loading branch information
hargut committed Sep 11, 2024
1 parent b84ad21 commit 90a823a
Show file tree
Hide file tree
Showing 9 changed files with 346 additions and 523 deletions.
2 changes: 1 addition & 1 deletion pingora-core/src/protocols/tls/boringssl_openssl/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
use pingora_error::{Error, ErrorType::*, OrErr, Result};

use crate::protocols::tls::boringssl_openssl::TlsStream;
use crate::protocols::tls::TlsStream;
use crate::protocols::IO;
use crate::tls::ssl::ConnectConfiguration;

Expand Down
153 changes: 0 additions & 153 deletions pingora-core/src/protocols/tls/boringssl_openssl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,159 +13,6 @@
// limitations under the License.

//! BoringSSL & OpenSSL TLS specific implementation
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf};

use pingora_error::ErrorType::TLSHandshakeFailure;
use pingora_error::{OrErr, Result};

use crate::protocols::tls::boringssl_openssl::stream::InnerStream;
use crate::protocols::tls::SslDigest;
use crate::protocols::{Ssl, UniqueID, ALPN};
use crate::tls::hash::MessageDigest;
use crate::tls::ssl;
use crate::tls::ssl::SslRef;
use crate::utils::tls::boringssl_openssl::{get_x509_organization, get_x509_serial};

use super::TlsStream;

pub mod client;
pub mod server;
pub(super) mod stream;

impl<T> TlsStream<T>
where
T: AsyncRead + AsyncWrite + Unpin + Send,
{
/// Create a new TLS connection from the given `stream`
///
/// The caller needs to perform [`Self::connect()`] or [`Self::accept()`] to perform TLS
/// handshake after.
pub fn new(ssl: ssl::Ssl, stream: T) -> Result<Self> {
let tls = InnerStream::new(ssl, stream)
.explain_err(TLSHandshakeFailure, |e| format!("tls stream error: {e}"))?;
Ok(TlsStream {
tls,
digest: None,
timing: Default::default(),
})
}
}

impl<T> AsyncRead for TlsStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Self::clear_error(&self);
Pin::new(&mut self.tls.0).poll_read(cx, buf)
}
}

impl<T: AsyncRead + AsyncWrite + Unpin> TlsStream<T> {
#[inline]
fn clear_error(&self) {
InnerStream::<T>::clear_error()
}
}

impl<T> AsyncWrite for TlsStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Self::clear_error(&self);
Pin::new(&mut self.tls.0).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Self::clear_error(&self);
Pin::new(&mut self.tls.0).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Self::clear_error(&self);
Pin::new(&mut self.tls.0).poll_shutdown(cx)
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Self::clear_error(&self);
Pin::new(&mut self.tls.0).poll_write_vectored(cx, bufs)
}

fn is_write_vectored(&self) -> bool {
true
}
}

impl<T> UniqueID for TlsStream<T>
where
T: UniqueID,
{
fn id(&self) -> i32 {
self.tls.0.get_ref().id()
}
}

impl<T> Ssl for TlsStream<T> {
fn get_ssl(&self) -> Option<&ssl::SslRef> {
Some(self.tls.0.ssl())
}

fn get_ssl_digest(&self) -> Option<Arc<SslDigest>> {
self.ssl_digest()
}

fn selected_alpn_proto(&self) -> Option<ALPN> {
let ssl = self.tls.0.ssl();
ALPN::from_wire_selected(ssl.selected_alpn_protocol()?)
}
}

impl SslDigest {
pub fn from_ssl(ssl: &SslRef) -> Self {
let cipher = match ssl.current_cipher() {
Some(c) => c.name(),
None => "",
};

let (cert_digest, org, sn) = match ssl.peer_certificate() {
Some(cert) => {
let cert_digest = match cert.digest(MessageDigest::sha256()) {
Ok(c) => c.as_ref().to_vec(),
Err(_) => Vec::new(),
};
(
cert_digest,
get_x509_organization(&cert),
get_x509_serial(&cert).ok(),
)
}
None => (Vec::new(), None, None),
};

SslDigest {
cipher,
version: ssl.version_str(),
organization: org,
serial_number: sn,
cert_digest,
}
}
}
9 changes: 5 additions & 4 deletions pingora-core/src/protocols/tls/boringssl_openssl/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@

//! BoringSSL & OpenSSL TLS server specific implementation
use crate::protocols::Ssl;
use std::pin::Pin;

use async_trait::async_trait;
use pingora_error::ErrorType::{TLSHandshakeFailure, TLSWantX509Lookup};
use pingora_error::{OrErr, Result};
use tokio::io::{AsyncRead, AsyncWrite};

use crate::protocols::tls::boringssl_openssl::TlsStream;
use crate::protocols::tls::server::{ResumableAccept, TlsAcceptCallbacks};
use crate::protocols::{Ssl, IO};
use crate::protocols::tls::TlsStream;
use crate::protocols::IO;
use crate::tls::ext;
use crate::tls::ext::ssl_from_acceptor;
use crate::tls::ssl::SslAcceptor;
Expand Down Expand Up @@ -57,7 +58,7 @@ impl<S: AsyncRead + AsyncWrite + Send + Unpin> ResumableAccept for TlsStream<S>
}

fn prepare_tls_stream<S: IO>(acceptor: &SslAcceptor, io: S) -> Result<TlsStream<S>> {
let ssl = ssl_from_acceptor(&acceptor)
let ssl = ssl_from_acceptor(acceptor)
.explain_err(TLSHandshakeFailure, |e| format!("ssl_acceptor error: {e}"))?;
TlsStream::new(ssl, io).explain_err(TLSHandshakeFailure, |e| format!("tls stream error: {e}"))
}
Expand All @@ -82,7 +83,7 @@ pub(crate) async fn handshake_with_callback<S: IO>(
let done = Pin::new(&mut tls_stream).start_accept().await?;
if !done {
// safety: we do hold a mut ref of tls_stream
let ssl_mut = unsafe { ext::ssl_mut(tls_stream.0.ssl()) };
let ssl_mut = unsafe { ext::ssl_mut(tls_stream.stream.ssl()) };
callbacks.certificate_callback(ssl_mut).await;
Pin::new(&mut tls_stream)
.resume_accept()
Expand Down
Loading

0 comments on commit 90a823a

Please sign in to comment.