diff --git a/.github/workflows/build-events.yml b/.github/workflows/build-events.yml index d9f5c72a..3026f1ac 100644 --- a/.github/workflows/build-events.yml +++ b/.github/workflows/build-events.yml @@ -3,10 +3,10 @@ name: Check Lambda Events on: push: paths: - - 'lambda-events/**' + - "lambda-events/**" pull_request: paths: - - 'lambda-events/**' + - "lambda-events/**" jobs: build: @@ -14,7 +14,7 @@ jobs: strategy: matrix: toolchain: - - "1.66.0" # Current MSRV + - "1.70.0" # Current MSRV - stable env: RUST_BACKTRACE: 1 diff --git a/.github/workflows/build-extension.yml b/.github/workflows/build-extension.yml index d9bcc989..d09b08c0 100644 --- a/.github/workflows/build-extension.yml +++ b/.github/workflows/build-extension.yml @@ -18,7 +18,7 @@ jobs: strategy: matrix: toolchain: - - "1.66.0" # Current MSRV + - "1.70.0" # Current MSRV - stable env: RUST_BACKTRACE: 1 diff --git a/.github/workflows/build-runtime.yml b/.github/workflows/build-runtime.yml index 25cd83ec..a52927b5 100644 --- a/.github/workflows/build-runtime.yml +++ b/.github/workflows/build-runtime.yml @@ -19,7 +19,7 @@ jobs: strategy: matrix: toolchain: - - "1.66.0" # Current MSRV + - "1.70.0" # Current MSRV - stable env: RUST_BACKTRACE: 1 diff --git a/Cargo.toml b/Cargo.toml index cba3ba3b..09b046e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ members = [ "lambda-runtime-api-client", "lambda-runtime", "lambda-extension", - "lambda-events" + "lambda-events", ] exclude = ["examples"] @@ -26,4 +26,5 @@ hyper = "1.0" hyper-util = "0.1.1" pin-project-lite = "0.2" tower = "0.4" +tower-layer = "0.3" tower-service = "0.3" diff --git a/README.md b/README.md index 8f4838c5..331635d2 100644 --- a/README.md +++ b/README.md @@ -458,7 +458,7 @@ This will make your function compile much faster. ## Supported Rust Versions (MSRV) -The AWS Lambda Rust Runtime requires a minimum of Rust 1.66, and is not guaranteed to build on compiler versions earlier than that. +The AWS Lambda Rust Runtime requires a minimum of Rust 1.70, and is not guaranteed to build on compiler versions earlier than that. ## Security diff --git a/examples/opentelemetry-tracing/Cargo.toml b/examples/opentelemetry-tracing/Cargo.toml new file mode 100644 index 00000000..27a778b5 --- /dev/null +++ b/examples/opentelemetry-tracing/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "opentelemetry-tracing" +version = "0.1.0" +edition = "2021" + +[dependencies] +# Library dependencies +lambda_runtime = { path = "../../lambda-runtime" } +pin-project = "1" +opentelemetry-semantic-conventions = "0.14" +tower = "0.4" +tracing = "0.1" + +# Binary dependencies +opentelemetry = { version = "0.22", optional = true } +opentelemetry_sdk = { version = "0.22", features = ["rt-tokio"], optional = true } +opentelemetry-stdout = { version = "0.3", features = ["trace"], optional = true } +serde_json = { version = "1.0", optional = true } +tokio = { version = "1", optional = true } +tracing-opentelemetry = { version = "0.23", optional = true } +tracing-subscriber = { version = "0.3", optional = true } + +[features] +build-binary = [ + "opentelemetry", + "opentelemetry_sdk", + "opentelemetry-stdout", + "serde_json", + "tokio", + "tracing-opentelemetry", + "tracing-subscriber", +] + +[[bin]] +name = "opentelemetry-tracing" +required-features = ["build-binary"] diff --git a/examples/opentelemetry-tracing/src/lib.rs b/examples/opentelemetry-tracing/src/lib.rs new file mode 100644 index 00000000..82f12a16 --- /dev/null +++ b/examples/opentelemetry-tracing/src/lib.rs @@ -0,0 +1,113 @@ +use std::future::Future; +use std::pin::Pin; +use std::task; + +use lambda_runtime::LambdaInvocation; +use opentelemetry_semantic_conventions::trace as traceconv; +use pin_project::pin_project; +use tower::{Layer, Service}; +use tracing::instrument::Instrumented; +use tracing::Instrument; + +/// Tower layer to add OpenTelemetry tracing to a Lambda function invocation. The layer accepts +/// a function to flush OpenTelemetry after the end of the invocation. +pub struct OpenTelemetryLayer { + flush_fn: F, +} + +impl OpenTelemetryLayer +where + F: Fn() + Clone, +{ + pub fn new(flush_fn: F) -> Self { + Self { flush_fn } + } +} + +impl Layer for OpenTelemetryLayer +where + F: Fn() + Clone, +{ + type Service = OpenTelemetryService; + + fn layer(&self, inner: S) -> Self::Service { + OpenTelemetryService { + inner, + flush_fn: self.flush_fn.clone(), + coldstart: true, + } + } +} + +/// Tower service created by [OpenTelemetryLayer]. +pub struct OpenTelemetryService { + inner: S, + flush_fn: F, + coldstart: bool, +} + +impl Service for OpenTelemetryService +where + S: Service, + F: Fn() + Clone, +{ + type Error = S::Error; + type Response = (); + type Future = OpenTelemetryFuture, F>; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: LambdaInvocation) -> Self::Future { + let span = tracing::info_span!( + "Lambda function invocation", + "otel.name" = req.context.env_config.function_name, + { traceconv::FAAS_TRIGGER } = "http", + { traceconv::FAAS_INVOCATION_ID } = req.context.request_id, + { traceconv::FAAS_COLDSTART } = self.coldstart + ); + + // After the first execution, we can set 'coldstart' to false + self.coldstart = false; + + let fut = self.inner.call(req).instrument(span); + OpenTelemetryFuture { + future: Some(fut), + flush_fn: self.flush_fn.clone(), + } + } +} + +/// Future created by [OpenTelemetryService]. +#[pin_project] +pub struct OpenTelemetryFuture { + #[pin] + future: Option, + flush_fn: F, +} + +impl Future for OpenTelemetryFuture +where + Fut: Future, + F: Fn(), +{ + type Output = Fut::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll { + // First, try to get the ready value of the future + let ready = task::ready!(self + .as_mut() + .project() + .future + .as_pin_mut() + .expect("future polled after completion") + .poll(cx)); + + // If we got the ready value, we first drop the future: this ensures that the + // OpenTelemetry span attached to it is closed and included in the subsequent flush. + Pin::set(&mut self.as_mut().project().future, None); + (self.project().flush_fn)(); + task::Poll::Ready(ready) + } +} diff --git a/examples/opentelemetry-tracing/src/main.rs b/examples/opentelemetry-tracing/src/main.rs new file mode 100644 index 00000000..68038366 --- /dev/null +++ b/examples/opentelemetry-tracing/src/main.rs @@ -0,0 +1,34 @@ +use lambda_runtime::{LambdaEvent, Runtime}; +use opentelemetry::trace::TracerProvider; +use opentelemetry_sdk::{runtime, trace}; +use opentelemetry_tracing::OpenTelemetryLayer; +use tower::{service_fn, BoxError}; +use tracing_subscriber::prelude::*; + +async fn echo(event: LambdaEvent) -> Result { + Ok(event.payload) +} + +#[tokio::main] +async fn main() -> Result<(), BoxError> { + // Set up OpenTelemetry tracer provider that writes spans to stdout for debugging purposes + let exporter = opentelemetry_stdout::SpanExporter::default(); + let tracer_provider = trace::TracerProvider::builder() + .with_batch_exporter(exporter, runtime::Tokio) + .build(); + + // Set up link between OpenTelemetry and tracing crate + tracing_subscriber::registry() + .with(tracing_opentelemetry::OpenTelemetryLayer::new( + tracer_provider.tracer("my-app"), + )) + .init(); + + // Initialize the Lambda runtime and add OpenTelemetry tracing + let runtime = Runtime::new(service_fn(echo)).layer(OpenTelemetryLayer::new(|| { + // Make sure that the trace is exported before the Lambda runtime is frozen + tracer_provider.force_flush(); + })); + runtime.run().await?; + Ok(()) +} diff --git a/lambda-events/src/custom_serde/mod.rs b/lambda-events/src/custom_serde/mod.rs index 46d121d1..030cb5b3 100644 --- a/lambda-events/src/custom_serde/mod.rs +++ b/lambda-events/src/custom_serde/mod.rs @@ -177,18 +177,18 @@ mod test { let test = r#"{"v": null}"#; let decoded: Test = serde_json::from_str(test).unwrap(); - assert_eq!(false, decoded.v); + assert!(!decoded.v); let test = r#"{}"#; let decoded: Test = serde_json::from_str(test).unwrap(); - assert_eq!(false, decoded.v); + assert!(!decoded.v); let test = r#"{"v": true}"#; let decoded: Test = serde_json::from_str(test).unwrap(); - assert_eq!(true, decoded.v); + assert!(decoded.v); let test = r#"{"v": false}"#; let decoded: Test = serde_json::from_str(test).unwrap(); - assert_eq!(false, decoded.v); + assert!(!decoded.v); } } diff --git a/lambda-events/src/event/dynamodb/attributes.rs b/lambda-events/src/event/dynamodb/attributes.rs index aad2cd4b..e1a42c83 100644 --- a/lambda-events/src/event/dynamodb/attributes.rs +++ b/lambda-events/src/event/dynamodb/attributes.rs @@ -83,7 +83,7 @@ mod test { let attr: AttributeValue = serde_json::from_value(value.clone()).unwrap(); match attr { - AttributeValue::Bool(b) => assert_eq!(true, b), + AttributeValue::Bool(b) => assert!(b), other => panic!("unexpected value {:?}", other), } diff --git a/lambda-runtime-api-client/src/lib.rs b/lambda-runtime-api-client/src/lib.rs index 8e52d416..4315dfd7 100644 --- a/lambda-runtime-api-client/src/lib.rs +++ b/lambda-runtime-api-client/src/lib.rs @@ -4,10 +4,11 @@ //! This crate includes a base HTTP client to interact with //! the AWS Lambda Runtime API. +use futures_util::{future::BoxFuture, FutureExt, TryFutureExt}; use http::{uri::PathAndQuery, uri::Scheme, Request, Response, Uri}; use hyper::body::Incoming; use hyper_util::client::legacy::connect::HttpConnector; -use std::{convert::TryInto, fmt::Debug}; +use std::{convert::TryInto, fmt::Debug, future}; const USER_AGENT_HEADER: &str = "User-Agent"; const DEFAULT_USER_AGENT: &str = concat!("aws-lambda-rust/", env!("CARGO_PKG_VERSION")); @@ -42,9 +43,15 @@ impl Client { impl Client { /// Send a given request to the Runtime API. /// Use the client's base URI to ensure the API endpoint is correct. - pub async fn call(&self, req: Request) -> Result, BoxError> { - let req = self.set_origin(req)?; - self.client.request(req).await.map_err(Into::into) + pub fn call(&self, req: Request) -> BoxFuture<'static, Result, BoxError>> { + // NOTE: This method returns a boxed future such that the future has a static lifetime. + // Due to limitations around the Rust async implementation as of Mar 2024, this is + // required to minimize constraints on the handler passed to [lambda_runtime::run]. + let req = match self.set_origin(req) { + Ok(req) => req, + Err(err) => return future::ready(Err(err)).boxed(), + }; + self.client.request(req).map_err(Into::into).boxed() } /// Create a new client with a given base URI and HTTP connector. diff --git a/lambda-runtime/Cargo.toml b/lambda-runtime/Cargo.toml index d61e5594..d9eca35a 100644 --- a/lambda-runtime/Cargo.toml +++ b/lambda-runtime/Cargo.toml @@ -37,6 +37,7 @@ hyper-util = { workspace = true, features = [ "tokio", ] } lambda_runtime_api_client = { version = "0.10", path = "../lambda-runtime-api-client" } +pin-project = "1" serde = { version = "1", features = ["derive", "rc"] } serde_json = "^1" serde_path_to_error = "0.1.11" @@ -48,6 +49,7 @@ tokio = { version = "1.0", features = [ ] } tokio-stream = "0.1.2" tower = { workspace = true, features = ["util"] } +tower-layer = { workspace = true } tracing = { version = "0.1", features = ["log"] } [dev-dependencies] diff --git a/lambda-runtime/src/layers/api_client.rs b/lambda-runtime/src/layers/api_client.rs new file mode 100644 index 00000000..b6d9acf8 --- /dev/null +++ b/lambda-runtime/src/layers/api_client.rs @@ -0,0 +1,85 @@ +use crate::LambdaInvocation; +use futures::{future::BoxFuture, ready, FutureExt, TryFutureExt}; +use hyper::body::Incoming; +use lambda_runtime_api_client::{body::Body, BoxError, Client}; +use pin_project::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task; +use tower::Service; +use tracing::error; + +/// Tower service that sends a Lambda Runtime API response to the Lambda Runtime HTTP API using +/// a previously initialized client. +/// +/// This type is only meant for internal use in the Lambda runtime crate. It neither augments the +/// inner service's request type nor its error type. However, this service returns an empty +/// response `()` as the Lambda request has been completed. +pub struct RuntimeApiClientService { + inner: S, + client: Arc, +} + +impl RuntimeApiClientService { + pub fn new(inner: S, client: Arc) -> Self { + Self { inner, client } + } +} + +impl Service for RuntimeApiClientService +where + S: Service, + S::Future: Future, BoxError>>, +{ + type Response = (); + type Error = S::Error; + type Future = RuntimeApiClientFuture; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: LambdaInvocation) -> Self::Future { + let request_fut = self.inner.call(req); + let client = self.client.clone(); + RuntimeApiClientFuture::First(request_fut, client) + } +} + +#[pin_project(project = RuntimeApiClientFutureProj)] +pub enum RuntimeApiClientFuture { + First(#[pin] F, Arc), + Second(#[pin] BoxFuture<'static, Result, BoxError>>), +} + +impl Future for RuntimeApiClientFuture +where + F: Future, BoxError>>, +{ + type Output = Result<(), BoxError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll { + // NOTE: We loop here to directly poll the second future once the first has finished. + task::Poll::Ready(loop { + match self.as_mut().project() { + RuntimeApiClientFutureProj::First(fut, client) => match ready!(fut.poll(cx)) { + Ok(ok) => { + // NOTE: We use 'client.call_boxed' here to obtain a future with static + // lifetime. Otherwise, this future would need to be self-referential... + let next_fut = client + .call(ok) + .map_err(|err| { + error!(error = ?err, "failed to send request to Lambda Runtime API"); + err + }) + .boxed(); + self.set(RuntimeApiClientFuture::Second(next_fut)); + } + Err(err) => break Err(err), + }, + RuntimeApiClientFutureProj::Second(fut) => break ready!(fut.poll(cx)).map(|_| ()), + } + }) + } +} diff --git a/lambda-runtime/src/layers/api_response.rs b/lambda-runtime/src/layers/api_response.rs new file mode 100644 index 00000000..266402cf --- /dev/null +++ b/lambda-runtime/src/layers/api_response.rs @@ -0,0 +1,173 @@ +use crate::requests::{EventCompletionRequest, IntoRequest}; +use crate::runtime::LambdaInvocation; +use crate::types::Diagnostic; +use crate::{deserializer, IntoFunctionResponse}; +use crate::{EventErrorRequest, LambdaEvent}; +use futures::ready; +use futures::Stream; +use lambda_runtime_api_client::{body::Body, BoxError}; +use pin_project::pin_project; +use serde::{Deserialize, Serialize}; +use std::fmt::Debug; +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task; +use tower::Service; +use tracing::{error, trace}; + +/// Tower service that turns the result or an error of a handler function into a Lambda Runtime API +/// response. +/// +/// This type is only meant for internal use in the Lambda runtime crate. The service augments both +/// inputs and outputs: the input is converted from a [LambdaInvocation] into a [LambdaEvent] +/// while any errors encountered during the conversion are turned into error responses. The service +/// outputs either a HTTP request to send to the Lambda Runtime API or a boxed error which ought to +/// be propagated to the caller to terminate the runtime. +pub struct RuntimeApiResponseService< + S, + EventPayload, + Response, + BufferedResponse, + StreamingResponse, + StreamItem, + StreamError, +> { + inner: S, + _phantom: PhantomData<( + EventPayload, + Response, + BufferedResponse, + StreamingResponse, + StreamItem, + StreamError, + )>, +} + +impl + RuntimeApiResponseService +{ + pub fn new(inner: S) -> Self { + Self { + inner, + _phantom: PhantomData, + } + } +} + +impl<'a, S, EventPayload, Response, BufferedResponse, StreamingResponse, StreamItem, StreamError> + Service + for RuntimeApiResponseService< + S, + EventPayload, + Response, + BufferedResponse, + StreamingResponse, + StreamItem, + StreamError, + > +where + S: Service, Response = Response, Error = Diagnostic<'a>>, + EventPayload: for<'de> Deserialize<'de>, + Response: IntoFunctionResponse, + BufferedResponse: Serialize, + StreamingResponse: Stream> + Unpin + Send + 'static, + StreamItem: Into + Send, + StreamError: Into + Send + Debug, +{ + type Response = http::Request; + type Error = BoxError; + type Future = + RuntimeApiResponseFuture<'a, S::Future, Response, BufferedResponse, StreamingResponse, StreamItem, StreamError>; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> task::Poll> { + self.inner + .poll_ready(cx) + .map_err(|err| BoxError::from(format!("{}: {}", err.error_type, err.error_message))) + } + + fn call(&mut self, req: LambdaInvocation) -> Self::Future { + #[cfg(debug_assertions)] + if req.parts.status.is_server_error() { + error!("Lambda Runtime server returned an unexpected error"); + return RuntimeApiResponseFuture::Ready(Some(Err(req.parts.status.to_string().into()))); + } + + // Utility closure to propagate potential error from conditionally executed trace + let trace_fn = || { + trace!( + body = std::str::from_utf8(&req.body)?, + "raw JSON event received from Lambda" + ); + Ok(()) + }; + if let Err(err) = trace_fn() { + error!(error = ?err, "failed to parse raw JSON event received from Lambda"); + return RuntimeApiResponseFuture::Ready(Some(Err(err))); + }; + + let request_id = req.context.request_id.clone(); + let lambda_event = match deserializer::deserialize::(&req.body, req.context) { + Ok(lambda_event) => lambda_event, + Err(err) => match build_event_error_request(&request_id, err) { + Ok(request) => return RuntimeApiResponseFuture::Ready(Some(Ok(request))), + Err(err) => { + error!(error = ?err, "failed to build error response for Lambda Runtime API"); + return RuntimeApiResponseFuture::Ready(Some(Err(err))); + } + }, + }; + + // Once the handler input has been generated successfully, the + let fut = self.inner.call(lambda_event); + RuntimeApiResponseFuture::Future(fut, request_id, PhantomData) + } +} + +fn build_event_error_request<'a, T>(request_id: &'a str, err: T) -> Result, BoxError> +where + T: Into> + Debug, +{ + error!(error = ?err, "building error response for Lambda Runtime API"); + EventErrorRequest::new(request_id, err).into_req() +} + +#[pin_project(project = RuntimeApiResponseFutureProj)] +pub enum RuntimeApiResponseFuture<'a, F, Response, BufferedResponse, StreamingResponse, StreamItem, StreamError> { + Future( + #[pin] F, + String, + PhantomData<( + &'a (), + Response, + BufferedResponse, + StreamingResponse, + StreamItem, + StreamError, + )>, + ), + Ready(Option, BoxError>>), +} + +impl<'a, F, Response, BufferedResponse, StreamingResponse, StreamItem, StreamError> Future + for RuntimeApiResponseFuture<'a, F, Response, BufferedResponse, StreamingResponse, StreamItem, StreamError> +where + F: Future>>, + Response: IntoFunctionResponse, + BufferedResponse: Serialize, + StreamingResponse: Stream> + Unpin + Send + 'static, + StreamItem: Into + Send, + StreamError: Into + Send + Debug, +{ + type Output = Result, BoxError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll { + task::Poll::Ready(match self.as_mut().project() { + RuntimeApiResponseFutureProj::Future(fut, request_id, _) => match ready!(fut.poll(cx)) { + Ok(ok) => EventCompletionRequest::new(request_id, ok).into_req(), + Err(err) => EventErrorRequest::new(request_id, err).into_req(), + }, + RuntimeApiResponseFutureProj::Ready(ready) => ready.take().expect("future polled after completion"), + }) + } +} diff --git a/lambda-runtime/src/layers/mod.rs b/lambda-runtime/src/layers/mod.rs new file mode 100644 index 00000000..27ce0d68 --- /dev/null +++ b/lambda-runtime/src/layers/mod.rs @@ -0,0 +1,12 @@ +// Internally used services. +mod api_client; +mod api_response; +mod panic; + +// Publicly available services. +mod trace; + +pub(crate) use api_client::RuntimeApiClientService; +pub(crate) use api_response::RuntimeApiResponseService; +pub(crate) use panic::CatchPanicService; +pub use trace::TracingLayer; diff --git a/lambda-runtime/src/layers/panic.rs b/lambda-runtime/src/layers/panic.rs new file mode 100644 index 00000000..26ceeecc --- /dev/null +++ b/lambda-runtime/src/layers/panic.rs @@ -0,0 +1,118 @@ +use crate::{Diagnostic, LambdaEvent}; +use futures::{future::CatchUnwind, FutureExt}; +use pin_project::pin_project; +use std::any::Any; +use std::borrow::Cow; +use std::fmt::Debug; +use std::future::Future; +use std::marker::PhantomData; +use std::panic::AssertUnwindSafe; +use std::pin::Pin; +use std::task; +use tower::Service; +use tracing::error; + +/// Tower service that transforms panics into an error. Panics are converted to errors both when +/// constructed in [tower::Service::call] and when constructed in the returned +/// [tower::Service::Future]. +/// +/// This type is only meant for internal use in the Lambda runtime crate. It neither augments the +/// inner service's request type, nor its response type. It merely transforms the error type +/// from `Into + Debug` into `Diagnostic<'a>` to turn panics into diagnostics. +#[derive(Clone)] +pub struct CatchPanicService<'a, S> { + inner: S, + _phantom: PhantomData<&'a ()>, +} + +impl<'a, S> CatchPanicService<'a, S> { + pub fn new(inner: S) -> Self { + Self { + inner, + _phantom: PhantomData, + } + } +} + +impl<'a, S, Payload> Service> for CatchPanicService<'a, S> +where + S: Service>, + S::Future: 'a, + S::Error: Into> + Debug, +{ + type Error = Diagnostic<'a>; + type Response = S::Response; + type Future = CatchPanicFuture<'a, S::Future>; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> task::Poll> { + self.inner.poll_ready(cx).map_err(|err| err.into()) + } + + fn call(&mut self, req: LambdaEvent) -> Self::Future { + // Catch panics that result from calling `call` on the service + let task = std::panic::catch_unwind(AssertUnwindSafe(|| self.inner.call(req))); + + // Catch panics that result from polling the future returned from `call` + match task { + Ok(task) => { + let fut = AssertUnwindSafe(task).catch_unwind(); + CatchPanicFuture::Future(fut, PhantomData) + } + Err(err) => { + error!(error = ?err, "user handler panicked"); + CatchPanicFuture::Error(err) + } + } + } +} + +/// Future returned by [CatchPanicService]. +#[pin_project(project = CatchPanicFutureProj)] +pub enum CatchPanicFuture<'a, F> { + Future(#[pin] CatchUnwind>, PhantomData<&'a ()>), + Error(Box), +} + +impl<'a, F, T, E> Future for CatchPanicFuture<'a, F> +where + F: Future>, + E: Into> + Debug, +{ + type Output = Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll { + use task::Poll; + match self.project() { + CatchPanicFutureProj::Future(fut, _) => match fut.poll(cx) { + Poll::Ready(ready) => match ready { + Ok(inner_result) => Poll::Ready(inner_result.map_err(|err| err.into())), + Err(err) => { + error!(error = ?err, "user handler panicked"); + Poll::Ready(Err(Self::build_panic_diagnostic(&err))) + } + }, + Poll::Pending => Poll::Pending, + }, + CatchPanicFutureProj::Error(err) => Poll::Ready(Err(Self::build_panic_diagnostic(err))), + } + } +} + +impl<'a, F> CatchPanicFuture<'a, F> { + fn build_panic_diagnostic(err: &Box) -> Diagnostic<'a> { + let error_type = type_name_of_val(&err); + let msg = if let Some(msg) = err.downcast_ref::<&str>() { + format!("Lambda panicked: {msg}") + } else { + "Lambda panicked".to_string() + }; + Diagnostic { + error_type: Cow::Borrowed(error_type), + error_message: Cow::Owned(msg), + } + } +} + +fn type_name_of_val(_: T) -> &'static str { + std::any::type_name::() +} diff --git a/lambda-runtime/src/layers/trace.rs b/lambda-runtime/src/layers/trace.rs new file mode 100644 index 00000000..0d635154 --- /dev/null +++ b/lambda-runtime/src/layers/trace.rs @@ -0,0 +1,68 @@ +use std::env; +use tower::{Layer, Service}; +use tracing::{instrument::Instrumented, Instrument}; + +use crate::{Context, LambdaInvocation}; +use lambda_runtime_api_client::BoxError; +use std::task; + +/// Tower middleware to create a tracing span for invocations of the Lambda function. +#[derive(Default)] +pub struct TracingLayer {} + +impl TracingLayer { + /// Create a new tracing layer. + pub fn new() -> Self { + Self::default() + } +} + +impl Layer for TracingLayer { + type Service = TracingService; + + fn layer(&self, inner: S) -> Self::Service { + TracingService { inner } + } +} + +/// Tower service returned by [TracingLayer]. +pub struct TracingService { + inner: S, +} + +impl Service for TracingService +where + S: Service, +{ + type Response = (); + type Error = BoxError; + type Future = Instrumented; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: LambdaInvocation) -> Self::Future { + let span = request_span(&req.context); + self.inner.call(req).instrument(span) + } +} + +/* ------------------------------------------- UTILS ------------------------------------------- */ + +fn request_span(ctx: &Context) -> tracing::Span { + match &ctx.xray_trace_id { + Some(trace_id) => { + env::set_var("_X_AMZN_TRACE_ID", trace_id); + tracing::info_span!( + "Lambda runtime invoke", + requestId = &ctx.request_id, + xrayTraceId = trace_id + ) + } + None => { + env::remove_var("_X_AMZN_TRACE_ID"); + tracing::info_span!("Lambda runtime invoke", requestId = &ctx.request_id) + } + } +} diff --git a/lambda-runtime/src/lib.rs b/lambda-runtime/src/lib.rs index 3fe56b03..9638df64 100644 --- a/lambda-runtime/src/lib.rs +++ b/lambda-runtime/src/lib.rs @@ -7,27 +7,22 @@ //! Create a type that conforms to the [`tower::Service`] trait. This type can //! then be passed to the the `lambda_runtime::run` function, which launches //! and runs the Lambda runtime. -use ::tracing::{error, trace, Instrument}; -use bytes::Bytes; -use futures::FutureExt; -use http_body_util::BodyExt; -use hyper::{body::Incoming, http::Request}; -use lambda_runtime_api_client::{body::Body, BoxError, Client}; use serde::{Deserialize, Serialize}; use std::{ - borrow::Cow, env, fmt::{self, Debug}, future::Future, - panic, sync::Arc, }; -use tokio_stream::{Stream, StreamExt}; +use tokio_stream::Stream; +use tower::util::ServiceFn; pub use tower::{self, service_fn, Service}; -use tower::{util::ServiceFn, ServiceExt}; mod deserializer; +/// Tower middleware to be applied to runtime invocatinos. +pub mod layers; mod requests; +mod runtime; /// Utilities for Lambda Streaming functions. pub mod streaming; @@ -38,13 +33,12 @@ pub use lambda_runtime_api_client::tracing; /// Types available to a Lambda function. mod types; -use requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEventRequest}; +use requests::EventErrorRequest; +pub use runtime::{LambdaInvocation, Runtime}; pub use types::{ Context, Diagnostic, FunctionResponse, IntoFunctionResponse, LambdaEvent, MetadataPrelude, StreamResponse, }; -use types::invoke_request_id; - /// Error type that lambdas may result in pub type Error = lambda_runtime_api_client::BoxError; @@ -90,135 +84,12 @@ where service_fn(move |req: LambdaEvent| f(req.payload, req.context)) } -struct Runtime { - client: Client, - config: RefConfig, -} - -impl Runtime { - async fn run( - &self, - incoming: impl Stream, Error>> + Send, - mut handler: F, - ) -> Result<(), BoxError> - where - F: Service>, - F::Future: Future>, - F::Error: for<'a> Into> + fmt::Debug, - A: for<'de> Deserialize<'de>, - R: IntoFunctionResponse, - B: Serialize, - S: Stream> + Unpin + Send + 'static, - D: Into + Send, - E: Into + Send + Debug, - { - let client = &self.client; - tokio::pin!(incoming); - while let Some(next_event_response) = incoming.next().await { - trace!("New event arrived (run loop)"); - let event = next_event_response?; - let (parts, body) = event.into_parts(); - let request_id = invoke_request_id(&parts.headers)?; - - #[cfg(debug_assertions)] - if parts.status == http::StatusCode::NO_CONTENT { - // Ignore the event if the status code is 204. - // This is a way to keep the runtime alive when - // there are no events pending to be processed. - continue; - } - - let ctx: Context = Context::new(request_id, self.config.clone(), &parts.headers)?; - let request_span = ctx.request_span(); - - // Group the handling in one future and instrument it with the span - async { - let body = body.collect().await?.to_bytes(); - trace!( - body = std::str::from_utf8(&body)?, - "raw JSON event received from Lambda" - ); - - #[cfg(debug_assertions)] - if parts.status.is_server_error() { - error!("Lambda Runtime server returned an unexpected error"); - return Err(parts.status.to_string().into()); - } - - let lambda_event = match deserializer::deserialize(&body, ctx) { - Ok(lambda_event) => lambda_event, - Err(err) => { - let req = build_event_error_request(request_id, err)?; - client.call(req).await.expect("Unable to send response to Runtime APIs"); - return Ok(()); - } - }; - - let req = match handler.ready().await { - Ok(handler) => { - // Catches panics outside of a `Future` - let task = panic::catch_unwind(panic::AssertUnwindSafe(|| handler.call(lambda_event))); - - let task = match task { - // Catches panics inside of the `Future` - Ok(task) => panic::AssertUnwindSafe(task).catch_unwind().await, - Err(err) => Err(err), - }; - - match task { - Ok(response) => match response { - Ok(response) => { - trace!("Ok response from handler (run loop)"); - EventCompletionRequest::new(request_id, response).into_req() - } - Err(err) => build_event_error_request(request_id, err), - }, - Err(err) => { - error!("{:?}", err); - let error_type = type_name_of_val(&err); - let msg = if let Some(msg) = err.downcast_ref::<&str>() { - format!("Lambda panicked: {msg}") - } else { - "Lambda panicked".to_string() - }; - EventErrorRequest::new( - request_id, - Diagnostic { - error_type: Cow::Borrowed(error_type), - error_message: Cow::Owned(msg), - }, - ) - .into_req() - } - } - } - Err(err) => build_event_error_request(request_id, err), - }?; - - client.call(req).await.expect("Unable to send response to Runtime APIs"); - Ok::<(), Error>(()) - } - .instrument(request_span) - .await?; - } - Ok(()) - } -} - -fn incoming(client: &Client) -> impl Stream, Error>> + Send + '_ { - async_stream::stream! { - loop { - trace!("Waiting for next event (incoming loop)"); - let req = NextEventRequest.into_req().expect("Unable to construct request"); - let res = client.call(req).await; - yield res; - } - } -} - /// Starts the Lambda Rust runtime and begins polling for events on the [Lambda /// Runtime APIs](https://docs.aws.amazon.com/lambda/latest/dg/runtimes-api.html). /// +/// If you need more control over the runtime and add custom middleware, use the +/// [Runtime] type directly. +/// /// # Example /// ```no_run /// use lambda_runtime::{Error, service_fn, LambdaEvent}; @@ -237,272 +108,16 @@ fn incoming(client: &Client) -> impl Stream(handler: F) -> Result<(), Error> where - F: Service>, + F: Service, Response = R>, F::Future: Future>, F::Error: for<'a> Into> + fmt::Debug, A: for<'de> Deserialize<'de>, R: IntoFunctionResponse, B: Serialize, S: Stream> + Unpin + Send + 'static, - D: Into + Send, + D: Into + Send, E: Into + Send + Debug, { - trace!("Loading config from env"); - let config = Config::from_env(); - let client = Client::builder().build().expect("Unable to create a runtime client"); - let runtime = Runtime { - client, - config: Arc::new(config), - }; - - let client = &runtime.client; - let incoming = incoming(client); - runtime.run(incoming, handler).await -} - -fn type_name_of_val(_: T) -> &'static str { - std::any::type_name::() -} - -fn build_event_error_request<'a, T>(request_id: &'a str, err: T) -> Result, Error> -where - T: Into> + Debug, -{ - error!("{:?}", err); // logs the error in CloudWatch - EventErrorRequest::new(request_id, err).into_req() -} - -#[cfg(test)] -mod endpoint_tests { - use crate::{ - incoming, - requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEventRequest}, - types::Diagnostic, - Config, Error, Runtime, - }; - use futures::future::BoxFuture; - use http::{HeaderValue, StatusCode}; - use http_body_util::BodyExt; - use httpmock::prelude::*; - - use lambda_runtime_api_client::Client; - use std::{borrow::Cow, env, sync::Arc}; - use tokio_stream::StreamExt; - - #[tokio::test] - async fn test_next_event() -> Result<(), Error> { - let server = MockServer::start(); - let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9"; - let deadline = "1542409706888"; - - let mock = server.mock(|when, then| { - when.method(GET).path("/2018-06-01/runtime/invocation/next"); - then.status(200) - .header("content-type", "application/json") - .header("lambda-runtime-aws-request-id", request_id) - .header("lambda-runtime-deadline-ms", deadline) - .body("{}"); - }); - - let base = server.base_url().parse().expect("Invalid mock server Uri"); - let client = Client::builder().with_endpoint(base).build()?; - - let req = NextEventRequest.into_req()?; - let rsp = client.call(req).await.expect("Unable to send request"); - - mock.assert_async().await; - assert_eq!(rsp.status(), StatusCode::OK); - assert_eq!( - rsp.headers()["lambda-runtime-aws-request-id"], - &HeaderValue::from_static(request_id) - ); - assert_eq!( - rsp.headers()["lambda-runtime-deadline-ms"], - &HeaderValue::from_static(deadline) - ); - - let body = rsp.into_body().collect().await?.to_bytes(); - assert_eq!("{}", std::str::from_utf8(&body)?); - Ok(()) - } - - #[tokio::test] - async fn test_ok_response() -> Result<(), Error> { - let server = MockServer::start(); - - let mock = server.mock(|when, then| { - when.method(POST) - .path("/2018-06-01/runtime/invocation/156cb537-e2d4-11e8-9b34-d36013741fb9/response") - .body("\"{}\""); - then.status(200).body(""); - }); - - let base = server.base_url().parse().expect("Invalid mock server Uri"); - let client = Client::builder().with_endpoint(base).build()?; - - let req = EventCompletionRequest::new("156cb537-e2d4-11e8-9b34-d36013741fb9", "{}"); - let req = req.into_req()?; - - let rsp = client.call(req).await?; - - mock.assert_async().await; - assert_eq!(rsp.status(), StatusCode::OK); - Ok(()) - } - - #[tokio::test] - async fn test_error_response() -> Result<(), Error> { - let diagnostic = Diagnostic { - error_type: Cow::Borrowed("InvalidEventDataError"), - error_message: Cow::Borrowed("Error parsing event data"), - }; - let body = serde_json::to_string(&diagnostic)?; - - let server = MockServer::start(); - let mock = server.mock(|when, then| { - when.method(POST) - .path("/2018-06-01/runtime/invocation/156cb537-e2d4-11e8-9b34-d36013741fb9/error") - .header("lambda-runtime-function-error-type", "unhandled") - .body(body); - then.status(200).body(""); - }); - - let base = server.base_url().parse().expect("Invalid mock server Uri"); - let client = Client::builder().with_endpoint(base).build()?; - - let req = EventErrorRequest { - request_id: "156cb537-e2d4-11e8-9b34-d36013741fb9", - diagnostic, - }; - let req = req.into_req()?; - let rsp = client.call(req).await?; - - mock.assert_async().await; - assert_eq!(rsp.status(), StatusCode::OK); - Ok(()) - } - - #[tokio::test] - async fn successful_end_to_end_run() -> Result<(), Error> { - let server = MockServer::start(); - let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9"; - let deadline = "1542409706888"; - - let next_request = server.mock(|when, then| { - when.method(GET).path("/2018-06-01/runtime/invocation/next"); - then.status(200) - .header("content-type", "application/json") - .header("lambda-runtime-aws-request-id", request_id) - .header("lambda-runtime-deadline-ms", deadline) - .body("{}"); - }); - let next_response = server.mock(|when, then| { - when.method(POST) - .path(format!("/2018-06-01/runtime/invocation/{}/response", request_id)) - .body("{}"); - then.status(200).body(""); - }); - - let base = server.base_url().parse().expect("Invalid mock server Uri"); - let client = Client::builder().with_endpoint(base).build()?; - - async fn func(event: crate::LambdaEvent) -> Result { - let (event, _) = event.into_parts(); - Ok(event) - } - let f = crate::service_fn(func); - - // set env vars needed to init Config if they are not already set in the environment - if env::var("AWS_LAMBDA_RUNTIME_API").is_err() { - env::set_var("AWS_LAMBDA_RUNTIME_API", server.base_url()); - } - if env::var("AWS_LAMBDA_FUNCTION_NAME").is_err() { - env::set_var("AWS_LAMBDA_FUNCTION_NAME", "test_fn"); - } - if env::var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE").is_err() { - env::set_var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "128"); - } - if env::var("AWS_LAMBDA_FUNCTION_VERSION").is_err() { - env::set_var("AWS_LAMBDA_FUNCTION_VERSION", "1"); - } - if env::var("AWS_LAMBDA_LOG_STREAM_NAME").is_err() { - env::set_var("AWS_LAMBDA_LOG_STREAM_NAME", "test_stream"); - } - if env::var("AWS_LAMBDA_LOG_GROUP_NAME").is_err() { - env::set_var("AWS_LAMBDA_LOG_GROUP_NAME", "test_log"); - } - let config = Config::from_env(); - - let runtime = Runtime { - client, - config: Arc::new(config), - }; - let client = &runtime.client; - let incoming = incoming(client).take(1); - runtime.run(incoming, f).await?; - - next_request.assert_async().await; - next_response.assert_async().await; - Ok(()) - } - - async fn run_panicking_handler(func: F) -> Result<(), Error> - where - F: FnMut(crate::LambdaEvent) -> BoxFuture<'static, Result>, - { - let server = MockServer::start(); - let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9"; - let deadline = "1542409706888"; - - let next_request = server.mock(|when, then| { - when.method(GET).path("/2018-06-01/runtime/invocation/next"); - then.status(200) - .header("content-type", "application/json") - .header("lambda-runtime-aws-request-id", request_id) - .header("lambda-runtime-deadline-ms", deadline) - .body("{}"); - }); - - let next_response = server.mock(|when, then| { - when.method(POST) - .path(format!("/2018-06-01/runtime/invocation/{}/error", request_id)) - .header("lambda-runtime-function-error-type", "unhandled"); - then.status(200).body(""); - }); - - let base = server.base_url().parse().expect("Invalid mock server Uri"); - let client = Client::builder().with_endpoint(base).build()?; - - let f = crate::service_fn(func); - - let config = Arc::new(Config { - function_name: "test_fn".to_string(), - memory: 128, - version: "1".to_string(), - log_stream: "test_stream".to_string(), - log_group: "test_log".to_string(), - }); - - let runtime = Runtime { client, config }; - let client = &runtime.client; - let incoming = incoming(client).take(1); - runtime.run(incoming, f).await?; - - next_request.assert_async().await; - next_response.assert_async().await; - Ok(()) - } - - #[tokio::test] - async fn panic_in_async_run() -> Result<(), Error> { - run_panicking_handler(|_| Box::pin(async { panic!("This is intentionally here") })).await - } - - #[tokio::test] - async fn panic_outside_async_run() -> Result<(), Error> { - run_panicking_handler(|_| { - panic!("This is intentionally here"); - }) - .await - } + let runtime = Runtime::new(handler).layer(layers::TracingLayer::new()); + runtime.run().await } diff --git a/lambda-runtime/src/runtime.rs b/lambda-runtime/src/runtime.rs new file mode 100644 index 00000000..0fc328cf --- /dev/null +++ b/lambda-runtime/src/runtime.rs @@ -0,0 +1,481 @@ +use super::requests::{IntoRequest, NextEventRequest}; +use super::types::{invoke_request_id, Diagnostic, IntoFunctionResponse, LambdaEvent}; +use crate::layers::{CatchPanicService, RuntimeApiClientService, RuntimeApiResponseService}; +use crate::{Config, Context}; +use http_body_util::BodyExt; +use lambda_runtime_api_client::BoxError; +use lambda_runtime_api_client::Client as ApiClient; +use serde::{Deserialize, Serialize}; +use std::fmt::Debug; +use std::future::Future; +use std::sync::Arc; +use tokio_stream::{Stream, StreamExt}; +use tower::Layer; +use tower::{Service, ServiceExt}; +use tracing::trace; + +/* ----------------------------------------- INVOCATION ---------------------------------------- */ + +/// A simple container that provides information about a single invocation of a Lambda function. +pub struct LambdaInvocation { + /// The header of the request sent to invoke the Lambda function. + pub parts: http::response::Parts, + /// The body of the request sent to invoke the Lambda function. + pub body: bytes::Bytes, + /// The context of the Lambda invocation. + pub context: Context, +} + +/* ------------------------------------------ RUNTIME ------------------------------------------ */ + +/// Lambda runtime executing a handler function on incoming requests. +/// +/// Middleware can be added to a runtime using the [Runtime::layer] method in order to execute +/// logic prior to processing the incoming request and/or after the response has been sent back +/// to the Lambda Runtime API. +/// +/// # Example +/// ```no_run +/// use lambda_runtime::{Error, LambdaEvent, Runtime}; +/// use serde_json::Value; +/// use tower::service_fn; +/// +/// #[tokio::main] +/// async fn main() -> Result<(), Error> { +/// let func = service_fn(func); +/// Runtime::new(func).run().await?; +/// Ok(()) +/// } +/// +/// async fn func(event: LambdaEvent) -> Result { +/// Ok(event.payload) +/// } +/// ```` +pub struct Runtime { + service: S, + config: Arc, + client: Arc, +} + +impl<'a, F, EventPayload, Response, BufferedResponse, StreamingResponse, StreamItem, StreamError> + Runtime< + RuntimeApiClientService< + RuntimeApiResponseService< + CatchPanicService<'a, F>, + EventPayload, + Response, + BufferedResponse, + StreamingResponse, + StreamItem, + StreamError, + >, + >, + > +where + F: Service, Response = Response>, + F::Future: Future>, + F::Error: Into> + Debug, + EventPayload: for<'de> Deserialize<'de>, + Response: IntoFunctionResponse, + BufferedResponse: Serialize, + StreamingResponse: Stream> + Unpin + Send + 'static, + StreamItem: Into + Send, + StreamError: Into + Send + Debug, +{ + /// Create a new runtime that executes the provided handler for incoming requests. + /// + /// In order to start the runtime and poll for events on the [Lambda Runtime + /// APIs](https://docs.aws.amazon.com/lambda/latest/dg/runtimes-api.html), you must call + /// [Runtime::run]. + /// + /// Note that manually creating a [Runtime] does not add tracing to the executed handler + /// as is done by [super::run]. If you want to add the default tracing functionality, call + /// [Runtime::layer] with a [super::layers::TracingLayer]. + pub fn new(handler: F) -> Self { + trace!("Loading config from env"); + let config = Arc::new(Config::from_env()); + let client = Arc::new(ApiClient::builder().build().expect("Unable to create a runtime client")); + Self { + service: wrap_handler(handler, client.clone()), + config, + client, + } + } +} + +impl Runtime { + /// Add a new layer to this runtime. For an incoming request, this layer will be executed + /// before any layer that has been added prior. + /// + /// # Example + /// ```no_run + /// use lambda_runtime::{layers, Error, LambdaEvent, Runtime}; + /// use serde_json::Value; + /// use tower::service_fn; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Error> { + /// let runtime = Runtime::new(service_fn(echo)).layer( + /// layers::TracingLayer::new() + /// ); + /// runtime.run().await?; + /// Ok(()) + /// } + /// + /// async fn echo(event: LambdaEvent) -> Result { + /// Ok(event.payload) + /// } + /// ``` + pub fn layer(self, layer: L) -> Runtime + where + L: Layer, + L::Service: Service, + { + Runtime { + client: self.client, + config: self.config, + service: layer.layer(self.service), + } + } +} + +impl Runtime +where + S: Service, +{ + /// Start the runtime and begin polling for events on the Lambda Runtime API. + pub async fn run(self) -> Result<(), BoxError> { + let incoming = incoming(&self.client); + Self::run_with_incoming(self.service, self.config, incoming).await + } + + /// Internal utility function to start the runtime with a customized incoming stream. + /// This implements the core of the [Runtime::run] method. + pub(crate) async fn run_with_incoming( + mut service: S, + config: Arc, + incoming: impl Stream, BoxError>> + Send, + ) -> Result<(), BoxError> { + tokio::pin!(incoming); + while let Some(next_event_response) = incoming.next().await { + trace!("New event arrived (run loop)"); + let event = next_event_response?; + let (parts, incoming) = event.into_parts(); + + #[cfg(debug_assertions)] + if parts.status == http::StatusCode::NO_CONTENT { + // Ignore the event if the status code is 204. + // This is a way to keep the runtime alive when + // there are no events pending to be processed. + continue; + } + + // Build the invocation such that it can be sent to the service right away + // when it is ready + let body = incoming.collect().await?.to_bytes(); + let context = Context::new(invoke_request_id(&parts.headers)?, config.clone(), &parts.headers)?; + let invocation = LambdaInvocation { parts, body, context }; + + // Wait for service to be ready + let ready = service.ready().await?; + + // Once ready, call the service which will respond to the Lambda runtime API + ready.call(invocation).await?; + } + Ok(()) + } +} + +/* ------------------------------------------- UTILS ------------------------------------------- */ + +#[allow(clippy::type_complexity)] +fn wrap_handler<'a, F, EventPayload, Response, BufferedResponse, StreamingResponse, StreamItem, StreamError>( + handler: F, + client: Arc, +) -> RuntimeApiClientService< + RuntimeApiResponseService< + CatchPanicService<'a, F>, + EventPayload, + Response, + BufferedResponse, + StreamingResponse, + StreamItem, + StreamError, + >, +> +where + F: Service, Response = Response>, + F::Future: Future>, + F::Error: Into> + Debug, + EventPayload: for<'de> Deserialize<'de>, + Response: IntoFunctionResponse, + BufferedResponse: Serialize, + StreamingResponse: Stream> + Unpin + Send + 'static, + StreamItem: Into + Send, + StreamError: Into + Send + Debug, +{ + let safe_service = CatchPanicService::new(handler); + let response_service = RuntimeApiResponseService::new(safe_service); + RuntimeApiClientService::new(response_service, client) +} + +fn incoming( + client: &ApiClient, +) -> impl Stream, BoxError>> + Send + '_ { + async_stream::stream! { + loop { + trace!("Waiting for next event (incoming loop)"); + let req = NextEventRequest.into_req().expect("Unable to construct request"); + let res = client.call(req).await; + yield res; + } + } +} + +/* --------------------------------------------------------------------------------------------- */ +/* TESTS */ +/* --------------------------------------------------------------------------------------------- */ + +#[cfg(test)] +mod endpoint_tests { + use super::{incoming, wrap_handler}; + use crate::{ + requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEventRequest}, + types::Diagnostic, + Config, Error, Runtime, + }; + use futures::future::BoxFuture; + use http::{HeaderValue, StatusCode}; + use http_body_util::BodyExt; + use httpmock::prelude::*; + + use lambda_runtime_api_client::Client; + use std::{borrow::Cow, env, sync::Arc}; + use tokio_stream::StreamExt; + + #[tokio::test] + async fn test_next_event() -> Result<(), Error> { + let server = MockServer::start(); + let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9"; + let deadline = "1542409706888"; + + let mock = server.mock(|when, then| { + when.method(GET).path("/2018-06-01/runtime/invocation/next"); + then.status(200) + .header("content-type", "application/json") + .header("lambda-runtime-aws-request-id", request_id) + .header("lambda-runtime-deadline-ms", deadline) + .body("{}"); + }); + + let base = server.base_url().parse().expect("Invalid mock server Uri"); + let client = Client::builder().with_endpoint(base).build()?; + + let req = NextEventRequest.into_req()?; + let rsp = client.call(req).await.expect("Unable to send request"); + + mock.assert_async().await; + assert_eq!(rsp.status(), StatusCode::OK); + assert_eq!( + rsp.headers()["lambda-runtime-aws-request-id"], + &HeaderValue::from_static(request_id) + ); + assert_eq!( + rsp.headers()["lambda-runtime-deadline-ms"], + &HeaderValue::from_static(deadline) + ); + + let body = rsp.into_body().collect().await?.to_bytes(); + assert_eq!("{}", std::str::from_utf8(&body)?); + Ok(()) + } + + #[tokio::test] + async fn test_ok_response() -> Result<(), Error> { + let server = MockServer::start(); + + let mock = server.mock(|when, then| { + when.method(POST) + .path("/2018-06-01/runtime/invocation/156cb537-e2d4-11e8-9b34-d36013741fb9/response") + .body("\"{}\""); + then.status(200).body(""); + }); + + let base = server.base_url().parse().expect("Invalid mock server Uri"); + let client = Client::builder().with_endpoint(base).build()?; + + let req = EventCompletionRequest::new("156cb537-e2d4-11e8-9b34-d36013741fb9", "{}"); + let req = req.into_req()?; + + let rsp = client.call(req).await?; + + mock.assert_async().await; + assert_eq!(rsp.status(), StatusCode::OK); + Ok(()) + } + + #[tokio::test] + async fn test_error_response() -> Result<(), Error> { + let diagnostic = Diagnostic { + error_type: Cow::Borrowed("InvalidEventDataError"), + error_message: Cow::Borrowed("Error parsing event data"), + }; + let body = serde_json::to_string(&diagnostic)?; + + let server = MockServer::start(); + let mock = server.mock(|when, then| { + when.method(POST) + .path("/2018-06-01/runtime/invocation/156cb537-e2d4-11e8-9b34-d36013741fb9/error") + .header("lambda-runtime-function-error-type", "unhandled") + .body(body); + then.status(200).body(""); + }); + + let base = server.base_url().parse().expect("Invalid mock server Uri"); + let client = Client::builder().with_endpoint(base).build()?; + + let req = EventErrorRequest { + request_id: "156cb537-e2d4-11e8-9b34-d36013741fb9", + diagnostic, + }; + let req = req.into_req()?; + let rsp = client.call(req).await?; + + mock.assert_async().await; + assert_eq!(rsp.status(), StatusCode::OK); + Ok(()) + } + + #[tokio::test] + async fn successful_end_to_end_run() -> Result<(), Error> { + let server = MockServer::start(); + let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9"; + let deadline = "1542409706888"; + + let next_request = server.mock(|when, then| { + when.method(GET).path("/2018-06-01/runtime/invocation/next"); + then.status(200) + .header("content-type", "application/json") + .header("lambda-runtime-aws-request-id", request_id) + .header("lambda-runtime-deadline-ms", deadline) + .body("{}"); + }); + let next_response = server.mock(|when, then| { + when.method(POST) + .path(format!("/2018-06-01/runtime/invocation/{}/response", request_id)) + .body("{}"); + then.status(200).body(""); + }); + + let base = server.base_url().parse().expect("Invalid mock server Uri"); + let client = Client::builder().with_endpoint(base).build()?; + + async fn func(event: crate::LambdaEvent) -> Result { + let (event, _) = event.into_parts(); + Ok(event) + } + let f = crate::service_fn(func); + + // set env vars needed to init Config if they are not already set in the environment + if env::var("AWS_LAMBDA_RUNTIME_API").is_err() { + env::set_var("AWS_LAMBDA_RUNTIME_API", server.base_url()); + } + if env::var("AWS_LAMBDA_FUNCTION_NAME").is_err() { + env::set_var("AWS_LAMBDA_FUNCTION_NAME", "test_fn"); + } + if env::var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE").is_err() { + env::set_var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "128"); + } + if env::var("AWS_LAMBDA_FUNCTION_VERSION").is_err() { + env::set_var("AWS_LAMBDA_FUNCTION_VERSION", "1"); + } + if env::var("AWS_LAMBDA_LOG_STREAM_NAME").is_err() { + env::set_var("AWS_LAMBDA_LOG_STREAM_NAME", "test_stream"); + } + if env::var("AWS_LAMBDA_LOG_GROUP_NAME").is_err() { + env::set_var("AWS_LAMBDA_LOG_GROUP_NAME", "test_log"); + } + let config = Config::from_env(); + + let client = Arc::new(client); + let runtime = Runtime { + client: client.clone(), + config: Arc::new(config), + service: wrap_handler(f, client), + }; + let client = &runtime.client; + let incoming = incoming(client).take(1); + Runtime::run_with_incoming(runtime.service, runtime.config, incoming).await?; + + next_request.assert_async().await; + next_response.assert_async().await; + Ok(()) + } + + async fn run_panicking_handler(func: F) -> Result<(), Error> + where + F: FnMut(crate::LambdaEvent) -> BoxFuture<'static, Result> + + Send + + 'static, + { + let server = MockServer::start(); + let request_id = "156cb537-e2d4-11e8-9b34-d36013741fb9"; + let deadline = "1542409706888"; + + let next_request = server.mock(|when, then| { + when.method(GET).path("/2018-06-01/runtime/invocation/next"); + then.status(200) + .header("content-type", "application/json") + .header("lambda-runtime-aws-request-id", request_id) + .header("lambda-runtime-deadline-ms", deadline) + .body("{}"); + }); + + let next_response = server.mock(|when, then| { + when.method(POST) + .path(format!("/2018-06-01/runtime/invocation/{}/error", request_id)) + .header("lambda-runtime-function-error-type", "unhandled"); + then.status(200).body(""); + }); + + let base = server.base_url().parse().expect("Invalid mock server Uri"); + let client = Client::builder().with_endpoint(base).build()?; + + let f = crate::service_fn(func); + + let config = Arc::new(Config { + function_name: "test_fn".to_string(), + memory: 128, + version: "1".to_string(), + log_stream: "test_stream".to_string(), + log_group: "test_log".to_string(), + }); + + let client = Arc::new(client); + let runtime = Runtime { + client: client.clone(), + config, + service: wrap_handler(f, client), + }; + let client = &runtime.client; + let incoming = incoming(client).take(1); + Runtime::run_with_incoming(runtime.service, runtime.config, incoming).await?; + + next_request.assert_async().await; + next_response.assert_async().await; + Ok(()) + } + + #[tokio::test] + async fn panic_in_async_run() -> Result<(), Error> { + run_panicking_handler(|_| Box::pin(async { panic!("This is intentionally here") })).await + } + + #[tokio::test] + async fn panic_outside_async_run() -> Result<(), Error> { + run_panicking_handler(|_| { + panic!("This is intentionally here"); + }) + .await + } +} diff --git a/lambda-runtime/src/types.rs b/lambda-runtime/src/types.rs index 478f88fd..b4f10f71 100644 --- a/lambda-runtime/src/types.rs +++ b/lambda-runtime/src/types.rs @@ -7,12 +7,10 @@ use serde::{Deserialize, Serialize}; use std::{ borrow::Cow, collections::HashMap, - env, fmt::{Debug, Display}, time::{Duration, SystemTime}, }; use tokio_stream::Stream; -use tracing::Span; /// Diagnostic information about an error. /// @@ -209,24 +207,6 @@ impl Context { pub fn deadline(&self) -> SystemTime { SystemTime::UNIX_EPOCH + Duration::from_millis(self.deadline) } - - /// Create a new [`tracing::Span`] for an incoming invocation. - pub(crate) fn request_span(&self) -> Span { - match &self.xray_trace_id { - Some(trace_id) => { - env::set_var("_X_AMZN_TRACE_ID", trace_id); - tracing::info_span!( - "Lambda runtime invoke", - requestId = &self.request_id, - xrayTraceId = trace_id - ) - } - None => { - env::remove_var("_X_AMZN_TRACE_ID"); - tracing::info_span!("Lambda runtime invoke", requestId = &self.request_id) - } - } - } } /// Extract the invocation request id from the incoming request.