Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api-gateway): add common HTTP service errors #506

Merged
merged 8 commits into from
Jul 6, 2021
89 changes: 86 additions & 3 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,79 @@
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = logging.getLogger(__name__)
APPLICATION_JSON = "application/json"
michaelbrewer marked this conversation as resolved.
Show resolved Hide resolved


class ServiceError(Exception):
"""Service Error"""

def __init__(self, status_code: int, message: str):
"""
Parameters
----------
code: int
Http status code
message: str
Error message
"""
self.status_code = status_code
self.message = message

def __str__(self) -> str:
"""To string of the message only"""
return self.message


class BadRequestError(ServiceError):
"""Bad Request Error"""

def __init__(self, message: str):
"""
Parameters
----------
message: str
Error message
"""
super().__init__(400, message)


class UnauthorizedError(ServiceError):
"""Unauthorized Error"""

def __init__(self, message: str):
"""
Parameters
----------
message: str
Error message
"""
super().__init__(401, message)


class NotFoundError(ServiceError):
"""Not Found Error"""

def __init__(self, message: str = "Not found"):
"""
Parameters
----------
message: str
Error message
"""
super().__init__(404, message)


class InternalServerError(ServiceError):
"""Internal Serve Error"""

def __init__(self, message: str):
"""
Parameters
----------
message: str
Error message
"""
super().__init__(500, message)


class ProxyEventType(Enum):
Expand Down Expand Up @@ -467,15 +540,25 @@ def _not_found(self, method: str) -> ResponseBuilder:
return ResponseBuilder(
Response(
status_code=404,
content_type="application/json",
content_type=APPLICATION_JSON,
headers=headers,
body=json.dumps({"message": "Not found"}),
)
)

def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
"""Actually call the matching route with any provided keyword arguments."""
return ResponseBuilder(self._to_response(route.func(**args)), route)
try:
return ResponseBuilder(self._to_response(route.func(**args)), route)
except ServiceError as e:
return ResponseBuilder(
Response(
status_code=e.status_code,
content_type=APPLICATION_JSON,
body=json.dumps({"message": str(e)}),
michaelbrewer marked this conversation as resolved.
Show resolved Hide resolved
),
route,
)

@staticmethod
def _to_response(result: Union[Dict, Response]) -> Response:
Expand All @@ -493,6 +576,6 @@ def _to_response(result: Union[Dict, Response]) -> Response:
logger.debug("Simple response detected, serializing return before constructing final response")
return Response(
status_code=200,
content_type="application/json",
content_type=APPLICATION_JSON,
body=json.dumps(result, separators=(",", ":"), cls=Encoder),
)
79 changes: 78 additions & 1 deletion tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@
from typing import Dict

from aws_lambda_powertools.event_handler.api_gateway import (
APPLICATION_JSON,
ApiGatewayResolver,
BadRequestError,
CORSConfig,
InternalServerError,
NotFoundError,
ProxyEventType,
Response,
ResponseBuilder,
ServiceError,
UnauthorizedError,
)
from aws_lambda_powertools.shared.json_encoder import Encoder
from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2
Expand All @@ -24,7 +30,6 @@ def read_media(file_name: str) -> bytes:

LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json")
TEXT_HTML = "text/html"
APPLICATION_JSON = "application/json"


def test_alb_event():
Expand Down Expand Up @@ -429,6 +434,7 @@ def test_no_matches_with_cors():
# AND cors headers are returned
assert result["statusCode"] == 404
assert "Access-Control-Allow-Origin" in result["headers"]
assert "Not found" in result["body"]


def test_cors_preflight():
Expand Down Expand Up @@ -490,3 +496,74 @@ def custom_method():
assert headers["Content-Type"] == TEXT_HTML
assert "Access-Control-Allow-Origin" in result["headers"]
assert headers["Access-Control-Allow-Methods"] == "CUSTOM"


def test_service_error_response():
# GIVEN a service error response
app = ApiGatewayResolver(cors=CORSConfig())

@app.route(method="GET", rule="/bad-request-error", cors=False)
def bad_request_error():
raise BadRequestError("Missing required parameter")

@app.route(method="GET", rule="/unauthorized-error", cors=False)
def unauthorized_error():
raise UnauthorizedError("Unauthorized")

@app.route(method="GET", rule="/service-error", cors=True)
def service_error():
raise ServiceError(403, "Unauthorized")

@app.route(method="GET", rule="/not-found-error", cors=False)
def not_found_error():
raise NotFoundError

@app.route(method="GET", rule="/internal-server-error", cors=False)
def internal_server_error():
raise InternalServerError("Internal server error")

# WHEN calling the handler
# AND path is /bad-request-error
result = app({"path": "/bad-request-error", "httpMethod": "GET"}, None)
# THEN return the bad request error response
# AND status code equals 400
assert result["statusCode"] == 400
assert result["body"] == json.dumps({"message": "Missing required parameter"})
assert result["headers"]["Content-Type"] == APPLICATION_JSON

# WHEN calling the handler
# AND path is /unauthorized-error
result = app({"path": "/unauthorized-error", "httpMethod": "GET"}, None)
# THEN return the unauthorized error response
# AND status code equals 401
assert result["statusCode"] == 401
assert result["body"] == json.dumps({"message": "Unauthorized"})
assert result["headers"]["Content-Type"] == APPLICATION_JSON

# WHEN calling the handler
# AND path is /service-error
result = app({"path": "/service-error", "httpMethod": "GET"}, None)
# THEN return the service error response
# AND status code equals 403
assert result["statusCode"] == 403
assert result["body"] == json.dumps({"message": "Unauthorized"})
assert result["headers"]["Content-Type"] == APPLICATION_JSON
assert "Access-Control-Allow-Origin" in result["headers"]

# WHEN calling the handler
# AND path is /not-found-error
result = app({"path": "/not-found-error", "httpMethod": "GET"}, None)
# THEN return the not found error response
# AND status code equals 404
assert result["statusCode"] == 404
assert result["body"] == json.dumps({"message": "Not found"})
assert result["headers"]["Content-Type"] == APPLICATION_JSON

# WHEN calling the handler
# AND path is /internal-server-error
result = app({"path": "/internal-server-error", "httpMethod": "GET"}, None)
# THEN return the internal server error response
# AND status code equals 500
assert result["statusCode"] == 500
assert result["body"] == json.dumps({"message": "Internal server error"})
assert result["headers"]["Content-Type"] == APPLICATION_JSON