Skip to content

feat(event_handler): allow customers to catch request validation errors #3396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from aws_lambda_powertools.event_handler import content_types
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
from aws_lambda_powertools.event_handler.openapi.swagger_ui.html import generate_swagger_html
from aws_lambda_powertools.event_handler.openapi.types import (
COMPONENT_REF_PREFIX,
Expand Down Expand Up @@ -1972,6 +1973,17 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[Resp
except ServiceError as service_error:
exp = service_error

if isinstance(exp, RequestValidationError):
return self._response_builder_class(
response=Response(
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
content_type=content_types.APPLICATION_JSON,
body={"statusCode": HTTPStatus.UNPROCESSABLE_ENTITY, "message": exp.errors()},
),
serializer=self._serializer,
route=route,
)

if isinstance(exp, ServiceError):
return self._response_builder_class(
response=Response(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,50 +62,43 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
values: Dict[str, Any] = {}
errors: List[Any] = []

try:
# Process path values, which can be found on the route_args
path_values, path_errors = _request_params_to_args(
route.dependant.path_params,
app.context["_route_args"],
# Process path values, which can be found on the route_args
path_values, path_errors = _request_params_to_args(
route.dependant.path_params,
app.context["_route_args"],
)

# Process query values
query_values, query_errors = _request_params_to_args(
route.dependant.query_params,
app.current_event.query_string_parameters or {},
)

values.update(path_values)
values.update(query_values)
errors += path_errors + query_errors

# Process the request body, if it exists
if route.dependant.body_params:
(body_values, body_errors) = _request_body_to_args(
required_params=route.dependant.body_params,
received_body=self._get_body(app),
)
values.update(body_values)
errors.extend(body_errors)

# Process query values
query_values, query_errors = _request_params_to_args(
route.dependant.query_params,
app.current_event.query_string_parameters or {},
)

values.update(path_values)
values.update(query_values)
errors += path_errors + query_errors
if errors:
# Raise the validation errors
raise RequestValidationError(_normalize_errors(errors))
else:
# Re-write the route_args with the validated values, and call the next middleware
app.context["_route_args"] = values

# Process the request body, if it exists
if route.dependant.body_params:
(body_values, body_errors) = _request_body_to_args(
required_params=route.dependant.body_params,
received_body=self._get_body(app),
)
values.update(body_values)
errors.extend(body_errors)
# Call the handler by calling the next middleware
response = next_middleware(app)

if errors:
# Raise the validation errors
raise RequestValidationError(_normalize_errors(errors))
else:
# Re-write the route_args with the validated values, and call the next middleware
app.context["_route_args"] = values

# Call the handler by calling the next middleware
response = next_middleware(app)

# Process the response
return self._handle_response(route=route, response=response)
except RequestValidationError as e:
return Response(
status_code=422,
content_type="application/json",
body=json.dumps({"detail": e.errors()}),
)
# Process the response
return self._handle_response(route=route, response=response)

def _handle_response(self, *, route: Route, response: Response):
# Process the response body if it exists
Expand Down
46 changes: 46 additions & 0 deletions tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ServiceError,
UnauthorizedError,
)
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
from aws_lambda_powertools.shared import constants
from aws_lambda_powertools.shared.cookies import Cookie
from aws_lambda_powertools.shared.json_encoder import Encoder
Expand Down Expand Up @@ -1458,6 +1459,51 @@ def get_lambda() -> Response:
assert result["body"] == "Foo!"


def test_exception_handler_with_data_validation():
# GIVEN a resolver with an exception handler defined for RequestValidationError
app = ApiGatewayResolver(enable_validation=True)

@app.exception_handler(RequestValidationError)
def handle_validation_error(ex: RequestValidationError):
print(f"request path is '{app.current_event.path}'")
return Response(
status_code=422,
content_type=content_types.TEXT_PLAIN,
body=f"Invalid data. Number of errors: {len(ex.errors())}",
)

@app.get("/my/path")
def get_lambda(param: int):
...

# WHEN calling the event handler
# AND a RequestValidationError is raised
result = app(LOAD_GW_EVENT, {})

# THEN call the exception_handler
assert result["statusCode"] == 422
assert result["multiValueHeaders"]["Content-Type"] == [content_types.TEXT_PLAIN]
assert result["body"] == "Invalid data. Number of errors: 1"


def test_data_validation_error():
# GIVEN a resolver without an exception handler
app = ApiGatewayResolver(enable_validation=True)

@app.get("/my/path")
def get_lambda(param: int):
...

# WHEN calling the event handler
# AND a RequestValidationError is raised
result = app(LOAD_GW_EVENT, {})

# THEN call the exception_handler
assert result["statusCode"] == 422
assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON]
assert "missing" in result["body"]


def test_exception_handler_service_error():
# GIVEN
app = ApiGatewayResolver()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ class Model(BaseModel):
# WHEN a handler is defined with a body parameter
@app.post("/")
def handler(user: Model) -> Response[Model]:
return Response(body=user, status_code=200)
return Response(body=user, status_code=200, content_type="application/json")

LOAD_GW_EVENT["httpMethod"] = "POST"
LOAD_GW_EVENT["path"] = "/"
Expand All @@ -353,7 +353,7 @@ def handler(user: Model) -> Response[Model]:
# THEN the body must be a dict
result = app(LOAD_GW_EVENT, {})
assert result["statusCode"] == 200
assert result["body"] == {"name": "John", "age": 30}
assert json.loads(result["body"]) == {"name": "John", "age": 30}


def test_validate_response_invalid_return():
Expand Down