diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md index f745ec72..662d1fed 100644 --- a/tower-http/CHANGELOG.md +++ b/tower-http/CHANGELOG.md @@ -9,7 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Added -- None. +- Added `CatchPanic` middleware which catches panics and converts them + into `500 Internal Server` responses ## Changed diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml index 96b70a74..facaf6f4 100644 --- a/tower-http/Cargo.toml +++ b/tower-http/Cargo.toml @@ -48,12 +48,14 @@ tokio = { version = "1", features = ["full"] } tower = { version = "0.4.10", features = ["buffer", "util", "retry", "make", "timeout"] } tracing-subscriber = "0.3" uuid = { version = "0.8", features = ["v4"] } +serde_json = "1.0" [features] default = [] full = [ "add-extension", "auth", + "catch-panic", "compression-full", "cors", "decompression-full", @@ -73,6 +75,7 @@ full = [ add-extension = [] auth = ["base64"] +catch-panic = ["tracing", "futures-util/std"] cors = [] follow-redirect = ["iri-string", "tower/util"] fs = ["tokio/fs", "tokio-util/io", "tokio/io-util", "mime_guess", "mime", "percent-encoding", "httpdate"] diff --git a/tower-http/src/builder.rs b/tower-http/src/builder.rs index d7084866..0467abd1 100644 --- a/tower-http/src/builder.rs +++ b/tower-http/src/builder.rs @@ -365,6 +365,19 @@ pub trait ServiceBuilderExt: crate::sealed::Sealed + Sized { ) -> ServiceBuilder> { self.propagate_request_id(HeaderName::from_static(crate::request_id::X_REQUEST_ID)) } + + /// Catch panics and convert them into `500 Internal Server` responses. + /// + /// See [`tower_http::catch_panic`] for more details. + /// + /// [`tower_http::catch_panic`]: crate::catch_panic + #[cfg(feature = "catch-panic")] + #[cfg_attr(docsrs, doc(cfg(feature = "catch-panic")))] + fn catch_panic( + self, + ) -> ServiceBuilder< + Stack, L>, + >; } impl crate::sealed::Sealed for ServiceBuilder {} @@ -588,4 +601,14 @@ impl ServiceBuilderExt for ServiceBuilder { ) -> ServiceBuilder> { self.layer(crate::request_id::PropagateRequestIdLayer::new(header_name)) } + + #[cfg(feature = "catch-panic")] + #[cfg_attr(docsrs, doc(cfg(feature = "catch-panic")))] + fn catch_panic( + self, + ) -> ServiceBuilder< + Stack, L>, + > { + self.layer(crate::catch_panic::CatchPanicLayer::new()) + } } diff --git a/tower-http/src/catch_panic.rs b/tower-http/src/catch_panic.rs new file mode 100644 index 00000000..b8e3213e --- /dev/null +++ b/tower-http/src/catch_panic.rs @@ -0,0 +1,401 @@ +//! Convert panics into responses. +//! +//! Note that using panics for error handling is _not_ recommended. Prefer instead to use `Result` +//! whenever possible. +//! +//! # Example +//! +//! ```rust +//! use http::{Request, Response, header::HeaderName}; +//! use std::convert::Infallible; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! use tower_http::catch_panic::CatchPanicLayer; +//! use hyper::Body; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! async fn handle(req: Request) -> Result, Infallible> { +//! panic!("something went wrong...") +//! } +//! +//! let mut svc = ServiceBuilder::new() +//! // Catch panics and convert them into responses. +//! .layer(CatchPanicLayer::new()) +//! .service_fn(handle); +//! +//! // Call the service. +//! let request = Request::new(Body::empty()); +//! +//! let response = svc.ready().await?.call(request).await?; +//! +//! assert_eq!(response.status(), 500); +//! # +//! # Ok(()) +//! # } +//! ``` +//! +//! Using a custom panic handler: +//! +//! ```rust +//! use http::{Request, StatusCode, Response, header::{self, HeaderName}}; +//! use std::{any::Any, convert::Infallible}; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! use tower_http::catch_panic::CatchPanicLayer; +//! use hyper::Body; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! async fn handle(req: Request) -> Result, Infallible> { +//! panic!("something went wrong...") +//! } +//! +//! fn handle_panic(err: Box) -> Response { +//! let details = if let Some(s) = err.downcast_ref::() { +//! s.clone() +//! } else if let Some(s) = err.downcast_ref::<&str>() { +//! s.to_string() +//! } else { +//! "Unknown panic message".to_string() +//! }; +//! +//! let body = serde_json::json!({ +//! "error": { +//! "kind": "panic", +//! "details": details, +//! } +//! }); +//! let body = serde_json::to_string(&body).unwrap(); +//! +//! Response::builder() +//! .status(StatusCode::INTERNAL_SERVER_ERROR) +//! .header(header::CONTENT_TYPE, "application/json") +//! .body(Body::from(body)) +//! .unwrap() +//! } +//! +//! let svc = ServiceBuilder::new() +//! // Use `handle_panic` to create the response. +//! .layer(CatchPanicLayer::custom(handle_panic)) +//! .service_fn(handle); +//! # +//! # Ok(()) +//! # } +//! ``` + +use bytes::Bytes; +use futures_core::ready; +use futures_util::future::{CatchUnwind, FutureExt}; +use http::{HeaderValue, Request, Response, StatusCode}; +use http_body::{combinators::UnsyncBoxBody, Body, Full}; +use pin_project_lite::pin_project; +use std::{ + any::Any, + future::Future, + panic::AssertUnwindSafe, + pin::Pin, + task::{Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +use crate::BoxError; + +/// Layer that applies the [`CatchPanic`] middleware that catches panics and converts them into +/// `500 Internal Server` responses. +/// +/// See the [module docs](self) for an example. +#[derive(Debug, Clone, Copy, Default)] +pub struct CatchPanicLayer { + panic_handler: T, +} + +impl CatchPanicLayer { + /// Create a new `CatchPanicLayer` with the default panic handler. + pub fn new() -> Self { + CatchPanicLayer { + panic_handler: DefaultResponseForPanic, + } + } +} + +impl CatchPanicLayer { + /// Create a new `CatchPanicLayer` with a custom panic handler. + pub fn custom(panic_handler: T) -> Self + where + T: ResponseForPanic, + { + Self { panic_handler } + } +} + +impl Layer for CatchPanicLayer +where + T: Clone, +{ + type Service = CatchPanic; + + fn layer(&self, inner: S) -> Self::Service { + CatchPanic { + inner, + panic_handler: self.panic_handler.clone(), + } + } +} + +/// Middleware that catches panics and converts them into `500 Internal Server` responses. +/// +/// See the [module docs](self) for an example. +#[derive(Debug, Clone, Copy)] +pub struct CatchPanic { + inner: S, + panic_handler: T, +} + +impl CatchPanic { + /// Create a new `CatchPanic` with the default panic handler. + pub fn new(inner: S) -> Self { + Self { + inner, + panic_handler: DefaultResponseForPanic, + } + } +} + +impl CatchPanic { + define_inner_service_accessors!(); + + /// Create a new `CatchPanic` with a custom panic handler. + pub fn custom(inner: S, panic_handler: T) -> Self + where + T: ResponseForPanic, + { + Self { + inner, + panic_handler, + } + } +} + +impl Service> for CatchPanic +where + S: Service, Response = Response>, + ResBody: Body + Send + 'static, + ResBody::Error: Into, + T: ResponseForPanic + Clone, + T::ResponseBody: Body + Send + 'static, + ::Error: Into, +{ + type Response = Response>; + type Error = S::Error; + type Future = ResponseFuture; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + match std::panic::catch_unwind(AssertUnwindSafe(|| self.inner.call(req))) { + Ok(future) => ResponseFuture { + kind: Kind::Future { + future: AssertUnwindSafe(future).catch_unwind(), + panic_handler: Some(self.panic_handler.clone()), + }, + }, + Err(panic_err) => ResponseFuture { + kind: Kind::Panicked { + panic_err: Some(panic_err), + panic_handler: Some(self.panic_handler.clone()), + }, + }, + } + } +} + +pin_project! { + /// Response future for [`CatchPanic`]. + pub struct ResponseFuture { + #[pin] + kind: Kind, + } +} + +pin_project! { + #[project = KindProj] + enum Kind { + Panicked { + panic_err: Option>, + panic_handler: Option, + }, + Future { + #[pin] + future: CatchUnwind>, + panic_handler: Option, + } + } +} + +impl Future for ResponseFuture +where + F: Future, E>>, + ResBody: Body + Send + 'static, + ResBody::Error: Into, + T: ResponseForPanic, + T::ResponseBody: Body + Send + 'static, + ::Error: Into, +{ + type Output = Result>, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project().kind.project() { + KindProj::Panicked { + panic_err, + panic_handler, + } => { + let panic_handler = panic_handler + .take() + .expect("future polled after completion"); + let panic_err = panic_err.take().expect("future polled after completion"); + Poll::Ready(Ok(response_for_panic(panic_handler, panic_err))) + } + KindProj::Future { + future, + panic_handler, + } => match ready!(future.poll(cx)) { + Ok(Ok(res)) => { + Poll::Ready(Ok(res.map(|body| body.map_err(Into::into).boxed_unsync()))) + } + Ok(Err(svc_err)) => Poll::Ready(Err(svc_err)), + Err(panic_err) => Poll::Ready(Ok(response_for_panic( + panic_handler + .take() + .expect("future polled after completion"), + panic_err, + ))), + }, + } + } +} + +fn response_for_panic( + mut panic_handler: T, + err: Box, +) -> Response> +where + T: ResponseForPanic, + T::ResponseBody: Body + Send + 'static, + ::Error: Into, +{ + panic_handler + .response_for_panic(err) + .map(|body| body.map_err(Into::into).boxed_unsync()) +} + +/// Trait for creating responses from panics. +pub trait ResponseForPanic: Clone { + /// The body type used for responses to panics. + type ResponseBody; + + /// Create a response from the panic error. + fn response_for_panic( + &mut self, + err: Box, + ) -> Response; +} + +impl ResponseForPanic for F +where + F: FnMut(Box) -> Response + Clone, +{ + type ResponseBody = B; + + fn response_for_panic( + &mut self, + err: Box, + ) -> Response { + self(err) + } +} + +/// The default `ResponseForPanic` used by `CatchPanic`. +/// +/// It will log the panic message and return a `500 Internal Server` error response with an empty +/// body. +#[derive(Debug, Default, Clone, Copy)] +#[non_exhaustive] +pub struct DefaultResponseForPanic; + +impl ResponseForPanic for DefaultResponseForPanic { + type ResponseBody = Full; + + fn response_for_panic( + &mut self, + err: Box, + ) -> Response { + if let Some(s) = err.downcast_ref::() { + tracing::error!("Service panicked: {}", s); + } else if let Some(s) = err.downcast_ref::<&str>() { + tracing::error!("Service panicked: {}", s); + } else { + tracing::error!( + "Service panicked but `CatchPanic` was unable to downcast the panic info" + ); + }; + + let mut res = Response::new(Full::from("Service panicked")); + *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + + #[allow(clippy::declare_interior_mutable_const)] + const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8"); + res.headers_mut() + .insert(http::header::CONTENT_TYPE, TEXT_PLAIN); + + res + } +} + +#[cfg(test)] +mod tests { + #![allow(unreachable_code)] + + use super::*; + use hyper::{Body, Response}; + use std::convert::Infallible; + use tower::{ServiceBuilder, ServiceExt}; + + #[tokio::test] + async fn panic_before_returning_future() { + let svc = ServiceBuilder::new() + .layer(CatchPanicLayer::new()) + .service_fn(|_: Request| { + panic!("service panic"); + async { Ok::<_, Infallible>(Response::new(Body::empty())) } + }); + + let req = Request::new(Body::empty()); + + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + let body = hyper::body::to_bytes(res).await.unwrap(); + assert_eq!(&body[..], b"Service panicked"); + } + + #[tokio::test] + async fn panic_in_future() { + let svc = ServiceBuilder::new() + .layer(CatchPanicLayer::new()) + .service_fn(|_: Request| async { + panic!("future panic"); + Ok::<_, Infallible>(Response::new(Body::empty())) + }); + + let req = Request::new(Body::empty()); + + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + let body = hyper::body::to_bytes(res).await.unwrap(); + assert_eq!(&body[..], b"Service panicked"); + } +} diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs index c8f07ec6..987a3b3d 100644 --- a/tower-http/src/lib.rs +++ b/tower-http/src/lib.rs @@ -322,6 +322,10 @@ pub mod cors; #[cfg_attr(docsrs, doc(cfg(feature = "request-id")))] pub mod request_id; +#[cfg(feature = "catch-panic")] +#[cfg_attr(docsrs, doc(cfg(feature = "catch-panic")))] +pub mod catch_panic; + pub mod classify; pub mod services;