10
10
from enum import Enum
11
11
from functools import partial
12
12
from http import HTTPStatus
13
- from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Union
13
+ from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Type , Union
14
14
15
15
from aws_lambda_powertools .event_handler import content_types
16
- from aws_lambda_powertools .event_handler .exceptions import ServiceError
16
+ from aws_lambda_powertools .event_handler .exceptions import NotFoundError , ServiceError
17
17
from aws_lambda_powertools .shared import constants
18
18
from aws_lambda_powertools .shared .functions import resolve_truthy_env_var_choice
19
19
from aws_lambda_powertools .shared .json_encoder import Encoder
27
27
_SAFE_URI = "-._~()'!*:@,;" # https://www.ietf.org/rfc/rfc3986.txt
28
28
# API GW/ALB decode non-safe URI chars; we must support them too
29
29
_UNSAFE_URI = "%<>\[\]{}|^" # noqa: W605
30
-
31
30
_NAMED_GROUP_BOUNDARY_PATTERN = fr"(?P\1[{ _SAFE_URI } { _UNSAFE_URI } \\w]+)"
32
31
33
32
@@ -435,6 +434,7 @@ def __init__(
435
434
self ._proxy_type = proxy_type
436
435
self ._routes : List [Route ] = []
437
436
self ._route_keys : List [str ] = []
437
+ self ._exception_handlers : Dict [Type , Callable ] = {}
438
438
self ._cors = cors
439
439
self ._cors_enabled : bool = cors is not None
440
440
self ._cors_methods : Set [str ] = {"OPTIONS" }
@@ -596,6 +596,10 @@ def _not_found(self, method: str) -> ResponseBuilder:
596
596
headers ["Access-Control-Allow-Methods" ] = "," .join (sorted (self ._cors_methods ))
597
597
return ResponseBuilder (Response (status_code = 204 , content_type = None , headers = headers , body = None ))
598
598
599
+ handler = self ._lookup_exception_handler (NotFoundError )
600
+ if handler :
601
+ return ResponseBuilder (handler (NotFoundError ()))
602
+
599
603
return ResponseBuilder (
600
604
Response (
601
605
status_code = HTTPStatus .NOT_FOUND .value ,
@@ -609,16 +613,11 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
609
613
"""Actually call the matching route with any provided keyword arguments."""
610
614
try :
611
615
return ResponseBuilder (self ._to_response (route .func (** args )), route )
612
- except ServiceError as e :
613
- return ResponseBuilder (
614
- Response (
615
- status_code = e .status_code ,
616
- content_type = content_types .APPLICATION_JSON ,
617
- body = self ._json_dump ({"statusCode" : e .status_code , "message" : e .msg }),
618
- ),
619
- route ,
620
- )
621
- except Exception :
616
+ except Exception as exc :
617
+ response_builder = self ._call_exception_handler (exc , route )
618
+ if response_builder :
619
+ return response_builder
620
+
622
621
if self ._debug :
623
622
# If the user has turned on debug mode,
624
623
# we'll let the original exception propagate so
@@ -628,10 +627,46 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
628
627
status_code = 500 ,
629
628
content_type = content_types .TEXT_PLAIN ,
630
629
body = "" .join (traceback .format_exc ()),
631
- )
630
+ ),
631
+ route ,
632
632
)
633
+
633
634
raise
634
635
636
+ def not_found (self , func : Callable ):
637
+ return self .exception_handler (NotFoundError )(func )
638
+
639
+ def exception_handler (self , exc_class : Type [Exception ]):
640
+ def register_exception_handler (func : Callable ):
641
+ self ._exception_handlers [exc_class ] = func
642
+
643
+ return register_exception_handler
644
+
645
+ def _lookup_exception_handler (self , exp_type : Type ) -> Optional [Callable ]:
646
+ # Use "Method Resolution Order" to allow for matching against a base class
647
+ # of an exception
648
+ for cls in exp_type .__mro__ :
649
+ if cls in self ._exception_handlers :
650
+ return self ._exception_handlers [cls ]
651
+ return None
652
+
653
+ def _call_exception_handler (self , exp : Exception , route : Route ) -> Optional [ResponseBuilder ]:
654
+ handler = self ._lookup_exception_handler (type (exp ))
655
+ if handler :
656
+ return ResponseBuilder (handler (exp ), route )
657
+
658
+ if isinstance (exp , ServiceError ):
659
+ return ResponseBuilder (
660
+ Response (
661
+ status_code = exp .status_code ,
662
+ content_type = content_types .APPLICATION_JSON ,
663
+ body = self ._json_dump ({"statusCode" : exp .status_code , "message" : exp .msg }),
664
+ ),
665
+ route ,
666
+ )
667
+
668
+ return None
669
+
635
670
def _to_response (self , result : Union [Dict , Response ]) -> Response :
636
671
"""Convert the route's result to a Response
637
672
0 commit comments