From 97ff3489c38c1996d92610cd39cfd0d7c501d2c7 Mon Sep 17 00:00:00 2001 From: Geoffrey Mureithi <95377562+geofmureithi@users.noreply.github.com> Date: Sat, 13 Jul 2024 08:39:56 +0300 Subject: [PATCH] Feature: Add a layer that catches panics (#366) * Feature: Add a layer that catches panics This allows preventing job execution from killing workers and returns an error containing the backtrace * fix: backtrace as it may be different * add: example for catch-panic * fix: make not default --- Cargo.toml | 4 + README.md | 1 + examples/basics/Cargo.toml | 2 +- examples/basics/src/main.rs | 7 +- src/layers/catch_panic/mod.rs | 181 ++++++++++++++++++++++++++++++++++ src/layers/mod.rs | 5 + 6 files changed, 198 insertions(+), 2 deletions(-) create mode 100644 src/layers/catch_panic/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 4c3dc410..4e719d5f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,8 @@ timeout = ["tower/timeout"] limit = ["tower/limit"] ## Support filtering jobs based on a predicate filter = ["tower/filter"] +## Captures panics in executions and convert them to errors +catch-panic = ["dep:backtrace"] ## Compatibility with async-std and smol runtimes async-std-comp = ["async-std"] ## Compatibility with tokio and actix runtimes @@ -46,6 +48,7 @@ layers = [ "timeout", "limit", "filter", + "catch-panic", ] docsrs = ["document-features"] @@ -134,6 +137,7 @@ pin-project-lite = "0.2.14" uuid = { version = "1.8", optional = true } ulid = { version = "1", optional = true } serde = { version = "1.0", features = ["derive"] } +backtrace = { version = "0.3", optional = true } [dependencies.tracing] default-features = false diff --git a/README.md b/README.md index 663549cc..0b2234d7 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,7 @@ async fn produce_route_jobs(storage: &RedisStorage) -> Result<()> { - _timeout_ — Support timeouts on jobs - _limit_ — 💪 Limit the amount of jobs - _filter_ — Support filtering jobs based on a predicate +- _catch-panic_ - Catch panics that occur during execution ## Storage Comparison diff --git a/examples/basics/Cargo.toml b/examples/basics/Cargo.toml index d77bccbe..feade0b9 100644 --- a/examples/basics/Cargo.toml +++ b/examples/basics/Cargo.toml @@ -8,7 +8,7 @@ license = "MIT OR Apache-2.0" [dependencies] thiserror = "1" tokio = { version = "1", features = ["full"] } -apalis = { path = "../../", features = ["limit", "tokio-comp"] } +apalis = { path = "../../", features = ["limit", "tokio-comp", "catch-panic"] } apalis-sql = { path = "../../packages/apalis-sql" } serde = "1" tracing-subscriber = "0.3.11" diff --git a/examples/basics/src/main.rs b/examples/basics/src/main.rs index 67a78a33..e6b1f0cb 100644 --- a/examples/basics/src/main.rs +++ b/examples/basics/src/main.rs @@ -4,7 +4,10 @@ mod service; use std::time::Duration; -use apalis::{layers::tracing::TraceLayer, prelude::*}; +use apalis::{ + layers::{catch_panic::CatchPanicLayer, tracing::TraceLayer}, + prelude::*, +}; use apalis_sql::sqlite::{SqlitePool, SqliteStorage}; use email_service::Email; @@ -96,6 +99,8 @@ async fn main() -> Result<(), std::io::Error> { Monitor::::new() .register_with_count(2, { WorkerBuilder::new("tasty-banana") + // This handles any panics that may occur in any of the layers below + .layer(CatchPanicLayer::new()) .layer(TraceLayer::new()) .layer(LogLayer::new("some-log-example")) // Add shared context to all jobs executed by this worker diff --git a/src/layers/catch_panic/mod.rs b/src/layers/catch_panic/mod.rs new file mode 100644 index 00000000..f7b149bf --- /dev/null +++ b/src/layers/catch_panic/mod.rs @@ -0,0 +1,181 @@ +use std::fmt; +use std::future::Future; +use std::panic::{catch_unwind, AssertUnwindSafe}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use apalis_core::error::Error; +use apalis_core::request::Request; +use backtrace::Backtrace; +use tower::Layer; +use tower::Service; + +/// Apalis Layer that catches panics in the service. +#[derive(Clone, Debug)] +pub struct CatchPanicLayer; + +impl CatchPanicLayer { + /// Creates a new `CatchPanicLayer`. + pub fn new() -> Self { + CatchPanicLayer + } +} + +impl Default for CatchPanicLayer { + fn default() -> Self { + Self::new() + } +} + +impl Layer for CatchPanicLayer { + type Service = CatchPanicService; + + fn layer(&self, service: S) -> Self::Service { + CatchPanicService { service } + } +} + +/// Apalis Service that catches panics. +#[derive(Clone, Debug)] +pub struct CatchPanicService { + service: S, +} + +impl Service> for CatchPanicService +where + S: Service, Response = Res, Error = Error>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = CatchPanicFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + CatchPanicFuture { + future: self.service.call(request), + } + } +} + +pin_project_lite::pin_project! { + /// A wrapper that catches panics during execution + pub struct CatchPanicFuture { + #[pin] + future: F, + + } +} + +/// An error generated from a panic +#[derive(Debug, Clone)] +pub struct PanicError(pub String, pub Backtrace); + +impl std::error::Error for PanicError {} + +impl fmt::Display for PanicError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "PanicError: {}, Backtrace: {:?}", self.0, self.1) + } +} + +impl Future for CatchPanicFuture +where + F: Future>, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + match catch_unwind(AssertUnwindSafe(|| this.future.poll(cx))) { + Ok(res) => res, + Err(e) => { + let panic_info = if let Some(s) = e.downcast_ref::<&str>() { + s.to_string() + } else if let Some(s) = e.downcast_ref::() { + s.clone() + } else { + "Unknown panic".to_string() + }; + Poll::Ready(Err(Error::Failed(Box::new(PanicError( + panic_info, + Backtrace::new(), + ))))) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::task::{Context, Poll}; + use tower::Service; + + #[derive(Clone, Debug)] + struct TestJob; + + #[derive(Clone)] + struct TestService; + + impl Service> for TestService { + type Response = usize; + type Error = Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: Request) -> Self::Future { + Box::pin(async { Ok(42) }) + } + } + + #[tokio::test] + async fn test_catch_panic_layer() { + let layer = CatchPanicLayer::new(); + let mut service = layer.layer(TestService); + + let request = Request::new(TestJob); + let response = service.call(request).await; + + assert!(response.is_ok()); + } + + #[tokio::test] + async fn test_catch_panic_layer_panics() { + struct PanicService; + + impl Service> for PanicService { + type Response = usize; + type Error = Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: Request) -> Self::Future { + Box::pin(async { None.unwrap() }) + } + } + + let layer = CatchPanicLayer::new(); + let mut service = layer.layer(PanicService); + + let request = Request::new(TestJob); + let response = service.call(request).await; + + assert!(response.is_err()); + + assert_eq!( + response.unwrap_err().to_string()[0..87], + *"Task Failed: PanicError: called `Option::unwrap()` on a `None` value, Backtrace: 0: " + ); + } +} diff --git a/src/layers/mod.rs b/src/layers/mod.rs index e7b5e99e..f990573a 100644 --- a/src/layers/mod.rs +++ b/src/layers/mod.rs @@ -25,3 +25,8 @@ pub mod limit { #[cfg(feature = "timeout")] #[cfg_attr(docsrs, doc(cfg(feature = "timeout")))] pub use tower::timeout::TimeoutLayer; + +/// catch panic middleware for apalis +#[cfg(feature = "catch-panic")] +#[cfg_attr(docsrs, doc(cfg(feature = "catch-panic")))] +pub mod catch_panic;