diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 0edfe985cd7..2c829789e8c 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import json import logging @@ -190,9 +192,12 @@ def __init__( allow_credentials: bool A boolean value that sets the value of `Access-Control-Allow-Credentials` """ + self._allowed_origins = [allow_origin] + if extra_origins: self._allowed_origins.extend(extra_origins) + self.allow_headers = set(self._REQUIRED_HEADERS + (allow_headers or [])) self.expose_headers = expose_headers or [] self.max_age = max_age @@ -220,10 +225,18 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]: headers["Access-Control-Expose-Headers"] = ",".join(self.expose_headers) if self.max_age is not None: headers["Access-Control-Max-Age"] = str(self.max_age) - if self.allow_credentials is True: + if origin != "*" and self.allow_credentials is True: headers["Access-Control-Allow-Credentials"] = "true" return headers + def allowed_origin(self, extracted_origin: str) -> str | None: + if extracted_origin in self._allowed_origins: + return extracted_origin + if extracted_origin is not None and "*" in self._allowed_origins: + return "*" + + return None + @staticmethod def build_allow_methods(methods: Set[str]) -> str: """Build sorted comma delimited methods for Access-Control-Allow-Methods header @@ -808,7 +821,10 @@ def __init__( def _add_cors(self, event: ResponseEventT, cors: CORSConfig): """Update headers to include the configured Access-Control headers""" extracted_origin_header = extract_origin_header(event.resolved_headers_field) - self.response.headers.update(cors.to_dict(extracted_origin_header)) + + origin = cors.allowed_origin(extracted_origin_header) + if origin is not None: + self.response.headers.update(cors.to_dict(origin)) def _add_cache_control(self, cache_control: str): """Set the specified cache control headers for 200 http responses. For non-200 `no-cache` is used.""" diff --git a/tests/events/apiGatewayProxyEventNoOrigin.json b/tests/events/apiGatewayProxyEventNoOrigin.json new file mode 100644 index 00000000000..666022723ad --- /dev/null +++ b/tests/events/apiGatewayProxyEventNoOrigin.json @@ -0,0 +1,80 @@ +{ + "version": "1.0", + "resource": "/my/path", + "path": "/my/path", + "httpMethod": "GET", + "headers": { + "Header1": "value1", + "Header2": "value2" + }, + "multiValueHeaders": { + "Header1": [ + "value1" + ], + "Header2": [ + "value1", + "value2" + ] + }, + "queryStringParameters": { + "parameter1": "value1", + "parameter2": "value" + }, + "multiValueQueryStringParameters": { + "parameter1": [ + "value1", + "value2" + ], + "parameter2": [ + "value" + ] + }, + "requestContext": { + "accountId": "123456789012", + "apiId": "id", + "authorizer": { + "claims": null, + "scopes": null + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "extendedRequestId": "request-id", + "httpMethod": "GET", + "identity": { + "accessKey": null, + "accountId": null, + "caller": null, + "cognitoAuthenticationProvider": null, + "cognitoAuthenticationType": null, + "cognitoIdentityId": null, + "cognitoIdentityPoolId": null, + "principalOrgId": null, + "sourceIp": "192.168.0.1/32", + "user": null, + "userAgent": "user-agent", + "userArn": null, + "clientCert": { + "clientCertPem": "CERT_CONTENT", + "subjectDN": "www.example.com", + "issuerDN": "Example issuer", + "serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1", + "validity": { + "notBefore": "May 28 12:30:02 2019 GMT", + "notAfter": "Aug 5 09:36:04 2021 GMT" + } + } + }, + "path": "/my/path", + "protocol": "HTTP/1.1", + "requestId": "id=", + "requestTime": "04/Mar/2020:19:15:17 +0000", + "requestTimeEpoch": 1583349317135, + "resourceId": null, + "resourcePath": "/my/path", + "stage": "$default" + }, + "pathParameters": null, + "stageVariables": null, + "body": "Hello from Lambda!", + "isBase64Encoded": false +} \ No newline at end of file diff --git a/tests/functional/event_handler/required_dependencies/test_api_gateway.py b/tests/functional/event_handler/required_dependencies/test_api_gateway.py index ecd514aa0ee..ef36ab00587 100644 --- a/tests/functional/event_handler/required_dependencies/test_api_gateway.py +++ b/tests/functional/event_handler/required_dependencies/test_api_gateway.py @@ -48,6 +48,7 @@ def read_media(file_name: str) -> bytes: LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json") +LOAD_GW_EVENT_NO_ORIGIN = load_event("apiGatewayProxyEventNoOrigin.json") LOAD_GW_EVENT_TRAILING_SLASH = load_event("apiGatewayProxyEventPathTrailingSlash.json") @@ -324,7 +325,7 @@ def handler(event, context): def test_cors(): # GIVEN a function with cors=True # AND http method set to GET - app = ApiGatewayResolver() + app = ApiGatewayResolver(cors=CORSConfig("https://aws.amazon.com", allow_credentials=True)) @app.get("/my/path", cors=True) def with_cors() -> Response: @@ -345,6 +346,69 @@ def handler(event, context): headers = result["multiValueHeaders"] assert headers["Content-Type"] == [content_types.TEXT_HTML] assert headers["Access-Control-Allow-Origin"] == ["https://aws.amazon.com"] + assert "Access-Control-Allow-Credentials" in headers + assert headers["Access-Control-Allow-Headers"] == [",".join(sorted(CORSConfig._REQUIRED_HEADERS))] + + # THEN for routes without cors flag return no cors headers + mock_event = {"path": "/my/request", "httpMethod": "GET"} + result = handler(mock_event, None) + assert "Access-Control-Allow-Origin" not in result["multiValueHeaders"] + + +def test_cors_no_request_origin(): + # GIVEN a function with cors=True + # AND http method set to GET + app = ApiGatewayResolver() + + @app.get("/my/path", cors=True) + def with_cors() -> Response: + return Response(200, content_types.TEXT_HTML, "test") + + def handler(event, context): + return app.resolve(event, context) + + event = LOAD_GW_EVENT_NO_ORIGIN + + # WHEN calling the event handler + result = handler(event, None) + + # THEN the headers should include cors headers + assert "multiValueHeaders" in result + headers = result["multiValueHeaders"] + assert headers["Content-Type"] == [content_types.TEXT_HTML] + assert "Access-Control-Allow-Credentials" not in headers + assert "Access-Control-Allow-Origin" not in result["multiValueHeaders"] + + +def test_cors_allow_all_request_origins(): + # GIVEN a function with cors=True + # AND http method set to GET + app = ApiGatewayResolver( + cors=CORSConfig( + allow_origin="*", + allow_credentials=True, + ), + ) + + @app.get("/my/path", cors=True) + def with_cors() -> Response: + return Response(200, content_types.TEXT_HTML, "test") + + @app.get("/without-cors") + def without_cors() -> Response: + return Response(200, content_types.TEXT_HTML, "test") + + def handler(event, context): + return app.resolve(event, context) + + # WHEN calling the event handler + result = handler(LOAD_GW_EVENT, None) + + # THEN the headers should include cors headers + assert "multiValueHeaders" in result + headers = result["multiValueHeaders"] + assert headers["Content-Type"] == [content_types.TEXT_HTML] + assert headers["Access-Control-Allow-Origin"] == ["*"] assert "Access-Control-Allow-Credentials" not in headers assert headers["Access-Control-Allow-Headers"] == [",".join(sorted(CORSConfig._REQUIRED_HEADERS))]