Skip to content

Commit

Permalink
unify+improve errors: Limit, ASN & TLS
Browse files Browse the repository at this point in the history
Closes  #318
  • Loading branch information
GlenDC committed Sep 17, 2024
1 parent 4a016b0 commit 5b9fd17
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 173 deletions.
14 changes: 3 additions & 11 deletions rama-core/src/layer/limit/policy/concurrent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,19 +141,11 @@ where
}
}

/// The error that indicates the request is aborted,
/// because the concurrent request limit is reached.
#[derive(Debug)]
pub struct LimitReached;

impl std::fmt::Display for LimitReached {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("LimitReached")
}
rama_utils::macros::error::static_str_error! {
#[doc = "request aborted due to exhausted concurrency limit"]
pub struct LimitReached;
}

impl std::error::Error for LimitReached {}

/// The tracker trait that can be implemented to provide custom concurrent request tracking.
///
/// By default [`ConcurrentCounter`] is provided, but in case you need multi-instance tracking,
Expand Down
15 changes: 3 additions & 12 deletions rama-net/src/asn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,7 @@ impl<'de> Deserialize<'de> for Asn {
}
}

#[derive(Debug, Clone)]
#[non_exhaustive]
/// Error to indicate an invalid ASN for any reason,
/// most typically being because it is within the reserved space.
pub struct InvalidAsn;

impl fmt::Display for InvalidAsn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "invalid ASN")
}
rama_utils::macros::error::static_str_error! {
#[doc = "invalid ASN (e.g. within reserved space)"]
pub struct InvalidAsn;
}

impl std::error::Error for InvalidAsn {}
2 changes: 1 addition & 1 deletion rama-tls/src/boring/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub use config::ServerConfig;

mod service;
#[doc(inline)]
pub use service::{TlsAcceptorError, TlsAcceptorService};
pub use service::TlsAcceptorService;

mod layer;
#[doc(inline)]
Expand Down
89 changes: 19 additions & 70 deletions rama-tls/src/boring/server/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ use crate::{
};
use parking_lot::Mutex;
use rama_core::{
error::{ErrorContext, ErrorExt, OpaqueError},
error::{BoxError, ErrorContext, ErrorExt, OpaqueError},
Context, Service,
};
use rama_net::stream::Stream;
use rama_utils::macros::define_inner_service_accessors;
use std::{fmt, sync::Arc};
use std::sync::Arc;

