diff --git a/lambda-http/src/lib.rs b/lambda-http/src/lib.rs index f26fa351..2bc05067 100644 --- a/lambda-http/src/lib.rs +++ b/lambda-http/src/lib.rs @@ -100,7 +100,10 @@ pub mod request; mod response; mod strmap; pub use crate::{body::Body, ext::RequestExt, response::IntoResponse, strmap::StrMap}; -use crate::{request::LambdaRequest, response::LambdaResponse}; +use crate::{ + request::{LambdaRequest, RequestOrigin}, + response::LambdaResponse, +}; use std::{ future::Future, pin::Pin, @@ -149,7 +152,7 @@ where #[doc(hidden)] pub struct TransformResponse { - is_alb: bool, + request_origin: RequestOrigin, fut: Pin>>>, } @@ -160,9 +163,9 @@ where type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut TaskContext) -> Poll { match self.fut.as_mut().poll(cx) { - Poll::Ready(result) => { - Poll::Ready(result.map(|resp| LambdaResponse::from_response(self.is_alb, resp.into_response()))) - } + Poll::Ready(result) => Poll::Ready( + result.map(|resp| LambdaResponse::from_response(&self.request_origin, resp.into_response())), + ), Poll::Pending => Poll::Pending, } } @@ -192,8 +195,8 @@ impl LambdaHandler, LambdaResponse> for Adapter type Error = H::Error; type Fut = TransformResponse; fn call(&mut self, event: LambdaRequest<'_>, context: Context) -> Self::Fut { - let is_alb = event.is_alb(); + let request_origin = event.request_origin(); let fut = Box::pin(self.handler.call(event.into(), context)); - TransformResponse { is_alb, fut } + TransformResponse { request_origin, fut } } } diff --git a/lambda-http/src/request.rs b/lambda-http/src/request.rs index 347cb29e..53585738 100644 --- a/lambda-http/src/request.rs +++ b/lambda-http/src/request.rs @@ -90,17 +90,30 @@ pub enum LambdaRequest<'a> { } impl LambdaRequest<'_> { - /// Return true if this request represents an ALB event - /// - /// Alb responses have unique requirements for responses that - /// vary only slightly from APIGateway responses. We serialize - /// responses capturing a hint that the request was an alb triggered - /// event. - pub fn is_alb(&self) -> bool { - matches!(self, LambdaRequest::Alb { .. }) + /// Return the `RequestOrigin` of the request to determine where the `LambdaRequest` + /// originated from, so that the appropriate response can be selected based on what + /// type of response the request origin expects. + pub fn request_origin(&self) -> RequestOrigin { + match self { + LambdaRequest::ApiGatewayV2 { .. } => RequestOrigin::ApiGatewayV2, + LambdaRequest::Alb { .. } => RequestOrigin::Alb, + LambdaRequest::ApiGateway { .. } => RequestOrigin::ApiGateway, + } } } +/// Represents the origin from which the lambda was requested from. +#[doc(hidden)] +#[derive(Debug)] +pub enum RequestOrigin { + /// API Gateway v2 request origin + ApiGatewayV2, + /// API Gateway request origin + ApiGateway, + /// ALB request origin + Alb, +} + #[derive(Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] pub struct ApiGatewayV2RequestContext { diff --git a/lambda-http/src/response.rs b/lambda-http/src/response.rs index 61406426..a9104899 100644 --- a/lambda-http/src/response.rs +++ b/lambda-http/src/response.rs @@ -1,49 +1,70 @@ //! Response types use http::{ - header::{HeaderMap, HeaderValue, CONTENT_TYPE}, + header::{HeaderMap, HeaderValue, CONTENT_TYPE, SET_COOKIE}, Response, }; use serde::{ - ser::{Error as SerError, SerializeMap}, + ser::{Error as SerError, SerializeMap, SerializeSeq}, Serializer, }; use serde_derive::Serialize; use crate::body::Body; +use crate::request::RequestOrigin; -/// Representation of API Gateway response +/// Representation of Lambda response +#[doc(hidden)] +#[derive(Serialize, Debug)] +#[serde(untagged)] +pub enum LambdaResponse { + ApiGatewayV2(ApiGatewayV2Response), + Alb(AlbResponse), + ApiGateway(ApiGatewayResponse), +} + +/// Representation of API Gateway v2 lambda response #[doc(hidden)] #[derive(Serialize, Debug)] #[serde(rename_all = "camelCase")] -pub struct LambdaResponse { - pub status_code: u16, - // ALB requires a statusDescription i.e. "200 OK" field but API Gateway returns an error - // when one is provided. only populate this for ALB responses +pub struct ApiGatewayV2Response { + status_code: u16, + #[serde(serialize_with = "serialize_headers")] + headers: HeaderMap, + #[serde(serialize_with = "serialize_headers_slice")] + cookies: Vec, #[serde(skip_serializing_if = "Option::is_none")] - pub status_description: Option, + body: Option, + is_base64_encoded: bool, +} + +/// Representation of ALB lambda response +#[doc(hidden)] +#[derive(Serialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct AlbResponse { + status_code: u16, + status_description: String, #[serde(serialize_with = "serialize_headers")] - pub headers: HeaderMap, - #[serde(serialize_with = "serialize_multi_value_headers")] - pub multi_value_headers: HeaderMap, + headers: HeaderMap, #[serde(skip_serializing_if = "Option::is_none")] - pub body: Option, - // This field is optional for API Gateway but required for ALB - pub is_base64_encoded: bool, + body: Option, + is_base64_encoded: bool, } -#[cfg(test)] -impl Default for LambdaResponse { - fn default() -> Self { - Self { - status_code: 200, - status_description: Default::default(), - headers: Default::default(), - multi_value_headers: Default::default(), - body: Default::default(), - is_base64_encoded: Default::default(), - } - } +/// Representation of API Gateway lambda response +#[doc(hidden)] +#[derive(Serialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct ApiGatewayResponse { + status_code: u16, + #[serde(serialize_with = "serialize_headers")] + headers: HeaderMap, + #[serde(serialize_with = "serialize_multi_value_headers")] + multi_value_headers: HeaderMap, + #[serde(skip_serializing_if = "Option::is_none")] + body: Option, + is_base64_encoded: bool, } /// Serialize a http::HeaderMap into a serde str => str map @@ -75,9 +96,21 @@ where map.end() } +/// Serialize a &[HeaderValue] into a Vec +fn serialize_headers_slice(headers: &[HeaderValue], serializer: S) -> Result +where + S: Serializer, +{ + let mut seq = serializer.serialize_seq(Some(headers.len()))?; + for header in headers { + seq.serialize_element(header.to_str().map_err(S::Error::custom)?)?; + } + seq.end() +} + /// tranformation from http type to internal type impl LambdaResponse { - pub(crate) fn from_response(is_alb: bool, value: Response) -> Self + pub(crate) fn from_response(request_origin: &RequestOrigin, value: Response) -> Self where T: Into, { @@ -87,21 +120,43 @@ impl LambdaResponse { b @ Body::Text(_) => (false, Some(b)), b @ Body::Binary(_) => (true, Some(b)), }; - Self { - status_code: parts.status.as_u16(), - status_description: if is_alb { - Some(format!( + + let mut headers = parts.headers; + let status_code = parts.status.as_u16(); + + match request_origin { + RequestOrigin::ApiGatewayV2 => { + // ApiGatewayV2 expects the set-cookies headers to be in the "cookies" attribute, + // so remove them from the headers. + let cookies: Vec = headers.get_all(SET_COOKIE).iter().cloned().collect(); + headers.remove(SET_COOKIE); + + LambdaResponse::ApiGatewayV2(ApiGatewayV2Response { + body, + status_code, + is_base64_encoded, + cookies, + headers, + }) + } + RequestOrigin::ApiGateway => LambdaResponse::ApiGateway(ApiGatewayResponse { + body, + status_code, + is_base64_encoded, + headers: headers.clone(), + multi_value_headers: headers, + }), + RequestOrigin::Alb => LambdaResponse::Alb(AlbResponse { + body, + status_code, + is_base64_encoded, + headers, + status_description: format!( "{} {}", - parts.status.as_u16(), + status_code, parts.status.canonical_reason().unwrap_or_default() - )) - } else { - None - }, - body, - headers: parts.headers.clone(), - multi_value_headers: parts.headers, - is_base64_encoded, + ), + }), } } } @@ -161,10 +216,42 @@ impl IntoResponse for serde_json::Value { #[cfg(test)] mod tests { - use super::{Body, IntoResponse, LambdaResponse}; + use super::{ + AlbResponse, ApiGatewayResponse, ApiGatewayV2Response, Body, IntoResponse, LambdaResponse, RequestOrigin, + }; use http::{header::CONTENT_TYPE, Response}; use serde_json::{self, json}; + fn api_gateway_response() -> ApiGatewayResponse { + ApiGatewayResponse { + status_code: 200, + headers: Default::default(), + multi_value_headers: Default::default(), + body: Default::default(), + is_base64_encoded: Default::default(), + } + } + + fn alb_response() -> AlbResponse { + AlbResponse { + status_code: 200, + status_description: "200 OK".to_string(), + headers: Default::default(), + body: Default::default(), + is_base64_encoded: Default::default(), + } + } + + fn api_gateway_v2_response() -> ApiGatewayV2Response { + ApiGatewayV2Response { + status_code: 200, + headers: Default::default(), + body: Default::default(), + cookies: Default::default(), + is_base64_encoded: Default::default(), + } + } + #[test] fn json_into_response() { let response = json!({ "hello": "lambda"}).into_response(); @@ -191,32 +278,39 @@ mod tests { } #[test] - fn default_response() { - assert_eq!(LambdaResponse::default().status_code, 200) + fn serialize_body_for_api_gateway() { + let mut resp = api_gateway_response(); + resp.body = Some("foo".into()); + assert_eq!( + serde_json::to_string(&resp).expect("failed to serialize response"), + r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"body":"foo","isBase64Encoded":false}"# + ); } #[test] - fn serialize_default() { + fn serialize_body_for_alb() { + let mut resp = alb_response(); + resp.body = Some("foo".into()); assert_eq!( - serde_json::to_string(&LambdaResponse::default()).expect("failed to serialize response"), - r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"isBase64Encoded":false}"# + serde_json::to_string(&resp).expect("failed to serialize response"), + r#"{"statusCode":200,"statusDescription":"200 OK","headers":{},"body":"foo","isBase64Encoded":false}"# ); } #[test] - fn serialize_body() { - let mut resp = LambdaResponse::default(); + fn serialize_body_for_api_gateway_v2() { + let mut resp = api_gateway_v2_response(); resp.body = Some("foo".into()); assert_eq!( serde_json::to_string(&resp).expect("failed to serialize response"), - r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"body":"foo","isBase64Encoded":false}"# + r#"{"statusCode":200,"headers":{},"cookies":[],"body":"foo","isBase64Encoded":false}"# ); } #[test] fn serialize_multi_value_headers() { let res = LambdaResponse::from_response( - false, + &RequestOrigin::ApiGateway, Response::builder() .header("multi", "a") .header("multi", "b") @@ -229,4 +323,21 @@ mod tests { r#"{"statusCode":200,"headers":{"multi":"a"},"multiValueHeaders":{"multi":["a","b"]},"isBase64Encoded":false}"# ) } + + #[test] + fn serialize_cookies() { + let res = LambdaResponse::from_response( + &RequestOrigin::ApiGatewayV2, + Response::builder() + .header("set-cookie", "cookie1=a") + .header("set-cookie", "cookie2=b") + .body(Body::from(())) + .expect("failed to create response"), + ); + let json = serde_json::to_string(&res).expect("failed to serialize to json"); + assert_eq!( + json, + r#"{"statusCode":200,"headers":{},"cookies":["cookie1=a","cookie2=b"],"isBase64Encoded":false}"# + ) + } }