Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(event_handler): add support for multiValueQueryStringParameters in OpenAPI schema #3667

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_regenerate_error_with_loc,
get_missing_field_error,
)
from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
from aws_lambda_powertools.event_handler.openapi.params import Param
Expand Down Expand Up @@ -68,10 +69,16 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
app.context["_route_args"],
)

# Normalize query values before validate this
query_string = _normalize_multi_query_string_with_param(
app.current_event.resolved_query_string_parameters,
route.dependant.query_params,
)

# Process query values
query_values, query_errors = _request_params_to_args(
route.dependant.query_params,
app.current_event.query_string_parameters or {},
query_string,
)

values.update(path_values)
Expand Down Expand Up @@ -344,3 +351,29 @@ def _get_embed_body(
received_body = {field.alias: received_body}

return received_body, field_alias_omitted


def _normalize_multi_query_string_with_param(query_string: Optional[Dict[str, str]], params: Sequence[ModelField]):
"""
Extract and normalize resolved_query_string_parameters

Parameters
----------
query_string: Dict
A dictionary containing the initial query string parameters.
params: Sequence[ModelField]
A sequence of ModelField objects representing parameters.

Returns
-------
A dictionary containing the processed multi_query_string_parameters.
"""
if query_string:
for param in filter(is_scalar_field, params):
try:
# if the target parameter is a scalar, we keep the first value of the query string
# regardless if there are more in the payload
query_string[param.name] = query_string[param.name][0]
except KeyError:
pass
return query_string
9 changes: 8 additions & 1 deletion aws_lambda_powertools/utilities/data_classes/alb_event.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

from aws_lambda_powertools.shared.headers_serializer import (
BaseHeadersSerializer,
Expand Down Expand Up @@ -35,6 +35,13 @@ def request_context(self) -> ALBEventRequestContext:
def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]:
return self.get("multiValueQueryStringParameters")

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
if self.multi_value_query_string_parameters:
return self.multi_value_query_string_parameters

return self.query_string_parameters

@property
def multi_value_headers(self) -> Optional[Dict[str, List[str]]]:
return self.get("multiValueHeaders")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ def multi_value_headers(self) -> Dict[str, List[str]]:
def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]:
return self.get("multiValueQueryStringParameters")

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
if self.multi_value_query_string_parameters:
return self.multi_value_query_string_parameters

return self.query_string_parameters

@property
def request_context(self) -> APIGatewayEventRequestContext:
return APIGatewayEventRequestContext(self._data)
Expand Down Expand Up @@ -299,3 +306,13 @@ def http_method(self) -> str:

def header_serializer(self):
return HttpApiHeadersSerializer()

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
if self.query_string_parameters is not None:
query_string = {
key: value.split(",") if "," in value else value for key, value in self.query_string_parameters.items()
}
return query_string

return {}
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,7 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]:
# In Bedrock Agent events, query string parameters are passed as undifferentiated parameters,
# together with the other parameters. So we just return all parameters here.
return {x["name"]: x["value"] for x in self["parameters"]} if self.get("parameters") else None

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.query_string_parameters
11 changes: 11 additions & 0 deletions aws_lambda_powertools/utilities/data_classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,17 @@ def headers(self) -> Dict[str, str]:
def query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.get("queryStringParameters")

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
"""
This property determines the appropriate query string parameter to be used
as a trusted source for validating OpenAPI.
leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved

This is necessary because different resolvers use different formats to encode
multi query string parameters.
"""
return self.query_string_parameters

@property
def is_base64_encoded(self) -> Optional[bool]:
return self.get("isBase64Encoded")
Expand Down
8 changes: 8 additions & 0 deletions aws_lambda_powertools/utilities/data_classes/vpc_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def query_string_parameters(self) -> Dict[str, str]:
"""The request query string parameters."""
return self["query_string_parameters"]

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.query_string_parameters


