Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(event_handler): disable allow-credentials header when origin allow_origin is * #4638

Merged
merged 9 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import base64
import json
import logging
Expand Down Expand Up @@ -190,9 +192,12 @@ def __init__(
allow_credentials: bool
A boolean value that sets the value of `Access-Control-Allow-Credentials`
"""

self._allowed_origins = [allow_origin]

if extra_origins:
self._allowed_origins.extend(extra_origins)

self.allow_headers = set(self._REQUIRED_HEADERS + (allow_headers or []))
self.expose_headers = expose_headers or []
self.max_age = max_age
Expand Down Expand Up @@ -220,10 +225,18 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
headers["Access-Control-Expose-Headers"] = ",".join(self.expose_headers)
if self.max_age is not None:
headers["Access-Control-Max-Age"] = str(self.max_age)
if self.allow_credentials is True:
if origin != "*" and self.allow_credentials is True:
sthulb marked this conversation as resolved.
Show resolved Hide resolved
headers["Access-Control-Allow-Credentials"] = "true"
return headers

def allowed_origin(self, extracted_origin: str) -> str | None:
if extracted_origin in self._allowed_origins:
return extracted_origin
if extracted_origin is not None and "*" in self._allowed_origins:
return "*"

return None

@staticmethod
def build_allow_methods(methods: Set[str]) -> str:
"""Build sorted comma delimited methods for Access-Control-Allow-Methods header
Expand Down Expand Up @@ -808,7 +821,10 @@ def __init__(
def _add_cors(self, event: ResponseEventT, cors: CORSConfig):
"""Update headers to include the configured Access-Control headers"""
extracted_origin_header = extract_origin_header(event.resolved_headers_field)
self.response.headers.update(cors.to_dict(extracted_origin_header))

origin = cors.allowed_origin(extracted_origin_header)
if origin is not None:
self.response.headers.update(cors.to_dict(origin))

def _add_cache_control(self, cache_control: str):
"""Set the specified cache control headers for 200 http responses. For non-200 `no-cache` is used."""
Expand Down
80 changes: 80 additions & 0 deletions tests/events/apiGatewayProxyEventNoOrigin.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
{
"version": "1.0",
"resource": "/my/path",
"path": "/my/path",
"httpMethod": "GET",
"headers": {
"Header1": "value1",
"Header2": "value2"
},
"multiValueHeaders": {
"Header1": [
"value1"
],
"Header2": [
"value1",
"value2"
]
},
"queryStringParameters": {
"parameter1": "value1",
"parameter2": "value"
},
"multiValueQueryStringParameters": {
"parameter1": [
"value1",
"value2"
],
"parameter2": [
"value"
]
},
"requestContext": {
"accountId": "123456789012",
"apiId": "id",
"authorizer": {
"claims": null,
"scopes": null
},
"domainName": "id.execute-api.us-east-1.amazonaws.com",
"domainPrefix": "id",
"extendedRequestId": "request-id",
"httpMethod": "GET",
"identity": {
"accessKey": null,
"accountId": null,
"caller": null,
"cognitoAuthenticationProvider": null,
"cognitoAuthenticationType": null,
"cognitoIdentityId": null,
"cognitoIdentityPoolId": null,
"principalOrgId": null,
"sourceIp": "192.168.0.1/32",
"user": null,
"userAgent": "user-agent",
"userArn": null,
"clientCert": {
"clientCertPem": "CERT_CONTENT",
"subjectDN": "www.example.com",
"issuerDN": "Example issuer",
"serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1",
"validity": {
"notBefore": "May 28 12:30:02 2019 GMT",
"notAfter": "Aug 5 09:36:04 2021 GMT"
}
}
},
"path": "/my/path",
"protocol": "HTTP/1.1",
"requestId": "id=",
"requestTime": "04/Mar/2020:19:15:17 +0000",
"requestTimeEpoch": 1583349317135,
"resourceId": null,
"resourcePath": "/my/path",
"stage": "$default"
},
"pathParameters": null,
"stageVariables": null,
"body": "Hello from Lambda!",
"isBase64Encoded": false
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def read_media(file_name: str) -> bytes:


LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json")
LOAD_GW_EVENT_NO_ORIGIN = load_event("apiGatewayProxyEventNoOrigin.json")
LOAD_GW_EVENT_TRAILING_SLASH = load_event("apiGatewayProxyEventPathTrailingSlash.json")


Expand Down Expand Up @@ -324,7 +325,7 @@ def handler(event, context):
def test_cors():
# GIVEN a function with cors=True
# AND http method set to GET
app = ApiGatewayResolver()
app = ApiGatewayResolver(cors=CORSConfig("https://aws.amazon.com", allow_credentials=True))

@app.get("/my/path", cors=True)
def with_cors() -> Response:
Expand All @@ -345,6 +346,69 @@ def handler(event, context):
headers = result["multiValueHeaders"]
assert headers["Content-Type"] == [content_types.TEXT_HTML]
assert headers["Access-Control-Allow-Origin"] == ["https://aws.amazon.com"]
assert "Access-Control-Allow-Credentials" in headers
assert headers["Access-Control-Allow-Headers"] == [",".join(sorted(CORSConfig._REQUIRED_HEADERS))]

# THEN for routes without cors flag return no cors headers
mock_event = {"path": "/my/request", "httpMethod": "GET"}
result = handler(mock_event, None)
assert "Access-Control-Allow-Origin" not in result["multiValueHeaders"]


def test_cors_no_request_origin():
# GIVEN a function with cors=True
# AND http method set to GET
app = ApiGatewayResolver()

@app.get("/my/path", cors=True)
def with_cors() -> Response:
return Response(200, content_types.TEXT_HTML, "test")

def handler(event, context):
return app.resolve(event, context)

event = LOAD_GW_EVENT_NO_ORIGIN

# WHEN calling the event handler
result = handler(event, None)

# THEN the headers should include cors headers
assert "multiValueHeaders" in result
headers = result["multiValueHeaders"]
assert headers["Content-Type"] == [content_types.TEXT_HTML]
assert "Access-Control-Allow-Credentials" not in headers
assert "Access-Control-Allow-Origin" not in result["multiValueHeaders"]


def test_cors_allow_all_request_origins():
# GIVEN a function with cors=True
# AND http method set to GET
app = ApiGatewayResolver(
cors=CORSConfig(
allow_origin="*",
allow_credentials=True,
),
)

@app.get("/my/path", cors=True)
def with_cors() -> Response:
return Response(200, content_types.TEXT_HTML, "test")

@app.get("/without-cors")
def without_cors() -> Response:
return Response(200, content_types.TEXT_HTML, "test")

def handler(event, context):
return app.resolve(event, context)

# WHEN calling the event handler
result = handler(LOAD_GW_EVENT, None)

# THEN the headers should include cors headers
assert "multiValueHeaders" in result
headers = result["multiValueHeaders"]
assert headers["Content-Type"] == [content_types.TEXT_HTML]
assert headers["Access-Control-Allow-Origin"] == ["*"]
assert "Access-Control-Allow-Credentials" not in headers
assert headers["Access-Control-Allow-Headers"] == [",".join(sorted(CORSConfig._REQUIRED_HEADERS))]

Expand Down