Skip to content

Commit 8c859de

Browse files
author
Michael Brewer
authored
feat(apigateway): add exception_handler support (#898)
1 parent e91932c commit 8c859de

File tree

2 files changed

+123
-15
lines changed

2 files changed

+123
-15
lines changed

Diff for: aws_lambda_powertools/event_handler/api_gateway.py

+49-14
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from enum import Enum
1111
from functools import partial
1212
from http import HTTPStatus
13-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
13+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
1414

1515
from aws_lambda_powertools.event_handler import content_types
16-
from aws_lambda_powertools.event_handler.exceptions import ServiceError
16+
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
1717
from aws_lambda_powertools.shared import constants
1818
from aws_lambda_powertools.shared.functions import resolve_truthy_env_var_choice
1919
from aws_lambda_powertools.shared.json_encoder import Encoder
@@ -27,7 +27,6 @@
2727
_SAFE_URI = "-._~()'!*:@,;" # https://www.ietf.org/rfc/rfc3986.txt
2828
# API GW/ALB decode non-safe URI chars; we must support them too
2929
_UNSAFE_URI = "%<>\[\]{}|^" # noqa: W605
30-
3130
_NAMED_GROUP_BOUNDARY_PATTERN = fr"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)"
3231

3332

@@ -435,6 +434,7 @@ def __init__(
435434
self._proxy_type = proxy_type
436435
self._routes: List[Route] = []
437436
self._route_keys: List[str] = []
437+
self._exception_handlers: Dict[Type, Callable] = {}
438438
self._cors = cors
439439
self._cors_enabled: bool = cors is not None
440440
self._cors_methods: Set[str] = {"OPTIONS"}
@@ -596,6 +596,10 @@ def _not_found(self, method: str) -> ResponseBuilder:
596596
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
597597
return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=None))
598598

599+
handler = self._lookup_exception_handler(NotFoundError)
600+
if handler:
601+
return ResponseBuilder(handler(NotFoundError()))
602+
599603
return ResponseBuilder(
600604
Response(
601605
status_code=HTTPStatus.NOT_FOUND.value,
@@ -609,16 +613,11 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
609613
"""Actually call the matching route with any provided keyword arguments."""
610614
try:
611615
return ResponseBuilder(self._to_response(route.func(**args)), route)
612-
except ServiceError as e:
613-
return ResponseBuilder(
614-
Response(
615-
status_code=e.status_code,
616-
content_type=content_types.APPLICATION_JSON,
617-
body=self._json_dump({"statusCode": e.status_code, "message": e.msg}),
618-
),
619-
route,
620-
)
621-
except Exception:
616+
except Exception as exc:
617+
response_builder = self._call_exception_handler(exc, route)
618+
if response_builder:
619+
return response_builder
620+
622621
if self._debug:
623622
# If the user has turned on debug mode,
624623
# we'll let the original exception propagate so
@@ -628,10 +627,46 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
628627
status_code=500,
629628
content_type=content_types.TEXT_PLAIN,
630629
body="".join(traceback.format_exc()),
631-
)
630+
),
631+
route,
632632
)
633+
633634
raise
634635

636+
def not_found(self, func: Callable):
637+
return self.exception_handler(NotFoundError)(func)
638+
639+
def exception_handler(self, exc_class: Type[Exception]):
640+
def register_exception_handler(func: Callable):
641+
self._exception_handlers[exc_class] = func
642+
643+
return register_exception_handler
644+
645+
def _lookup_exception_handler(self, exp_type: Type) -> Optional[Callable]:
646+
# Use "Method Resolution Order" to allow for matching against a base class
647+
# of an exception
648+
for cls in exp_type.__mro__:
649+
if cls in self._exception_handlers:
650+
return self._exception_handlers[cls]
651+
return None
652+
653+
def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[ResponseBuilder]:
654+
handler = self._lookup_exception_handler(type(exp))
655+
if handler:
656+
return ResponseBuilder(handler(exp), route)
657+
658+
if isinstance(exp, ServiceError):
659+
return ResponseBuilder(
660+
Response(
661+
status_code=exp.status_code,
662+
content_type=content_types.APPLICATION_JSON,
663+
body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}),
664+
),
665+
route,
666+
)
667+
668+
return None
669+
635670
def _to_response(self, result: Union[Dict, Response]) -> Response:
636671
"""Convert the route's result to a Response
637672

Diff for: tests/functional/event_handler/test_api_gateway.py

+74-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def patch_func():
163163
def handler(event, context):
164164
return app.resolve(event, context)
165165

166-
# Also check check the route configurations
166+
# Also check the route configurations
167167
routes = app._routes
168168
assert len(routes) == 5
169169
for route in routes:
@@ -1076,3 +1076,76 @@ def foo():
10761076

10771077
assert result["statusCode"] == 200
10781078
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
1079+
1080+
1081+
def test_exception_handler():
1082+
# GIVEN a resolver with an exception handler defined for ValueError
1083+
app = ApiGatewayResolver()
1084+
1085+
@app.exception_handler(ValueError)
1086+
def handle_value_error(ex: ValueError):
1087+
print(f"request path is '{app.current_event.path}'")
1088+
return Response(
1089+
status_code=418,
1090+
content_type=content_types.TEXT_HTML,
1091+
body=str(ex),
1092+
)
1093+
1094+
@app.get("/my/path")
1095+
def get_lambda() -> Response:
1096+
raise ValueError("Foo!")
1097+
1098+
# WHEN calling the event handler
1099+
# AND a ValueError is raised
1100+
result = app(LOAD_GW_EVENT, {})
1101+
1102+
# THEN call the exception_handler
1103+
assert result["statusCode"] == 418
1104+
assert result["headers"]["Content-Type"] == content_types.TEXT_HTML
1105+
assert result["body"] == "Foo!"
1106+
1107+
1108+
def test_exception_handler_service_error():
1109+
# GIVEN
1110+
app = ApiGatewayResolver()
1111+
1112+
@app.exception_handler(ServiceError)
1113+
def service_error(ex: ServiceError):
1114+
print(ex.msg)
1115+
return Response(
1116+
status_code=ex.status_code,
1117+
content_type=content_types.APPLICATION_JSON,
1118+
body="CUSTOM ERROR FORMAT",
1119+
)
1120+
1121+
@app.get("/my/path")
1122+
def get_lambda() -> Response:
1123+
raise InternalServerError("Something sensitive")
1124+
1125+
# WHEN calling the event handler
1126+
# AND a ServiceError is raised
1127+
result = app(LOAD_GW_EVENT, {})
1128+
1129+
# THEN call the exception_handler
1130+
assert result["statusCode"] == 500
1131+
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
1132+
assert result["body"] == "CUSTOM ERROR FORMAT"
1133+
1134+
1135+
def test_exception_handler_not_found():
1136+
# GIVEN a resolver with an exception handler defined for a 404 not found
1137+
app = ApiGatewayResolver()
1138+
1139+
@app.not_found
1140+
def handle_not_found(exc: NotFoundError) -> Response:
1141+
assert isinstance(exc, NotFoundError)
1142+
return Response(status_code=404, content_type=content_types.TEXT_PLAIN, body="I am a teapot!")
1143+
1144+
# WHEN calling the event handler
1145+
# AND not route is found
1146+
result = app(LOAD_GW_EVENT, {})
1147+
1148+
# THEN call the exception_handler
1149+
assert result["statusCode"] == 404
1150+
assert result["headers"]["Content-Type"] == content_types.TEXT_PLAIN
1151+
assert result["body"] == "I am a teapot!"

0 commit comments

Comments
 (0)