Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api-gateway): add common HTTP service errors #506

Merged
merged 8 commits into from
Jul 6, 2021
3 changes: 2 additions & 1 deletion aws_lambda_powertools/event_handler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Event handler decorators for common Lambda events
"""

from .api_gateway import ApiGatewayResolver
from .appsync import AppSyncResolver

__all__ = ["AppSyncResolver"]
__all__ = ["AppSyncResolver", "ApiGatewayResolver"]
33 changes: 25 additions & 8 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import re
import zlib
from enum import Enum
from http import HTTPStatus
from typing import Any, Callable, Dict, List, Optional, Set, Union

from aws_lambda_powertools.event_handler import content_types
from aws_lambda_powertools.event_handler.exceptions import ServiceError
from aws_lambda_powertools.shared.json_encoder import Encoder
from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent
Expand Down Expand Up @@ -466,19 +469,28 @@ def _not_found(self, method: str) -> ResponseBuilder:

return ResponseBuilder(
Response(
status_code=404,
content_type="application/json",
status_code=HTTPStatus.NOT_FOUND.value,
content_type=content_types.APPLICATION_JSON,
headers=headers,
body=json.dumps({"message": "Not found"}),
body=self._json_dump({"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"}),
)
)

def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
"""Actually call the matching route with any provided keyword arguments."""
return ResponseBuilder(self._to_response(route.func(**args)), route)
try:
return ResponseBuilder(self._to_response(route.func(**args)), route)
except ServiceError as e:
return ResponseBuilder(
Response(
status_code=e.status_code,
content_type=content_types.APPLICATION_JSON,
body=self._json_dump({"statusCode": e.status_code, "message": e.msg}),
),
route,
)

@staticmethod
def _to_response(result: Union[Dict, Response]) -> Response:
def _to_response(self, result: Union[Dict, Response]) -> Response:
"""Convert the route's result to a Response

2 main result types are supported:
Expand All @@ -493,6 +505,11 @@ def _to_response(result: Union[Dict, Response]) -> Response:
logger.debug("Simple response detected, serializing return before constructing final response")
return Response(
status_code=200,
content_type="application/json",
body=json.dumps(result, separators=(",", ":"), cls=Encoder),
content_type=content_types.APPLICATION_JSON,
body=self._json_dump(result),
)

@staticmethod
def _json_dump(obj: Any) -> str:
"""Does a concise json serialization"""
return json.dumps(obj, separators=(",", ":"), cls=Encoder)
2 changes: 2 additions & 0 deletions aws_lambda_powertools/event_handler/content_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
APPLICATION_JSON = "application/json"
PLAIN_TEXT = "plain/text"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be “text/plain”

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching it, fixed now in develop w/ a note on mimetypes lib.

45 changes: 45 additions & 0 deletions aws_lambda_powertools/event_handler/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from http import HTTPStatus


class ServiceError(Exception):
"""Service Error"""

def __init__(self, status_code: int, msg: str):
"""
Parameters
----------
status_code: int
Http status code
msg: str
Error message
"""
self.status_code = status_code
self.msg = msg


class BadRequestError(ServiceError):
"""Bad Request Error"""

def __init__(self, msg: str):
super().__init__(HTTPStatus.BAD_REQUEST, msg)


class UnauthorizedError(ServiceError):
"""Unauthorized Error"""

def __init__(self, msg: str):
super().__init__(HTTPStatus.UNAUTHORIZED, msg)


class NotFoundError(ServiceError):
"""Not Found Error"""

def __init__(self, msg: str = "Not found"):
super().__init__(HTTPStatus.NOT_FOUND, msg)


class InternalServerError(ServiceError):
"""Internal Server Error"""

def __init__(self, message: str):
super().__init__(HTTPStatus.INTERNAL_SERVER_ERROR, message)
110 changes: 101 additions & 9 deletions tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,21 @@
from pathlib import Path
from typing import Dict

from aws_lambda_powertools.event_handler import content_types
from aws_lambda_powertools.event_handler.api_gateway import (
ApiGatewayResolver,
CORSConfig,
ProxyEventType,
Response,
ResponseBuilder,
)
from aws_lambda_powertools.event_handler.exceptions import (
BadRequestError,
InternalServerError,
NotFoundError,
ServiceError,
UnauthorizedError,
)
from aws_lambda_powertools.shared.json_encoder import Encoder
from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2
from tests.functional.utils import load_event
Expand All @@ -24,7 +32,6 @@ def read_media(file_name: str) -> bytes:

LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json")
TEXT_HTML = "text/html"
APPLICATION_JSON = "application/json"


def test_alb_event():
Expand Down Expand Up @@ -55,15 +62,15 @@ def test_api_gateway_v1():
def get_lambda() -> Response:
assert isinstance(app.current_event, APIGatewayProxyEvent)
assert app.lambda_context == {}
return Response(200, APPLICATION_JSON, json.dumps({"foo": "value"}))
return Response(200, content_types.APPLICATION_JSON, json.dumps({"foo": "value"}))

# WHEN calling the event handler
result = app(LOAD_GW_EVENT, {})

# THEN process event correctly
# AND set the current_event type as APIGatewayProxyEvent
assert result["statusCode"] == 200
assert result["headers"]["Content-Type"] == APPLICATION_JSON
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON


def test_api_gateway():
Expand Down Expand Up @@ -93,15 +100,15 @@ def test_api_gateway_v2():
def my_path() -> Response:
assert isinstance(app.current_event, APIGatewayProxyEventV2)
post_data = app.current_event.json_body
return Response(200, "plain/text", post_data["username"])
return Response(200, content_types.PLAIN_TEXT, post_data["username"])

