From 4619fe8d0438d51b63bd65e7126f2e0860bb21d1 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 22 Nov 2023 15:05:43 +0100 Subject: [PATCH 1/8] chore(event_handler): only apply serialization at the end --- aws_lambda_powertools/event_handler/api_gateway.py | 2 ++ tests/functional/event_handler/test_api_gateway.py | 12 ++++++------ tests/functional/event_handler/test_base_path.py | 12 ++++++------ 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 5b7262e5d55..fd54eb07b3a 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -789,6 +789,8 @@ def build(self, event: ResponseEventT, cors: Optional[CORSConfig] = None) -> Dic logger.debug("Encoding bytes response with base64") self.response.base64_encoded = True self.response.body = base64.b64encode(self.response.body).decode() + elif self.response.is_json(): + self.response.body = self.serializer(self.response.body) # We only apply the serializer when the content type is JSON and the # body is not a str, to avoid double encoding diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 570de9ec808..4ef5fa0896f 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -367,7 +367,7 @@ def test_override_route_compress_parameter(): # AND the Response object with compress=False app = ApiGatewayResolver() mock_event = {"path": "/my/request", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}} - expected_value = '{"test": "value"}' + expected_value = {"test": "value"} @app.get("/my/request", compress=True) def with_compression() -> Response: @@ -381,7 +381,7 @@ def handler(event, context): # THEN the response is not compressed assert result["isBase64Encoded"] is False - assert result["body"] == expected_value + assert json.loads(result["body"]) == expected_value assert result["multiValueHeaders"].get("Content-Encoding") is None @@ -681,7 +681,7 @@ def another_one(): def test_no_content_response(): # GIVEN a response with no content-type or body response = Response(status_code=204, content_type=None, body=None, headers=None) - response_builder = ResponseBuilder(response) + response_builder = ResponseBuilder(response, serializer=json.dumps) # WHEN calling to_dict result = response_builder.build(APIGatewayProxyEvent(LOAD_GW_EVENT)) @@ -1482,7 +1482,7 @@ def get_lambda() -> Response: # THEN call the exception_handler assert result["statusCode"] == 500 assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON] - assert result["body"] == "CUSTOM ERROR FORMAT" + assert result["body"] == '"CUSTOM ERROR FORMAT"' def test_exception_handler_not_found(): @@ -1778,11 +1778,11 @@ def test_route_match_prioritize_full_match(): @router.get("/my/{path}") def dynamic_handler() -> Response: - return Response(200, content_types.APPLICATION_JSON, json.dumps({"hello": "dynamic"})) + return Response(200, content_types.APPLICATION_JSON, {"hello": "dynamic"}) @router.get("/my/path") def static_handler() -> Response: - return Response(200, content_types.APPLICATION_JSON, json.dumps({"hello": "static"})) + return Response(200, content_types.APPLICATION_JSON, {"hello": "static"}) app.include_router(router) diff --git a/tests/functional/event_handler/test_base_path.py b/tests/functional/event_handler/test_base_path.py index 479a46bda55..adf3c5849df 100644 --- a/tests/functional/event_handler/test_base_path.py +++ b/tests/functional/event_handler/test_base_path.py @@ -21,7 +21,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == "" + assert result["body"] == '""' def test_base_path_api_gateway_http(): @@ -38,7 +38,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == "" + assert result["body"] == '""' def test_base_path_alb(): @@ -53,7 +53,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == "" + assert result["body"] == '""' def test_base_path_lambda_function_url(): @@ -70,7 +70,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == "" + assert result["body"] == '""' def test_vpc_lattice(): @@ -85,7 +85,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == "" + assert result["body"] == '""' def test_vpc_latticev2(): @@ -100,4 +100,4 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == "" + assert result["body"] == '""' From 0c264170f4e5b42a098cebeb1769ee0101aee15f Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 22 Nov 2023 14:01:51 +0100 Subject: [PATCH 2/8] fix: avoid double encoding --- aws_lambda_powertools/event_handler/api_gateway.py | 5 ++++- tests/functional/event_handler/test_api_gateway.py | 2 +- tests/functional/event_handler/test_base_path.py | 12 ++++++------ 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index fd54eb07b3a..ef4b2be5860 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -789,7 +789,10 @@ def build(self, event: ResponseEventT, cors: Optional[CORSConfig] = None) -> Dic logger.debug("Encoding bytes response with base64") self.response.base64_encoded = True self.response.body = base64.b64encode(self.response.body).decode() - elif self.response.is_json(): + + # We only apply the serializer when the content type is JSON and the + # body is not a str, to avoid double encoding + elif self.response.is_json() and not isinstance(self.response.body, str): self.response.body = self.serializer(self.response.body) # We only apply the serializer when the content type is JSON and the diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 4ef5fa0896f..3cb1261eccd 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -1482,7 +1482,7 @@ def get_lambda() -> Response: # THEN call the exception_handler assert result["statusCode"] == 500 assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON] - assert result["body"] == '"CUSTOM ERROR FORMAT"' + assert result["body"] == "CUSTOM ERROR FORMAT" def test_exception_handler_not_found(): diff --git a/tests/functional/event_handler/test_base_path.py b/tests/functional/event_handler/test_base_path.py index adf3c5849df..479a46bda55 100644 --- a/tests/functional/event_handler/test_base_path.py +++ b/tests/functional/event_handler/test_base_path.py @@ -21,7 +21,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == '""' + assert result["body"] == "" def test_base_path_api_gateway_http(): @@ -38,7 +38,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == '""' + assert result["body"] == "" def test_base_path_alb(): @@ -53,7 +53,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == '""' + assert result["body"] == "" def test_base_path_lambda_function_url(): @@ -70,7 +70,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == '""' + assert result["body"] == "" def test_vpc_lattice(): @@ -85,7 +85,7 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == '""' + assert result["body"] == "" def test_vpc_latticev2(): @@ -100,4 +100,4 @@ def handle(): result = app(event, {}) assert result["statusCode"] == 200 - assert result["body"] == '""' + assert result["body"] == "" From 2ebf19c7e4f780b3cdb070b58de9af277eec8371 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 22 Nov 2023 14:06:12 +0100 Subject: [PATCH 3/8] fix: rolled back test changes --- tests/functional/event_handler/test_api_gateway.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 3cb1261eccd..e370ca4b99d 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -367,7 +367,7 @@ def test_override_route_compress_parameter(): # AND the Response object with compress=False app = ApiGatewayResolver() mock_event = {"path": "/my/request", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}} - expected_value = {"test": "value"} + expected_value = '{"test": "value"}' @app.get("/my/request", compress=True) def with_compression() -> Response: @@ -381,7 +381,7 @@ def handler(event, context): # THEN the response is not compressed assert result["isBase64Encoded"] is False - assert json.loads(result["body"]) == expected_value + assert result["body"] == expected_value assert result["multiValueHeaders"].get("Content-Encoding") is None @@ -1778,11 +1778,11 @@ def test_route_match_prioritize_full_match(): @router.get("/my/{path}") def dynamic_handler() -> Response: - return Response(200, content_types.APPLICATION_JSON, {"hello": "dynamic"}) + return Response(200, content_types.APPLICATION_JSON, json.dumps({"hello": "dynamic"})) @router.get("/my/path") def static_handler() -> Response: - return Response(200, content_types.APPLICATION_JSON, {"hello": "static"}) + return Response(200, content_types.APPLICATION_JSON, json.dumps({"hello": "static"})) app.include_router(router) From dbebe2ac49aa2dbc5c3824d2eb6fd87121cca7fb Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 22 Nov 2023 15:11:28 +0100 Subject: [PATCH 4/8] fix: remove code from bad rebase --- aws_lambda_powertools/event_handler/api_gateway.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index ef4b2be5860..5b7262e5d55 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -795,11 +795,6 @@ def build(self, event: ResponseEventT, cors: Optional[CORSConfig] = None) -> Dic elif self.response.is_json() and not isinstance(self.response.body, str): self.response.body = self.serializer(self.response.body) - # We only apply the serializer when the content type is JSON and the - # body is not a str, to avoid double encoding - elif self.response.is_json() and not isinstance(self.response.body, str): - self.response.body = self.serializer(self.response.body) - return { "statusCode": self.response.status_code, "body": self.response.body, From c23312ee20bdfd349dc4ae8864c019dc201eeb39 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 22 Nov 2023 15:13:22 +0100 Subject: [PATCH 5/8] fix: remove unused code --- tests/functional/event_handler/test_api_gateway.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index e370ca4b99d..570de9ec808 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -681,7 +681,7 @@ def another_one(): def test_no_content_response(): # GIVEN a response with no content-type or body response = Response(status_code=204, content_type=None, body=None, headers=None) - response_builder = ResponseBuilder(response, serializer=json.dumps) + response_builder = ResponseBuilder(response) # WHEN calling to_dict result = response_builder.build(APIGatewayProxyEvent(LOAD_GW_EVENT)) From 3b5132e317832bb9dbb9b98ed0e3c0e772345ab0 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 22 Nov 2023 15:56:29 +0100 Subject: [PATCH 6/8] chore(event_handler): enable exception handler when using validation --- .../event_handler/api_gateway.py | 12 +++ .../middlewares/openapi_validation.py | 73 +++++++++---------- 2 files changed, 45 insertions(+), 40 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 5b7262e5d55..05831a2eea5 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -32,6 +32,7 @@ from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION +from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError from aws_lambda_powertools.event_handler.openapi.swagger_ui.html import generate_swagger_html from aws_lambda_powertools.event_handler.openapi.types import ( COMPONENT_REF_PREFIX, @@ -1972,6 +1973,17 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[Resp except ServiceError as service_error: exp = service_error + if isinstance(exp, RequestValidationError): + return self._response_builder_class( + response=Response( + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + content_type=content_types.APPLICATION_JSON, + body={"statusCode": HTTPStatus.UNPROCESSABLE_ENTITY, "message": exp.errors()}, + ), + serializer=self._serializer, + route=route, + ) + if isinstance(exp, ServiceError): return self._response_builder_class( response=Response( diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 131f9d267a3..34011b64384 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -62,50 +62,43 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> values: Dict[str, Any] = {} errors: List[Any] = [] - try: - # Process path values, which can be found on the route_args - path_values, path_errors = _request_params_to_args( - route.dependant.path_params, - app.context["_route_args"], + # Process path values, which can be found on the route_args + path_values, path_errors = _request_params_to_args( + route.dependant.path_params, + app.context["_route_args"], + ) + + # Process query values + query_values, query_errors = _request_params_to_args( + route.dependant.query_params, + app.current_event.query_string_parameters or {}, + ) + + values.update(path_values) + values.update(query_values) + errors += path_errors + query_errors + + # Process the request body, if it exists + if route.dependant.body_params: + (body_values, body_errors) = _request_body_to_args( + required_params=route.dependant.body_params, + received_body=self._get_body(app), ) + values.update(body_values) + errors.extend(body_errors) - # Process query values - query_values, query_errors = _request_params_to_args( - route.dependant.query_params, - app.current_event.query_string_parameters or {}, - ) - - values.update(path_values) - values.update(query_values) - errors += path_errors + query_errors + if errors: + # Raise the validation errors + raise RequestValidationError(_normalize_errors(errors)) + else: + # Re-write the route_args with the validated values, and call the next middleware + app.context["_route_args"] = values - # Process the request body, if it exists - if route.dependant.body_params: - (body_values, body_errors) = _request_body_to_args( - required_params=route.dependant.body_params, - received_body=self._get_body(app), - ) - values.update(body_values) - errors.extend(body_errors) + # Call the handler by calling the next middleware + response = next_middleware(app) - if errors: - # Raise the validation errors - raise RequestValidationError(_normalize_errors(errors)) - else: - # Re-write the route_args with the validated values, and call the next middleware - app.context["_route_args"] = values - - # Call the handler by calling the next middleware - response = next_middleware(app) - - # Process the response - return self._handle_response(route=route, response=response) - except RequestValidationError as e: - return Response( - status_code=422, - content_type="application/json", - body=json.dumps({"detail": e.errors()}), - ) + # Process the response + return self._handle_response(route=route, response=response) def _handle_response(self, *, route: Route, response: Response): # Process the response body if it exists From 4190dd14460492308df6b0ec7d92858f28897839 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 22 Nov 2023 16:07:52 +0100 Subject: [PATCH 7/8] chore: add test --- .../event_handler/test_api_gateway.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 570de9ec808..d4c88b541aa 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -30,6 +30,7 @@ ServiceError, UnauthorizedError, ) +from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError from aws_lambda_powertools.shared import constants from aws_lambda_powertools.shared.cookies import Cookie from aws_lambda_powertools.shared.json_encoder import Encoder @@ -1458,6 +1459,51 @@ def get_lambda() -> Response: assert result["body"] == "Foo!" +def test_exception_handler_with_data_validation(): + # GIVEN a resolver with an exception handler defined for RequestValidationError + app = ApiGatewayResolver(enable_validation=True) + + @app.exception_handler(RequestValidationError) + def handle_validation_error(ex: RequestValidationError): + print(f"request path is '{app.current_event.path}'") + return Response( + status_code=422, + content_type=content_types.TEXT_PLAIN, + body=f"Invalid data. Number of errors: {len(ex.errors())}", + ) + + @app.get("/my/path") + def get_lambda(param: int): + ... + + # WHEN calling the event handler + # AND a RequestValidationError is raised + result = app(LOAD_GW_EVENT, {}) + + # THEN call the exception_handler + assert result["statusCode"] == 422 + assert result["multiValueHeaders"]["Content-Type"] == [content_types.TEXT_PLAIN] + assert result["body"] == "Invalid data. Number of errors: 1" + + +def test_data_validation_error(): + # GIVEN a resolver without an exception handler + app = ApiGatewayResolver(enable_validation=True) + + @app.get("/my/path") + def get_lambda(param: int): + ... + + # WHEN calling the event handler + # AND a RequestValidationError is raised + result = app(LOAD_GW_EVENT, {}) + + # THEN call the exception_handler + assert result["statusCode"] == 422 + assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON] + assert "missing" in result["body"] + + def test_exception_handler_service_error(): # GIVEN app = ApiGatewayResolver() From 6ca9fb1766467af00b36e67049d6c3b873a90370 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 22 Nov 2023 16:20:49 +0100 Subject: [PATCH 8/8] fix: failing test --- .../event_handler/test_openapi_validation_middleware.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index 9c7ca371d54..f558bd23ced 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -343,7 +343,7 @@ class Model(BaseModel): # WHEN a handler is defined with a body parameter @app.post("/") def handler(user: Model) -> Response[Model]: - return Response(body=user, status_code=200) + return Response(body=user, status_code=200, content_type="application/json") LOAD_GW_EVENT["httpMethod"] = "POST" LOAD_GW_EVENT["path"] = "/" @@ -353,7 +353,7 @@ def handler(user: Model) -> Response[Model]: # THEN the body must be a dict result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 200 - assert result["body"] == {"name": "John", "age": 30} + assert json.loads(result["body"]) == {"name": "John", "age": 30} def test_validate_response_invalid_return():