diff --git a/lambda-runtime/src/lib.rs b/lambda-runtime/src/lib.rs index a5be8fd1..a178fa3b 100644 --- a/lambda-runtime/src/lib.rs +++ b/lambda-runtime/src/lib.rs @@ -13,8 +13,8 @@ use serde::{Deserialize, Serialize}; use std::{convert::TryFrom, env, fmt, future::Future, panic}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_stream::{Stream, StreamExt}; -use tower::util::ServiceFn; pub use tower::{self, service_fn, Service}; +use tower::{util::ServiceFn, ServiceExt}; use tracing::{error, trace}; mod requests; @@ -112,41 +112,56 @@ where env::set_var("_X_AMZN_TRACE_ID", xray_trace_id); let request_id = &ctx.request_id.clone(); - let task = panic::catch_unwind(panic::AssertUnwindSafe(|| handler.call(LambdaEvent::new(body, ctx)))); - - let req = match task { - Ok(response) => match response.await { - Ok(response) => { - trace!("Ok response from handler (run loop)"); - EventCompletionRequest { - request_id, - body: response, - } - .into_req() - } - Err(err) => { - error!("{:?}", err); // logs the error in CloudWatch - EventErrorRequest { - request_id, - diagnostic: Diagnostic { - error_type: type_name_of_val(&err).to_owned(), - error_message: format!("{}", err), // returns the error to the caller via Lambda API - }, + let req = match handler.ready().await { + Ok(handler) => { + let task = + panic::catch_unwind(panic::AssertUnwindSafe(|| handler.call(LambdaEvent::new(body, ctx)))); + match task { + Ok(response) => match response.await { + Ok(response) => { + trace!("Ok response from handler (run loop)"); + EventCompletionRequest { + request_id, + body: response, + } + .into_req() + } + Err(err) => { + error!("{:?}", err); // logs the error in CloudWatch + EventErrorRequest { + request_id, + diagnostic: Diagnostic { + error_type: type_name_of_val(&err).to_owned(), + error_message: format!("{}", err), // returns the error to the caller via Lambda API + }, + } + .into_req() + } + }, + Err(err) => { + error!("{:?}", err); + EventErrorRequest { + request_id, + diagnostic: Diagnostic { + error_type: type_name_of_val(&err).to_owned(), + error_message: if let Some(msg) = err.downcast_ref::<&str>() { + format!("Lambda panicked: {}", msg) + } else { + "Lambda panicked".to_string() + }, + }, + } + .into_req() } - .into_req() } - }, + } Err(err) => { - error!("{:?}", err); + error!("{:?}", err); // logs the error in CloudWatch EventErrorRequest { request_id, diagnostic: Diagnostic { error_type: type_name_of_val(&err).to_owned(), - error_message: if let Some(msg) = err.downcast_ref::<&str>() { - format!("Lambda panicked: {}", msg) - } else { - "Lambda panicked".to_string() - }, + error_message: format!("{}", err), // returns the error to the caller via Lambda API }, } .into_req()