Skip to content

Commit

Permalink
Lambda-http: vary type of response based on request origin
Browse files Browse the repository at this point in the history
ApiGatewayV2, ApiGateway and Alb all expect different types of responses to
be returned from the invoked lambda function. Thus, it makes sense to pass
the request origin to the creation of the response, so that the correct
type of LambdaResponse is returned from the function.

This commit also adds support for the "cookies" attribute which can be used
for returning multiple Set-cookie headers from a lambda invoked via
ApiGatewayV2, since ApiGatewayV2 no longer seems to recognize the
"multiValueHeaders" attribute.

Closes: #267.
  • Loading branch information
l3ku committed Nov 24, 2020
1 parent 13aa8f0 commit 374ed81
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 65 deletions.
17 changes: 10 additions & 7 deletions lambda-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ pub mod request;
mod response;
mod strmap;
pub use crate::{body::Body, ext::RequestExt, response::IntoResponse, strmap::StrMap};
use crate::{request::LambdaRequest, response::LambdaResponse};
use crate::{
request::{LambdaRequest, RequestOrigin},
response::LambdaResponse,
};
use std::{
future::Future,
pin::Pin,
Expand Down Expand Up @@ -149,7 +152,7 @@ where

#[doc(hidden)]
pub struct TransformResponse<R, E> {
is_alb: bool,
request_origin: RequestOrigin,
fut: Pin<Box<dyn Future<Output = Result<R, E>>>>,
}

Expand All @@ -160,9 +163,9 @@ where
type Output = Result<LambdaResponse, E>;
fn poll(mut self: Pin<&mut Self>, cx: &mut TaskContext) -> Poll<Self::Output> {
match self.fut.as_mut().poll(cx) {
Poll::Ready(result) => {
Poll::Ready(result.map(|resp| LambdaResponse::from_response(self.is_alb, resp.into_response())))
}
Poll::Ready(result) => Poll::Ready(
result.map(|resp| LambdaResponse::from_response(&self.request_origin, resp.into_response())),
),
Poll::Pending => Poll::Pending,
}
}
Expand Down Expand Up @@ -192,8 +195,8 @@ impl<H: Handler> LambdaHandler<LambdaRequest<'_>, LambdaResponse> for Adapter<H>
type Error = H::Error;
type Fut = TransformResponse<H::Response, Self::Error>;
fn call(&mut self, event: LambdaRequest<'_>, context: Context) -> Self::Fut {
let is_alb = event.is_alb();
let request_origin = event.request_origin();
let fut = Box::pin(self.handler.call(event.into(), context));
TransformResponse { is_alb, fut }
TransformResponse { request_origin, fut }
}
}
29 changes: 21 additions & 8 deletions lambda-http/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,30 @@ pub enum LambdaRequest<'a> {
}

impl LambdaRequest<'_> {
/// Return true if this request represents an ALB event
///
/// Alb responses have unique requirements for responses that
/// vary only slightly from APIGateway responses. We serialize
/// responses capturing a hint that the request was an alb triggered
/// event.
pub fn is_alb(&self) -> bool {
matches!(self, LambdaRequest::Alb { .. })
/// Return the `RequestOrigin` of the request to determine where the `LambdaRequest`
/// originated from, so that the appropriate response can be selected based on what
/// type of response the request origin expects.
pub fn request_origin(&self) -> RequestOrigin {
match self {
LambdaRequest::ApiGatewayV2 { .. } => RequestOrigin::ApiGatewayV2,
LambdaRequest::Alb { .. } => RequestOrigin::Alb,
LambdaRequest::ApiGateway { .. } => RequestOrigin::ApiGateway,
}
}
}

/// Represents the origin from which the lambda was requested from.
#[doc(hidden)]
#[derive(Debug)]
pub enum RequestOrigin {
/// API Gateway v2 request origin
ApiGatewayV2,
/// API Gateway request origin
ApiGateway,
/// ALB request origin
Alb,
}