/// A [`Service`] which accepts TLS connections and delegates the underlying transport
/// stream to the given service.
Expand Down Expand Up @@ -64,45 +64,39 @@ impl<T, S, IO> Service<T, IO> for TlsAcceptorService<S>
where
T: Send + Sync + 'static,
IO: Stream + Unpin + 'static,
S: Service<T, SslStream<IO>>,
S: Service<T, SslStream<IO>, Error: Into<BoxError>>,
{
type Response = S::Response;
type Error = TlsAcceptorError<S::Error>;
type Error = BoxError;

async fn serve(&self, mut ctx: Context<T>, stream: IO) -> Result<Self::Response, Self::Error> {
// let acceptor = TlsAcceptor::from(self.config.clone());

let mut acceptor_builder = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls_server())
.context("create boring ssl acceptor")
.map_err(TlsAcceptorError::Accept)?;
.context("create boring ssl acceptor")?;

acceptor_builder.set_grease_enabled(true);
acceptor_builder
.set_default_verify_paths()
.context("build boring ssl acceptor: set default verify paths")
.map_err(TlsAcceptorError::Accept)?;
.context("build boring ssl acceptor: set default verify paths")?;

for (i, ca_cert) in self.config.ca_cert_chain.iter().enumerate() {
if i == 0 {
acceptor_builder
.set_certificate(ca_cert.as_ref())
.context("build boring ssl acceptor: set Leaf CA certificate (x509)")
.map_err(TlsAcceptorError::Accept)?;
.context("build boring ssl acceptor: set Leaf CA certificate (x509)")?;
} else {
acceptor_builder
.add_extra_chain_cert(ca_cert.clone())
.context("build boring ssl acceptor: add extra chain certificate (x509)")
.map_err(TlsAcceptorError::Accept)?;
.context("build boring ssl acceptor: add extra chain certificate (x509)")?;
}
}
acceptor_builder
.set_private_key(self.config.private_key.as_ref())
.context("build boring ssl acceptor: set private key")
.map_err(TlsAcceptorError::Accept)?;
.context("build boring ssl acceptor: set private key")?;
acceptor_builder
.check_private_key()
.context("build boring ssl acceptor: check private key")
.map_err(TlsAcceptorError::Accept)?;
.context("build boring ssl acceptor: check private key")?;

let mut maybe_client_hello = if self.store_client_hello {
let maybe_client_hello = Arc::new(Mutex::new(None));
Expand All @@ -127,13 +121,11 @@ where
let mut buf = vec![];
for alpn in &self.config.alpn_protocols {
alpn.encode_wire_format(&mut buf)
.context("build boring ssl acceptor: encode alpn")
.map_err(TlsAcceptorError::Accept)?;
.context("build boring ssl acceptor: encode alpn")?;
}
acceptor_builder
.set_alpn_protos(&buf[..])
.context("build boring ssl acceptor: set alpn")
.map_err(TlsAcceptorError::Accept)?;
.context("build boring ssl acceptor: set alpn")?;
}

if let Some(keylog_filename) = &self.config.keylog_filename {
Expand All @@ -142,8 +134,7 @@ where
.append(true)
.create(true)
.open(keylog_filename)
.context("build boring ssl acceptor: set keylog: open file")
.map_err(TlsAcceptorError::Accept)?;
.context("build boring ssl acceptor: set keylog: open file")?;
acceptor_builder.set_keylog_callback(move |_, line| {
use std::io::Write;
let line = format!("{}\n", line);
Expand All @@ -160,8 +151,7 @@ where
Some(err) => OpaqueError::from_display(err.to_string())
.context("boring ssl acceptor: accept"),
None => OpaqueError::from_display("boring ssl acceptor: accept"),
})
.map_err(TlsAcceptorError::Accept)?;
})?;

let secure_transport = maybe_client_hello
.take()
Expand All @@ -170,51 +160,10 @@ where
.unwrap_or_default();
ctx.insert(secure_transport);

self.inner
.serve(ctx, stream)
.await
.map_err(TlsAcceptorError::Service)
}
}

/// Errors that can happen when using [`TlsAcceptorService`].
pub enum TlsAcceptorError<E> {
/// An error occurred while accepting a TLS connection.
Accept(OpaqueError),
/// An error occurred while serving the underlying transport stream
/// using the inner service.
Service(E),
}

impl<E: fmt::Debug> fmt::Debug for TlsAcceptorError<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Accept(err) => write!(f, "TlsAcceptorError::Accept({err:?})"),
Self::Service(err) => write!(f, "TlsAcceptorError::Service({err:?})"),
}
}
}

impl<E> std::fmt::Display for TlsAcceptorError<E>
where
E: std::fmt::Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TlsAcceptorError::Accept(e) => write!(f, "accept error: {}", e),
TlsAcceptorError::Service(e) => write!(f, "service error: {}", e),
}
}
}

impl<E> std::error::Error for TlsAcceptorError<E>
where
E: std::fmt::Debug + std::fmt::Display,
{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
TlsAcceptorError::Accept(e) => Some(e),
TlsAcceptorError::Service(_) => None,
}
self.inner.serve(ctx, stream).await.map_err(|err| {
OpaqueError::from_boxed(err.into())
.context("rustls acceptor: service error")
.into_boxed()
})
}
}
2 changes: 1 addition & 1 deletion rama-tls/src/rustls/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

mod service;
#[doc(inline)]
pub use service::{TlsAcceptorError, TlsAcceptorService};
pub use service::TlsAcceptorService;

mod client_config;
#[doc(inline)]
Expand Down
Loading

0 comments on commit 5b9fd17

Please sign in to comment.