Skip to content

Commit

Permalink
✨ Add exclude_headers parameter (#280)
Browse files Browse the repository at this point in the history
* Add exclude_headers parameter

* add tests
  • Loading branch information
aminalaee authored Nov 24, 2022
1 parent 4147a90 commit f1106dd
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 7 deletions.
3 changes: 3 additions & 0 deletions mangum/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
api_gateway_base_path: str = "/",
custom_handlers: Optional[List[Type[LambdaHandler]]] = None,
text_mime_types: Optional[List[str]] = None,
exclude_headers: Optional[List[str]] = None,
) -> None:
if lifespan not in ("auto", "on", "off"):
raise ConfigurationError(
Expand All @@ -53,9 +54,11 @@ def __init__(
self.app = app
self.lifespan = lifespan
self.custom_handlers = custom_handlers or []
exclude_headers = exclude_headers or []
self.config = LambdaConfig(
api_gateway_base_path=api_gateway_base_path or "/",
text_mime_types=text_mime_types or [*DEFAULT_TEXT_MIME_TYPES],
exclude_headers=[header.lower() for header in exclude_headers],
)

def infer(self, event: LambdaEvent, context: LambdaContext) -> LambdaHandler:
Expand Down
7 changes: 5 additions & 2 deletions mangum/handlers/alb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from mangum.handlers.utils import (
get_server_and_port,
handle_base64_response_body,
handle_exclude_headers,
maybe_encode_body,
)
from mangum.types import (
Expand Down Expand Up @@ -166,8 +167,10 @@ def __call__(self, response: Response) -> dict:
# headers otherwise.
multi_value_headers_enabled = "multiValueHeaders" in self.scope["aws.event"]
if multi_value_headers_enabled:
out["multiValueHeaders"] = multi_value_headers
out["multiValueHeaders"] = handle_exclude_headers(
multi_value_headers, self.config
)
else:
out["headers"] = finalized_headers
out["headers"] = handle_exclude_headers(finalized_headers, self.config)

return out
7 changes: 5 additions & 2 deletions mangum/handlers/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mangum.handlers.utils import (
get_server_and_port,
handle_base64_response_body,
handle_exclude_headers,
handle_multi_value_headers,
maybe_encode_body,
strip_api_gateway_path,
Expand Down Expand Up @@ -120,8 +121,10 @@ def __call__(self, response: Response) -> dict:

return {
"statusCode": response["status"],
"headers": finalized_headers,
"multiValueHeaders": multi_value_headers,
"headers": handle_exclude_headers(finalized_headers, self.config),
"multiValueHeaders": handle_exclude_headers(
multi_value_headers, self.config
),
"body": finalized_body,
"isBase64Encoded": is_base64_encoded,
}
Expand Down
3 changes: 2 additions & 1 deletion mangum/handlers/lambda_at_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from mangum.handlers.utils import (
handle_base64_response_body,
handle_exclude_headers,
handle_multi_value_headers,
maybe_encode_body,
)
Expand Down Expand Up @@ -88,7 +89,7 @@ def __call__(self, response: Response) -> dict:

return {
"status": response["status"],
"headers": finalized_headers,
"headers": handle_exclude_headers(finalized_headers, self.config),
"body": response_body,
"isBase64Encoded": is_base64_encoded,
}
16 changes: 14 additions & 2 deletions mangum/handlers/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import base64
from typing import Dict, List, Tuple, Union
from typing import Any, Dict, List, Tuple, Union
from urllib.parse import unquote

from mangum.types import Headers
from mangum.types import Headers, LambdaConfig


def maybe_encode_body(body: Union[str, bytes], *, is_base64: bool) -> bytes:
Expand Down Expand Up @@ -81,3 +81,15 @@ def handle_base64_response_body(
is_base64_encoded = True

return output_body, is_base64_encoded


def handle_exclude_headers(
headers: Dict[str, Any], config: LambdaConfig
) -> Dict[str, Any]:
finalized_headers = {}
for header_key, header_value in headers.items():
if header_key in config["exclude_headers"]:
continue
finalized_headers[header_key] = header_value

return finalized_headers
1 change: 1 addition & 0 deletions mangum/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class Response(TypedDict):
class LambdaConfig(TypedDict):
api_gateway_base_path: str
text_mime_types: List[str]
exclude_headers: List[str]


class LambdaHandler(Protocol):
Expand Down
37 changes: 37 additions & 0 deletions tests/handlers/test_alb.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,40 @@ async def app(scope, receive, send):
"headers": {"content-type": content_type.decode()},
"body": utf_res_body,
}


@pytest.mark.parametrize("multi_value_headers_enabled", (True, False))
def test_aws_alb_exclude_headers(multi_value_headers_enabled) -> None:
async def app(scope, receive, send):
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
[b"content-type", b"text/plain; charset=utf-8"],
[b"x-custom-header", b"test"],
],
}
)
await send({"type": "http.response.body", "body": b"Hello, world!"})

handler = Mangum(app, lifespan="off", exclude_headers=["x-custom-header"])
event = get_mock_aws_alb_event(
"GET", "/test", {}, None, None, False, multi_value_headers_enabled
)
response = handler(event, {})

expected_response = {
"statusCode": 200,
"isBase64Encoded": False,
"body": "Hello, world!",
}
if multi_value_headers_enabled:
expected_response["multiValueHeaders"] = {
"content-type": ["text/plain; charset=utf-8"],
}
else:
expected_response["headers"] = {
"content-type": "text/plain; charset=utf-8",
}
assert response == expected_response
28 changes: 28 additions & 0 deletions tests/handlers/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,31 @@ async def app(scope, receive, send):
"multiValueHeaders": {},
"body": utf_res_body,
}


def test_aws_api_gateway_exclude_headers():
async def app(scope, receive, send):
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
[b"content-type", b"text/plain; charset=utf-8"],
[b"x-custom-header", b"test"],
],
}
)
await send({"type": "http.response.body", "body": b"Hello world"})

event = get_mock_aws_api_gateway_event("GET", "/test", {}, None, False)

handler = Mangum(app, lifespan="off", exclude_headers=["X-CUSTOM-HEADER"])

response = handler(event, {})
assert response == {
"statusCode": 200,
"isBase64Encoded": False,
"headers": {"content-type": b"text/plain; charset=utf-8".decode()},
"multiValueHeaders": {},
"body": "Hello world",
}
31 changes: 31 additions & 0 deletions tests/handlers/test_lambda_at_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,34 @@ async def app(scope, receive, send):
},
"body": utf_res_body,
}


def test_aws_lambda_at_edge_exclude_():
async def app(scope, receive, send):
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
[b"content-type", b"text/plain; charset=utf-8"],
[b"x-custom-header", b"test"],
],
}
)
await send({"type": "http.response.body", "body": b"Hello world"})

event = mock_lambda_at_edge_event("GET", "/test", {}, None, False)

handler = Mangum(app, lifespan="off", exclude_headers=["x-custom-header"])

response = handler(event, {})
assert response == {
"status": 200,
"isBase64Encoded": False,
"headers": {
"content-type": [
{"key": "content-type", "value": b"text/plain; charset=utf-8".decode()}
]
},
"body": "Hello world",
}
1 change: 1 addition & 0 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def test_default_settings():
assert handler.lifespan == "auto"
assert handler.config["api_gateway_base_path"] == "/"
assert sorted(handler.config["text_mime_types"]) == sorted(DEFAULT_TEXT_MIME_TYPES)
assert handler.config["exclude_headers"] == []


@pytest.mark.parametrize(
Expand Down

0 comments on commit f1106dd

Please sign in to comment.