#[derive(Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
pub struct ApiGatewayV2RequestContext {
Expand Down
211 changes: 161 additions & 50 deletions lambda-http/src/response.rs
Original file line number Diff line number Diff line change
@@ -1,49 +1,70 @@
//! Response types

use http::{
header::{HeaderMap, HeaderValue, CONTENT_TYPE},
header::{HeaderMap, HeaderValue, CONTENT_TYPE, SET_COOKIE},
Response,
};
use serde::{
ser::{Error as SerError, SerializeMap},
ser::{Error as SerError, SerializeMap, SerializeSeq},
Serializer,
};
use serde_derive::Serialize;

use crate::body::Body;
use crate::request::RequestOrigin;

/// Representation of API Gateway response
/// Representation of Lambda response
#[doc(hidden)]
#[derive(Serialize, Debug)]
#[serde(untagged)]
pub enum LambdaResponse {
ApiGatewayV2(ApiGatewayV2Response),
Alb(AlbResponse),
ApiGateway(ApiGatewayResponse),
}

/// Representation of API Gateway v2 lambda response
#[doc(hidden)]
#[derive(Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct LambdaResponse {
pub status_code: u16,
// ALB requires a statusDescription i.e. "200 OK" field but API Gateway returns an error
// when one is provided. only populate this for ALB responses
pub struct ApiGatewayV2Response {
status_code: u16,
#[serde(serialize_with = "serialize_headers")]
headers: HeaderMap<HeaderValue>,
#[serde(serialize_with = "serialize_headers_slice")]
cookies: Vec<HeaderValue>,
#[serde(skip_serializing_if = "Option::is_none")]
pub status_description: Option<String>,
body: Option<Body>,
is_base64_encoded: bool,
}

/// Representation of ALB lambda response
#[doc(hidden)]
#[derive(Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct AlbResponse {
status_code: u16,
status_description: String,
#[serde(serialize_with = "serialize_headers")]
pub headers: HeaderMap<HeaderValue>,
#[serde(serialize_with = "serialize_multi_value_headers")]
pub multi_value_headers: HeaderMap<HeaderValue>,
headers: HeaderMap<HeaderValue>,
#[serde(skip_serializing_if = "Option::is_none")]
pub body: Option<Body>,
// This field is optional for API Gateway but required for ALB
pub is_base64_encoded: bool,
body: Option<Body>,
is_base64_encoded: bool,
}

#[cfg(test)]
impl Default for LambdaResponse {
fn default() -> Self {
Self {
status_code: 200,
status_description: Default::default(),
headers: Default::default(),
multi_value_headers: Default::default(),
body: Default::default(),
is_base64_encoded: Default::default(),
}
}
/// Representation of API Gateway lambda response
#[doc(hidden)]
#[derive(Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct ApiGatewayResponse {
status_code: u16,
#[serde(serialize_with = "serialize_headers")]
headers: HeaderMap<HeaderValue>,
#[serde(serialize_with = "serialize_multi_value_headers")]
multi_value_headers: HeaderMap<HeaderValue>,
#[serde(skip_serializing_if = "Option::is_none")]
body: Option<Body>,
is_base64_encoded: bool,
}

/// Serialize a http::HeaderMap into a serde str => str map
Expand Down Expand Up @@ -75,9 +96,21 @@ where
map.end()
}

/// Serialize a &[HeaderValue] into a Vec<str>
fn serialize_headers_slice<S>(headers: &[HeaderValue], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(Some(headers.len()))?;
for header in headers {
seq.serialize_element(header.to_str().map_err(S::Error::custom)?)?;
}
seq.end()
}

