Skip to content

Commit

Permalink
feat(event_source): allow multiple CORS origins (#2279)
Browse files Browse the repository at this point in the history
Co-authored-by: Leandro Damascena <leandro.damascena@gmail.com>
  • Loading branch information
rubenfonseca and leandrodamascena authored May 18, 2023
1 parent 27d197c commit 042e83a
Show file tree
Hide file tree
Showing 13 changed files with 414 additions and 27 deletions.
31 changes: 24 additions & 7 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def with_cors():
cors_config = CORSConfig(
allow_origin="https://wwww.example.com/",
extra_origins=["https://dev.example.com/"],
expose_headers=["x-exposed-response-header"],
allow_headers=["x-custom-request-header"],
max_age=100,
Expand All @@ -106,6 +107,7 @@ def without_cors():
def __init__(
self,
allow_origin: str = "*",
extra_origins: Optional[List[str]] = None,
allow_headers: Optional[List[str]] = None,
expose_headers: Optional[List[str]] = None,
max_age: Optional[int] = None,
Expand All @@ -117,6 +119,8 @@ def __init__(
allow_origin: str
The value of the `Access-Control-Allow-Origin` to send in the response. Defaults to "*", but should
only be used during development.
extra_origins: Optional[List[str]]
The list of additional allowed origins.
allow_headers: Optional[List[str]]
The list of additional allowed headers. This list is added to list of
built-in allowed headers: `Authorization`, `Content-Type`, `X-Amz-Date`,
Expand All @@ -128,16 +132,29 @@ def __init__(
allow_credentials: bool
A boolean value that sets the value of `Access-Control-Allow-Credentials`
"""
self.allow_origin = allow_origin
self._allowed_origins = [allow_origin]
if extra_origins:
self._allowed_origins.extend(extra_origins)
self.allow_headers = set(self._REQUIRED_HEADERS + (allow_headers or []))
self.expose_headers = expose_headers or []
self.max_age = max_age
self.allow_credentials = allow_credentials

def to_dict(self) -> Dict[str, str]:
def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
"""Builds the configured Access-Control http headers"""

# If there's no Origin, don't add any CORS headers
if not origin:
return {}

# If the origin doesn't match any of the allowed origins, and we don't allow all origins ("*"),
# don't add any CORS headers
if origin not in self._allowed_origins and "*" not in self._allowed_origins:
return {}

# The origin matched an allowed origin, so return the CORS headers
headers: Dict[str, str] = {
"Access-Control-Allow-Origin": self.allow_origin,
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Headers": ",".join(sorted(self.allow_headers)),
}

Expand Down Expand Up @@ -207,9 +224,9 @@ def __init__(self, response: Response, route: Optional[Route] = None):
self.response = response
self.route = route

def _add_cors(self, cors: CORSConfig):
def _add_cors(self, event: BaseProxyEvent, cors: CORSConfig):
"""Update headers to include the configured Access-Control headers"""
self.response.headers.update(cors.to_dict())
self.response.headers.update(cors.to_dict(event.get_header_value("Origin")))

def _add_cache_control(self, cache_control: str):
"""Set the specified cache control headers for 200 http responses. For non-200 `no-cache` is used."""
Expand All @@ -230,7 +247,7 @@ def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]):
if self.route is None:
return
if self.route.cors:
self._add_cors(cors or CORSConfig())
self._add_cors(event, cors or CORSConfig())
if self.route.cache_control:
self._add_cache_control(self.route.cache_control)
if self.route.compress and "gzip" in (event.get_header_value("accept-encoding", "") or ""):
Expand Down Expand Up @@ -644,7 +661,7 @@ def _not_found(self, method: str) -> ResponseBuilder:
headers: Dict[str, Union[str, List[str]]] = {}
if self._cors:
logger.debug("CORS is enabled, updating headers.")
headers.update(self._cors.to_dict())
headers.update(self._cors.to_dict(self.current_event.get_header_value("Origin")))

if method == "OPTIONS":
logger.debug("Pre-flight request detected. Returning CORS with null response")
Expand Down
2 changes: 1 addition & 1 deletion aws_lambda_powertools/utilities/data_classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_header_value(
class BaseProxyEvent(DictWrapper):
@property
def headers(self) -> Dict[str, str]:
return self["headers"]
return self.get("headers") or {}

@property
def query_string_parameters(self) -> Optional[Dict[str, str]]:
Expand Down
23 changes: 20 additions & 3 deletions docs/core/event_handler/api_gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ To address this API Gateway behavior, we use `strip_prefixes` parameter to accou

You can configure CORS at the `APIGatewayRestResolver` constructor via `cors` parameter using the `CORSConfig` class.

This will ensure that CORS headers are always returned as part of the response when your functions match the path invoked.
This will ensure that CORS headers are returned as part of the response when your functions match the path invoked and the `Origin`
matches one of the allowed values.

???+ tip
Optionally disable CORS on a per path basis with `cors=False` parameter.
Expand All @@ -297,6 +298,18 @@ This will ensure that CORS headers are always returned as part of the response w
--8<-- "examples/event_handler_rest/src/setting_cors_output.json"
```

=== "setting_cors_extra_origins.py"

```python hl_lines="5 11-12 34"
--8<-- "examples/event_handler_rest/src/setting_cors_extra_origins.py"
```

=== "setting_cors_extra_origins_output.json"

```json
--8<-- "examples/event_handler_rest/src/setting_cors_extra_origins_output.json"
```

#### Pre-flight

Pre-flight (OPTIONS) calls are typically handled at the API Gateway or Lambda Function URL level as per [our sample infrastructure](#required-resources), no Lambda integration is necessary. However, ALB expects you to handle pre-flight requests.
Expand All @@ -310,9 +323,13 @@ For convenience, these are the default values when using `CORSConfig` to enable
???+ warning
Always configure `allow_origin` when using in production.

???+ tip "Multiple origins?"
If you need to allow multiple origins, pass the additional origins using the `extra_origins` key.

| Key | Value | Note |
| -------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|----------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| **[allow_origin](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin){target="_blank"}**: `str` | `*` | Only use the default value for development. **Never use `*` for production** unless your use case requires it |
| **[extra_origins](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin){target="_blank"}**: `List[str]` | `[]` | Additional origins to be allowed, in addition to the one specified in `allow_origin` |
| **[allow_headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers){target="_blank"}**: `List[str]` | `[Authorization, Content-Type, X-Amz-Date, X-Api-Key, X-Amz-Security-Token]` | Additional headers will be appended to the default list for your convenience |
| **[expose_headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers){target="_blank"}**: `List[str]` | `[]` | Any additional header beyond the [safe listed by CORS specification](https://developer.mozilla.org/en-US/docs/Glossary/CORS-safelisted_response_header){target="_blank"}. |
| **[max_age](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age){target="_blank"}**: `int` | `` | Only for pre-flight requests if you choose to have your function to handle it instead of API Gateway |
Expand All @@ -331,7 +348,7 @@ You can use the `Response` class to have full control over the response. For exa

=== "fine_grained_responses.py"

```python hl_lines="9 28-32"
```python hl_lines="9 29-35"
--8<-- "examples/event_handler_rest/src/fine_grained_responses.py"
```

Expand Down
3 changes: 2 additions & 1 deletion examples/event_handler_rest/src/setting_cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

tracer = Tracer()
logger = Logger()
cors_config = CORSConfig(allow_origin="https://example.com", max_age=300)
# CORS will match when Origin is only https://www.example.com
cors_config = CORSConfig(allow_origin="https://www.example.com", max_age=300)
app = APIGatewayRestResolver(cors=cors_config)


Expand Down
45 changes: 45 additions & 0 deletions examples/event_handler_rest/src/setting_cors_extra_origins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import requests
from requests import Response

from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, CORSConfig
from aws_lambda_powertools.logging import correlation_paths
from aws_lambda_powertools.utilities.typing import LambdaContext

tracer = Tracer()
logger = Logger()
# CORS will match when Origin is https://www.example.com OR https://dev.example.com
cors_config = CORSConfig(allow_origin="https://www.example.com", extra_origins=["https://dev.example.com"], max_age=300)
app = APIGatewayRestResolver(cors=cors_config)


@app.get("/todos")
@tracer.capture_method
def get_todos():
todos: Response = requests.get("https://jsonplaceholder.typicode.com/todos")
todos.raise_for_status()

# for brevity, we'll limit to the first 10 only
return {"todos": todos.json()[:10]}


@app.get("/todos/<todo_id>")
@tracer.capture_method
def get_todo_by_id(todo_id: str): # value come as str
todos: Response = requests.get(f"https://jsonplaceholder.typicode.com/todos/{todo_id}")
todos.raise_for_status()

return {"todos": todos.json()}


@app.get("/healthcheck", cors=False) # optionally removes CORS for a given route
@tracer.capture_method
def am_i_alive():
return {"am_i_alive": "yes"}


# You can continue to use other utilities just as before
@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_REST)
@tracer.capture_lambda_handler
def lambda_handler(event: dict, context: LambdaContext) -> dict:
return app.resolve(event, context)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"statusCode": 200,
"multiValueHeaders": {
"Content-Type": ["application/json"],
"Access-Control-Allow-Origin": ["https://www.example.com","https://dev.example.com"],
"Access-Control-Allow-Headers": ["Authorization,Content-Type,X-Amz-Date,X-Amz-Security-Token,X-Api-Key"]
},
"body": "{\"todos\":[{\"userId\":1,\"id\":1,\"title\":\"delectus aut autem\",\"completed\":false},{\"userId\":1,\"id\":2,\"title\":\"quis ut nam facilis et officia qui\",\"completed\":false},{\"userId\":1,\"id\":3,\"title\":\"fugiat veniam minus\",\"completed\":false},{\"userId\":1,\"id\":4,\"title\":\"et porro tempora\",\"completed\":true},{\"userId\":1,\"id\":5,\"title\":\"laboriosam mollitia et enim quasi adipisci quia provident illum\",\"completed\":false},{\"userId\":1,\"id\":6,\"title\":\"qui ullam ratione quibusdam voluptatem quia omnis\",\"completed\":false},{\"userId\":1,\"id\":7,\"title\":\"illo expedita consequatur quia in\",\"completed\":false},{\"userId\":1,\"id\":8,\"title\":\"quo adipisci enim quam ut ab\",\"completed\":true},{\"userId\":1,\"id\":9,\"title\":\"molestiae perspiciatis ipsa\",\"completed\":false},{\"userId\":1,\"id\":10,\"title\":\"illo est ratione doloremque quia maiores aut\",\"completed\":true}]}",
"isBase64Encoded": false
}
12 changes: 9 additions & 3 deletions tests/e2e/event_handler/handlers/alb_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from aws_lambda_powertools.event_handler import ALBResolver, Response, content_types

app = ALBResolver()
from aws_lambda_powertools.event_handler import (
ALBResolver,
CORSConfig,
Response,
content_types,
)

cors_config = CORSConfig(allow_origin="https://www.example.org", extra_origins=["https://dev.example.org"])
app = ALBResolver(cors=cors_config)

# The reason we use post is that whoever is writing tests can easily assert on the
# content being sent (body, headers, cookies, content-type) to reduce cognitive load.
Expand Down
4 changes: 3 additions & 1 deletion tests/e2e/event_handler/handlers/api_gateway_http_handler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from aws_lambda_powertools.event_handler import (
APIGatewayHttpResolver,
CORSConfig,
Response,
content_types,
)

app = APIGatewayHttpResolver()
cors_config = CORSConfig(allow_origin="https://www.example.org", extra_origins=["https://dev.example.org"])
app = APIGatewayHttpResolver(cors=cors_config)

# The reason we use post is that whoever is writing tests can easily assert on the
# content being sent (body, headers, cookies, content-type) to reduce cognitive load.
Expand Down
4 changes: 3 additions & 1 deletion tests/e2e/event_handler/handlers/api_gateway_rest_handler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from aws_lambda_powertools.event_handler import (
APIGatewayRestResolver,
CORSConfig,
Response,
content_types,
)

app = APIGatewayRestResolver()
cors_config = CORSConfig(allow_origin="https://www.example.org", extra_origins=["https://dev.example.org"])
app = APIGatewayRestResolver(cors=cors_config)

# The reason we use post is that whoever is writing tests can easily assert on the
# content being sent (body, headers, cookies, content-type) to reduce cognitive load.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from aws_lambda_powertools.event_handler import (
CORSConfig,
LambdaFunctionUrlResolver,
Response,
content_types,
)

app = LambdaFunctionUrlResolver()
cors_config = CORSConfig(allow_origin="https://www.example.org", extra_origins=["https://dev.example.org"])
app = LambdaFunctionUrlResolver(cors=cors_config)

# The reason we use post is that whoever is writing tests can easily assert on the
# content being sent (body, headers, cookies, content-type) to reduce cognitive load.
Expand Down
Loading

0 comments on commit 042e83a

Please sign in to comment.