# WHEN calling the event handler
result = app(load_event("apiGatewayProxyV2Event.json"), {})

# THEN process event correctly
# AND set the current_event type as APIGatewayProxyEventV2
assert result["statusCode"] == 200
assert result["headers"]["Content-Type"] == "plain/text"
assert result["headers"]["Content-Type"] == content_types.PLAIN_TEXT
assert result["body"] == "tom"


Expand Down Expand Up @@ -215,7 +222,7 @@ def test_compress():

@app.get("/my/request", compress=True)
def with_compression() -> Response:
return Response(200, APPLICATION_JSON, expected_value)
return Response(200, content_types.APPLICATION_JSON, expected_value)

def handler(event, context):
return app.resolve(event, context)
Expand Down Expand Up @@ -261,7 +268,7 @@ def test_compress_no_accept_encoding():

@app.get("/my/path", compress=True)
def return_text() -> Response:
return Response(200, "text/plain", expected_value)
return Response(200, content_types.PLAIN_TEXT, expected_value)

# WHEN calling the event handler
result = app({"path": "/my/path", "httpMethod": "GET", "headers": {}}, None)
Expand Down Expand Up @@ -327,7 +334,7 @@ def rest_func() -> Dict:

# THEN automatically process this as a json rest api response
assert result["statusCode"] == 200
assert result["headers"]["Content-Type"] == APPLICATION_JSON
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
expected_str = json.dumps(expected_dict, separators=(",", ":"), indent=None, cls=Encoder)
assert result["body"] == expected_str

Expand Down Expand Up @@ -382,7 +389,7 @@ def another_one():
# THEN routes by default return the custom cors headers
assert "headers" in result
headers = result["headers"]
assert headers["Content-Type"] == APPLICATION_JSON
assert headers["Content-Type"] == content_types.APPLICATION_JSON
assert headers["Access-Control-Allow-Origin"] == cors_config.allow_origin
expected_allows_headers = ",".join(sorted(set(allow_header + cors_config._REQUIRED_HEADERS)))
assert headers["Access-Control-Allow-Headers"] == expected_allows_headers
Expand Down Expand Up @@ -429,6 +436,7 @@ def test_no_matches_with_cors():
# AND cors headers are returned
assert result["statusCode"] == 404
assert "Access-Control-Allow-Origin" in result["headers"]
assert "Not found" in result["body"]


def test_cors_preflight():
Expand Down Expand Up @@ -490,3 +498,87 @@ def custom_method():
assert headers["Content-Type"] == TEXT_HTML
assert "Access-Control-Allow-Origin" in result["headers"]
assert headers["Access-Control-Allow-Methods"] == "CUSTOM"


def test_service_error_responses():
# SCENARIO handling different kind of service errors being raised
app = ApiGatewayResolver(cors=CORSConfig())

def json_dump(obj):
return json.dumps(obj, separators=(",", ":"))

# GIVEN an BadRequestError
@app.get(rule="/bad-request-error", cors=False)
def bad_request_error():
raise BadRequestError("Missing required parameter")

# WHEN calling the handler
# AND path is /bad-request-error
result = app({"path": "/bad-request-error", "httpMethod": "GET"}, None)
# THEN return the bad request error response
# AND status code equals 400
assert result["statusCode"] == 400
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
expected = {"statusCode": 400, "message": "Missing required parameter"}
assert result["body"] == json_dump(expected)

# GIVEN an UnauthorizedError
@app.get(rule="/unauthorized-error", cors=False)
def unauthorized_error():
raise UnauthorizedError("Unauthorized")

# WHEN calling the handler
# AND path is /unauthorized-error
result = app({"path": "/unauthorized-error", "httpMethod": "GET"}, None)
# THEN return the unauthorized error response
# AND status code equals 401
assert result["statusCode"] == 401
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
expected = {"statusCode": 401, "message": "Unauthorized"}
assert result["body"] == json_dump(expected)

# GIVEN an NotFoundError
@app.get(rule="/not-found-error", cors=False)
def not_found_error():
raise NotFoundError

# WHEN calling the handler
# AND path is /not-found-error
result = app({"path": "/not-found-error", "httpMethod": "GET"}, None)
# THEN return the not found error response
# AND status code equals 404
assert result["statusCode"] == 404
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
expected = {"statusCode": 404, "message": "Not found"}
assert result["body"] == json_dump(expected)

# GIVEN an InternalServerError
@app.get(rule="/internal-server-error", cors=False)
def internal_server_error():
raise InternalServerError("Internal server error")

# WHEN calling the handler
# AND path is /internal-server-error
result = app({"path": "/internal-server-error", "httpMethod": "GET"}, None)
# THEN return the internal server error response
# AND status code equals 500
assert result["statusCode"] == 500
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
expected = {"statusCode": 500, "message": "Internal server error"}
assert result["body"] == json_dump(expected)

# GIVEN an ServiceError with a custom status code
@app.get(rule="/service-error", cors=True)
def service_error():
raise ServiceError(502, "Something went wrong!")

# WHEN calling the handler
# AND path is /service-error
result = app({"path": "/service-error", "httpMethod": "GET"}, None)
# THEN return the service error response
# AND status code equals 502
assert result["statusCode"] == 502
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
assert "Access-Control-Allow-Origin" in result["headers"]
expected = {"statusCode": 502, "message": "Something went wrong!"}
assert result["body"] == json_dump(expected)