/// tranformation from http type to internal type
impl LambdaResponse {
pub(crate) fn from_response<T>(is_alb: bool, value: Response<T>) -> Self
pub(crate) fn from_response<T>(request_origin: &RequestOrigin, value: Response<T>) -> Self
where
T: Into<Body>,
{
Expand All @@ -87,21 +120,43 @@ impl LambdaResponse {
b @ Body::Text(_) => (false, Some(b)),
b @ Body::Binary(_) => (true, Some(b)),
};
Self {
status_code: parts.status.as_u16(),
status_description: if is_alb {
Some(format!(

let mut headers = parts.headers;
let status_code = parts.status.as_u16();

match request_origin {
RequestOrigin::ApiGatewayV2 => {
// ApiGatewayV2 expects the set-cookies headers to be in the "cookies" attribute,
// so remove them from the headers.
let cookies: Vec<HeaderValue> = headers.get_all(SET_COOKIE).iter().cloned().collect();
headers.remove(SET_COOKIE);

LambdaResponse::ApiGatewayV2(ApiGatewayV2Response {
body,
status_code,
is_base64_encoded,
cookies,
headers,
})
}
RequestOrigin::ApiGateway => LambdaResponse::ApiGateway(ApiGatewayResponse {
body,
status_code,
is_base64_encoded,
headers: headers.clone(),
multi_value_headers: headers,
}),
RequestOrigin::Alb => LambdaResponse::Alb(AlbResponse {
body,
status_code,
is_base64_encoded,
headers,
status_description: format!(
"{} {}",
parts.status.as_u16(),
status_code,
parts.status.canonical_reason().unwrap_or_default()
))
} else {
None
},
body,
headers: parts.headers.clone(),
multi_value_headers: parts.headers,
is_base64_encoded,
),
}),
}
}
}
Expand Down Expand Up @@ -161,10 +216,42 @@ impl IntoResponse for serde_json::Value {

#[cfg(test)]
mod tests {
use super::{Body, IntoResponse, LambdaResponse};
use super::{
AlbResponse, ApiGatewayResponse, ApiGatewayV2Response, Body, IntoResponse, LambdaResponse, RequestOrigin,
};
use http::{header::CONTENT_TYPE, Response};
use serde_json::{self, json};

fn api_gateway_response() -> ApiGatewayResponse {
ApiGatewayResponse {
status_code: 200,
headers: Default::default(),
multi_value_headers: Default::default(),
body: Default::default(),
is_base64_encoded: Default::default(),
}
}

fn alb_response() -> AlbResponse {
AlbResponse {
status_code: 200,
status_description: "200 OK".to_string(),
headers: Default::default(),
body: Default::default(),
is_base64_encoded: Default::default(),
}
}

fn api_gateway_v2_response() -> ApiGatewayV2Response {
ApiGatewayV2Response {
status_code: 200,
headers: Default::default(),
body: Default::default(),
cookies: Default::default(),
is_base64_encoded: Default::default(),
}
}

#[test]
fn json_into_response() {
let response = json!({ "hello": "lambda"}).into_response();
Expand All @@ -191,32 +278,39 @@ mod tests {
}

#[test]
fn default_response() {
assert_eq!(LambdaResponse::default().status_code, 200)
fn serialize_body_for_api_gateway() {
let mut resp = api_gateway_response();
resp.body = Some("foo".into());
assert_eq!(
serde_json::to_string(&resp).expect("failed to serialize response"),
r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"body":"foo","isBase64Encoded":false}"#
);
}

#[test]
fn serialize_default() {
fn serialize_body_for_alb() {
let mut resp = alb_response();
resp.body = Some("foo".into());
assert_eq!(
serde_json::to_string(&LambdaResponse::default()).expect("failed to serialize response"),
r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"isBase64Encoded":false}"#
serde_json::to_string(&resp).expect("failed to serialize response"),
r#"{"statusCode":200,"statusDescription":"200 OK","headers":{},"body":"foo","isBase64Encoded":false}"#
);
}

#[test]
fn serialize_body() {
let mut resp = LambdaResponse::default();
fn serialize_body_for_api_gateway_v2() {
let mut resp = api_gateway_v2_response();
resp.body = Some("foo".into());
assert_eq!(
serde_json::to_string(&resp).expect("failed to serialize response"),
r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"body":"foo","isBase64Encoded":false}"#
r#"{"statusCode":200,"headers":{},"cookies":[],"body":"foo","isBase64Encoded":false}"#
);
}

#[test]
fn serialize_multi_value_headers() {
let res = LambdaResponse::from_response(
false,
&RequestOrigin::ApiGateway,
Response::builder()
.header("multi", "a")
.header("multi", "b")
Expand All @@ -229,4 +323,21 @@ mod tests {
r#"{"statusCode":200,"headers":{"multi":"a"},"multiValueHeaders":{"multi":["a","b"]},"isBase64Encoded":false}"#
)
}

#[test]
fn serialize_cookies() {
let res = LambdaResponse::from_response(
&RequestOrigin::ApiGatewayV2,
Response::builder()
.header("set-cookie", "cookie1=a")
.header("set-cookie", "cookie2=b")
.body(Body::from(()))
.expect("failed to create response"),
);
let json = serde_json::to_string(&res).expect("failed to serialize to json");
assert_eq!(
json,
r#"{"statusCode":200,"headers":{},"cookies":["cookie1=a","cookie2=b"],"isBase64Encoded":false}"#
)
}
}

0 comments on commit 374ed81

Please sign in to comment.