class vpcLatticeEventV2Identity(DictWrapper):
@property
Expand Down Expand Up @@ -251,3 +255,7 @@ def request_context(self) -> vpcLatticeEventV2RequestContext:
def query_string_parameters(self) -> Optional[Dict[str, str]]:
"""The request query string parameters."""
return self.get("queryStringParameters")

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.query_string_parameters
10 changes: 10 additions & 0 deletions docs/core/event_handler/api_gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,16 @@ In the following example, we use a new `Query` OpenAPI type to add [one out of m

1. `completed` is still the same query string as before, except we simply state it's an string. No `Query` or `Annotated` to validate it.

=== "working_with_multi_query_values.py"

If you need to handle multi-value query parameters, you can create a list of the desired type.

```python hl_lines="23"
--8<-- "examples/event_handler_rest/src/working_with_multi_query_values.py"
```

1. `example_multi_value_param` is a list containing values from the `ExampleEnum` enumeration.

<!-- markdownlint-enable MD013 -->

#### Validating path parameters
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from enum import Enum
from typing import List

from aws_lambda_powertools.event_handler import APIGatewayRestResolver
from aws_lambda_powertools.event_handler.openapi.params import Query
from aws_lambda_powertools.shared.types import Annotated
from aws_lambda_powertools.utilities.typing import LambdaContext

app = APIGatewayRestResolver(enable_validation=True)


class ExampleEnum(Enum):
"""Example of an Enum class."""

ONE = "value_one"
TWO = "value_two"
THREE = "value_three"


@app.get("/todos")
def get(
example_multi_value_param: Annotated[
List[ExampleEnum], # (1)!
Query(
description="This is multi value query parameter.",
),
],
):
"""Return validated multi-value param values."""
return example_multi_value_param


def lambda_handler(event: dict, context: LambdaContext) -> dict:
return app.resolve(event, context)
38 changes: 38 additions & 0 deletions tests/events/albMultiValueQueryStringEvent.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"requestContext": {
"elb": {
"targetGroupArn": "arn:aws:elasticloadbalancing:eu-central-1:1234567890:targetgroup/alb-c-Targe-11GDXTPQ7663S/804a67588bfdc10f"
}
},
"httpMethod": "GET",
"path": "/todos",
"multiValueQueryStringParameters": {
"parameter1": ["value1","value2"],
"parameter2": ["value"]
},
"multiValueHeaders": {
"accept": [
"*/*"
],
"host": [
"alb-c-LoadB-14POFKYCLBNSF-1815800096.eu-central-1.elb.amazonaws.com"
],
"user-agent": [
"curl/7.79.1"
],
"x-amzn-trace-id": [
"Root=1-62fa9327-21cdd4da4c6db451490a5fb7"
],
"x-forwarded-for": [
"123.123.123.123"
],
"x-forwarded-port": [
"80"
],
"x-forwarded-proto": [
"http"
]
},
"body": "",
"isBase64Encoded": false
}
51 changes: 51 additions & 0 deletions tests/events/lambdaFunctionUrlEventWithHeaders.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
{
"version":"2.0",
"routeKey":"$default",
"rawPath":"/",
"rawQueryString":"",
"headers":{
"sec-fetch-mode":"navigate",
"x-amzn-tls-version":"TLSv1.2",
"sec-fetch-site":"cross-site",
"accept-language":"pt-BR,pt;q=0.9",
"x-forwarded-proto":"https",
"x-forwarded-port":"443",
"x-forwarded-for":"123.123.123.123",
"sec-fetch-user":"?1",
"accept":"text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
"x-amzn-tls-cipher-suite":"ECDHE-RSA-AES128-GCM-SHA256",
"sec-ch-ua":"\" Not A;Brand\";v=\"99\", \"Chromium\";v=\"102\", \"Google Chrome\";v=\"102\"",
"sec-ch-ua-mobile":"?0",
"x-amzn-trace-id":"Root=1-62ecd163-5f302e550dcde3b12402207d",
"sec-ch-ua-platform":"\"Linux\"",
"host":"<url-id>.lambda-url.us-east-1.on.aws",
"upgrade-insecure-requests":"1",
"cache-control":"max-age=0",
"accept-encoding":"gzip, deflate, br",
"sec-fetch-dest":"document",
"user-agent":"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36"
},
"queryStringParameters": {
"parameter1": "value1,value2",
"parameter2": "value"
},
"requestContext":{
"accountId":"anonymous",
"apiId":"<url-id>",
"domainName":"<url-id>.lambda-url.us-east-1.on.aws",
"domainPrefix":"<url-id>",
"http":{
"method":"GET",
"path":"/",
"protocol":"HTTP/1.1",
"sourceIp":"123.123.123.123",
"userAgent":"agent"
},
"requestId":"id",
"routeKey":"$default",
"stage":"$default",
"time":"05/Aug/2022:08:14:39 +0000",
"timeEpoch":1659687279885
},
"isBase64Encoded":false
}
36 changes: 36 additions & 0 deletions tests/events/vpcLatticeV2EventWithHeaders.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"version": "2.0",
"path": "/newpath",
"method": "GET",
"headers": {
"user_agent": "curl/7.64.1",
"x-forwarded-for": "10.213.229.10",
"host": "test-lambda-service-3908sdf9u3u.dkfjd93.vpc-lattice-svcs.us-east-2.on.aws",
"accept": "*/*"
},
"queryStringParameters": {
"parameter1": [
"value1",
"value2"
],
"parameter2": [
"value"
]
},
"body": "{\"message\": \"Hello from Lambda!\"}",
"isBase64Encoded": false,
"requestContext": {
"serviceNetworkArn": "arn:aws:vpc-lattice:us-east-2:123456789012:servicenetwork/sn-0bf3f2882e9cc805a",
"serviceArn": "arn:aws:vpc-lattice:us-east-2:123456789012:service/svc-0a40eebed65f8d69c",
"targetGroupArn": "arn:aws:vpc-lattice:us-east-2:123456789012:targetgroup/tg-6d0ecf831eec9f09",
"identity": {
"sourceVpcArn": "arn:aws:ec2:region:123456789012:vpc/vpc-0b8276c84697e7339",
"type" : "AWS_IAM",
"principal": "arn:aws:sts::123456789012:assumed-role/example-role/057d00f8b51257ba3c853a0f248943cf",
"sessionName": "057d00f8b51257ba3c853a0f248943cf",
"x509SanDns": "example.com"
},
"region": "us-east-2",
"timeEpoch": "1696331543569073"
}
}
14 changes: 14 additions & 0 deletions tests/functional/event_handler/test_openapi_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,20 @@ def handler(page: Annotated[str, Query(include_in_schema=False)]):
assert get.parameters is None


def test_openapi_with_list_param():
app = APIGatewayRestResolver()

@app.get("/")
def handler(page: Annotated[List[str], Query()]):
return page

schema = app.get_openapi_schema()
assert len(schema.paths.keys()) == 1

get = schema.paths["/"].get
assert get.parameters[0].schema_.type == "array"


def test_openapi_with_description():
app = APIGatewayRestResolver()

Expand Down
Loading