Skip to content

Commit

Permalink
Feature: Add a layer that catches panics (#366)
Browse files Browse the repository at this point in the history
* Feature: Add a layer that catches panics

This allows preventing job execution from killing workers and returns an error containing the backtrace

* fix: backtrace as it may be different

* add: example for catch-panic

* fix: make  not default
  • Loading branch information
geofmureithi authored Jul 13, 2024
1 parent d70e479 commit 97ff348
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 2 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ timeout = ["tower/timeout"]
limit = ["tower/limit"]
## Support filtering jobs based on a predicate
filter = ["tower/filter"]
## Captures panics in executions and convert them to errors
catch-panic = ["dep:backtrace"]
## Compatibility with async-std and smol runtimes
async-std-comp = ["async-std"]
## Compatibility with tokio and actix runtimes
Expand All @@ -46,6 +48,7 @@ layers = [
"timeout",
"limit",
"filter",
"catch-panic",
]

docsrs = ["document-features"]
Expand Down Expand Up @@ -134,6 +137,7 @@ pin-project-lite = "0.2.14"
uuid = { version = "1.8", optional = true }
ulid = { version = "1", optional = true }
serde = { version = "1.0", features = ["derive"] }
backtrace = { version = "0.3", optional = true }

[dependencies.tracing]
default-features = false
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ async fn produce_route_jobs(storage: &RedisStorage<Email>) -> Result<()> {
- _timeout_ — Support timeouts on jobs
- _limit_ — 💪 Limit the amount of jobs
- _filter_ — Support filtering jobs based on a predicate
- _catch-panic_ - Catch panics that occur during execution

## Storage Comparison

Expand Down
2 changes: 1 addition & 1 deletion examples/basics/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ license = "MIT OR Apache-2.0"
[dependencies]
thiserror = "1"
tokio = { version = "1", features = ["full"] }
apalis = { path = "../../", features = ["limit", "tokio-comp"] }
apalis = { path = "../../", features = ["limit", "tokio-comp", "catch-panic"] }
apalis-sql = { path = "../../packages/apalis-sql" }
serde = "1"
tracing-subscriber = "0.3.11"
Expand Down
7 changes: 6 additions & 1 deletion examples/basics/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ mod service;

use std::time::Duration;

use apalis::{layers::tracing::TraceLayer, prelude::*};
use apalis::{
layers::{catch_panic::CatchPanicLayer, tracing::TraceLayer},
prelude::*,
};
use apalis_sql::sqlite::{SqlitePool, SqliteStorage};

use email_service::Email;
Expand Down Expand Up @@ -96,6 +99,8 @@ async fn main() -> Result<(), std::io::Error> {
Monitor::<TokioExecutor>::new()
.register_with_count(2, {
WorkerBuilder::new("tasty-banana")
// This handles any panics that may occur in any of the layers below
.layer(CatchPanicLayer::new())
.layer(TraceLayer::new())
.layer(LogLayer::new("some-log-example"))
// Add shared context to all jobs executed by this worker
Expand Down
181 changes: 181 additions & 0 deletions src/layers/catch_panic/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
use std::fmt;
use std::future::Future;
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::pin::Pin;
use std::task::{Context, Poll};

use apalis_core::error::Error;
use apalis_core::request::Request;
use backtrace::Backtrace;
use tower::Layer;
use tower::Service;

/// Apalis Layer that catches panics in the service.
#[derive(Clone, Debug)]
pub struct CatchPanicLayer;

impl CatchPanicLayer {
/// Creates a new `CatchPanicLayer`.
pub fn new() -> Self {
CatchPanicLayer
}
}

impl Default for CatchPanicLayer {
fn default() -> Self {
Self::new()
}
}

impl<S> Layer<S> for CatchPanicLayer {
type Service = CatchPanicService<S>;

fn layer(&self, service: S) -> Self::Service {
CatchPanicService { service }
}
}

/// Apalis Service that catches panics.
#[derive(Clone, Debug)]
pub struct CatchPanicService<S> {
service: S,
}

impl<S, J, Res> Service<Request<J>> for CatchPanicService<S>
where
S: Service<Request<J>, Response = Res, Error = Error>,
{
type Response = S::Response;
type Error = S::Error;
type Future = CatchPanicFuture<S::Future>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}

fn call(&mut self, request: Request<J>) -> Self::Future {
CatchPanicFuture {
future: self.service.call(request),
}
}
}

pin_project_lite::pin_project! {
/// A wrapper that catches panics during execution
pub struct CatchPanicFuture<F> {
#[pin]
future: F,

}
}

/// An error generated from a panic
#[derive(Debug, Clone)]
pub struct PanicError(pub String, pub Backtrace);

impl std::error::Error for PanicError {}

impl fmt::Display for PanicError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "PanicError: {}, Backtrace: {:?}", self.0, self.1)
}
}

impl<F, Res> Future for CatchPanicFuture<F>
where
F: Future<Output = Result<Res, Error>>,
{
type Output = Result<Res, Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();

match catch_unwind(AssertUnwindSafe(|| this.future.poll(cx))) {
Ok(res) => res,
Err(e) => {
let panic_info = if let Some(s) = e.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = e.downcast_ref::<String>() {
s.clone()
} else {
"Unknown panic".to_string()
};
Poll::Ready(Err(Error::Failed(Box::new(PanicError(
panic_info,
Backtrace::new(),
)))))
}
}
}
}

#[cfg(test)]
mod tests {
use super::*;

use std::task::{Context, Poll};
use tower::Service;

#[derive(Clone, Debug)]
struct TestJob;

#[derive(Clone)]
struct TestService;

impl Service<Request<TestJob>> for TestService {
type Response = usize;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn call(&mut self, _req: Request<TestJob>) -> Self::Future {
Box::pin(async { Ok(42) })
}
}

#[tokio::test]
async fn test_catch_panic_layer() {
let layer = CatchPanicLayer::new();
let mut service = layer.layer(TestService);

let request = Request::new(TestJob);
let response = service.call(request).await;

assert!(response.is_ok());
}

#[tokio::test]
async fn test_catch_panic_layer_panics() {
struct PanicService;

impl Service<Request<TestJob>> for PanicService {
type Response = usize;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn call(&mut self, _req: Request<TestJob>) -> Self::Future {
Box::pin(async { None.unwrap() })
}
}

let layer = CatchPanicLayer::new();
let mut service = layer.layer(PanicService);

let request = Request::new(TestJob);
let response = service.call(request).await;

assert!(response.is_err());

assert_eq!(
response.unwrap_err().to_string()[0..87],
*"Task Failed: PanicError: called `Option::unwrap()` on a `None` value, Backtrace: 0: "
);
}
}
5 changes: 5 additions & 0 deletions src/layers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,8 @@ pub mod limit {
#[cfg(feature = "timeout")]
#[cfg_attr(docsrs, doc(cfg(feature = "timeout")))]
pub use tower::timeout::TimeoutLayer;

/// catch panic middleware for apalis
#[cfg(feature = "catch-panic")]
#[cfg_attr(docsrs, doc(cfg(feature = "catch-panic")))]
pub mod catch_panic;

0 comments on commit 97ff348

Please sign in to comment.