Skip to content

Commit

Permalink
Move exception handling logic to Route (#2026)
Browse files Browse the repository at this point in the history
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
  • Loading branch information
adriangb and Kludex authored Jun 7, 2023
1 parent e981768 commit e99738b
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 73 deletions.
76 changes: 76 additions & 0 deletions starlette/_exception_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import typing

from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket

Handler = typing.Callable[..., typing.Any]
ExceptionHandlers = typing.Dict[typing.Any, Handler]
StatusHandlers = typing.Dict[int, Handler]


def _lookup_exception_handler(
exc_handlers: ExceptionHandlers, exc: Exception
) -> typing.Optional[Handler]:
for cls in type(exc).__mro__:
if cls in exc_handlers:
return exc_handlers[cls]
return None


def wrap_app_handling_exceptions(
app: ASGIApp, conn: typing.Union[Request, WebSocket]
) -> ASGIApp:
exception_handlers: ExceptionHandlers
status_handlers: StatusHandlers
try:
exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"]
except KeyError:
exception_handlers, status_handlers = {}, {}

async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
response_started = False

async def sender(message: Message) -> None:
nonlocal response_started

if message["type"] == "http.response.start":
response_started = True
await send(message)

try:
await app(scope, receive, sender)
except Exception as exc:
handler = None

if isinstance(exc, HTTPException):
handler = status_handlers.get(exc.status_code)

if handler is None:
handler = _lookup_exception_handler(exception_handlers, exc)

if handler is None:
raise exc

if response_started:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc

if scope["type"] == "http":
response: Response
if is_async_callable(handler):
response = await handler(conn, exc)
else:
response = await run_in_threadpool(handler, conn, exc)
await response(scope, receive, sender)
elif scope["type"] == "websocket":
if is_async_callable(handler):
await handler(conn, exc)
else:
await run_in_threadpool(handler, conn, exc)

return wrapped_app
83 changes: 24 additions & 59 deletions starlette/middleware/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import typing

from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette._exception_handler import (
ExceptionHandlers,
StatusHandlers,
wrap_app_handling_exceptions,
)
from starlette.exceptions import HTTPException, WebSocketException
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket


Expand All @@ -20,12 +23,10 @@ def __init__(
) -> None:
self.app = app
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
self._status_handlers: typing.Dict[int, typing.Callable] = {}
self._exception_handlers: typing.Dict[
typing.Type[Exception], typing.Callable
] = {
self._status_handlers: StatusHandlers = {}
self._exception_handlers: ExceptionHandlers = {
HTTPException: self.http_exception,
WebSocketException: self.websocket_exception,
WebSocketException: self.websocket_exception, # type: ignore[dict-item]
}
if handlers is not None:
for key, value in handlers.items():
Expand All @@ -42,68 +43,32 @@ def add_exception_handler(
assert issubclass(exc_class_or_status_code, Exception)
self._exception_handlers[exc_class_or_status_code] = handler

def _lookup_exception_handler(
self, exc: Exception
) -> typing.Optional[typing.Callable]:
for cls in type(exc).__mro__:
if cls in self._exception_handlers:
return self._exception_handlers[cls]
return None

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"):
await self.app(scope, receive, send)
return

response_started = False

async def sender(message: Message) -> None:
nonlocal response_started

if message["type"] == "http.response.start":
response_started = True
await send(message)

try:
await self.app(scope, receive, sender)
except Exception as exc:
handler = None

if isinstance(exc, HTTPException):
handler = self._status_handlers.get(exc.status_code)

if handler is None:
handler = self._lookup_exception_handler(exc)

if handler is None:
raise exc
scope["starlette.exception_handlers"] = (
self._exception_handlers,
self._status_handlers,
)

if response_started:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc
conn: typing.Union[Request, WebSocket]
if scope["type"] == "http":
conn = Request(scope, receive, send)
else:
conn = WebSocket(scope, receive, send)

if scope["type"] == "http":
request = Request(scope, receive=receive)
if is_async_callable(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)
await response(scope, receive, sender)
elif scope["type"] == "websocket":
websocket = WebSocket(scope, receive=receive, send=send)
if is_async_callable(handler):
await handler(websocket, exc)
else:
await run_in_threadpool(handler, websocket, exc)
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)

def http_exception(self, request: Request, exc: HTTPException) -> Response:
def http_exception(self, request: Request, exc: Exception) -> Response:
assert isinstance(exc, HTTPException)
if exc.status_code in {204, 304}:
return Response(status_code=exc.status_code, headers=exc.headers)
return PlainTextResponse(
exc.detail, status_code=exc.status_code, headers=exc.headers
)

async def websocket_exception(
self, websocket: WebSocket, exc: WebSocketException
) -> None:
await websocket.close(code=exc.code, reason=exc.reason)
async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
assert isinstance(exc, WebSocketException)
await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover
23 changes: 16 additions & 7 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from contextlib import asynccontextmanager
from enum import Enum

from starlette._exception_handler import wrap_app_handling_exceptions
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.convertors import CONVERTOR_TYPES, Convertor
Expand Down Expand Up @@ -61,12 +62,16 @@ def request_response(func: typing.Callable) -> ASGIApp:
is_coroutine = is_async_callable(func)

async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive=receive, send=send)
if is_coroutine:
response = await func(request)
else:
response = await run_in_threadpool(func, request)
await response(scope, receive, send)
request = Request(scope, receive, send)

async def app(scope: Scope, receive: Receive, send: Send) -> None:
if is_coroutine:
response = await func(request)
else:
response = await run_in_threadpool(func, request)
await response(scope, receive, send)

await wrap_app_handling_exceptions(app, request)(scope, receive, send)

return app

Expand All @@ -79,7 +84,11 @@ def websocket_session(func: typing.Callable) -> ASGIApp:

async def app(scope: Scope, receive: Receive, send: Send) -> None:
session = WebSocket(scope, receive=receive, send=send)
await func(session)

async def app(scope: Scope, receive: Receive, send: Send) -> None:
await func(session)

await wrap_app_handling_exceptions(app, session)(scope, receive, send)

return app

Expand Down
35 changes: 33 additions & 2 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.responses import PlainTextResponse
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Route, Router, WebSocketRoute


Expand All @@ -28,6 +29,22 @@ def with_headers(request):
raise HTTPException(status_code=200, headers={"x-potato": "always"})


class BadBodyException(HTTPException):
pass


async def read_body_and_raise_exc(request: Request):
await request.body()
raise BadBodyException(422)


async def handler_that_reads_body(
request: Request, exc: BadBodyException
) -> JSONResponse:
body = await request.body()
return JSONResponse(status_code=422, content={"body": body.decode()})


class HandledExcAfterResponse:
async def __call__(self, scope, receive, send):
response = PlainTextResponse("OK", status_code=200)
Expand All @@ -44,11 +61,19 @@ async def __call__(self, scope, receive, send):
Route("/with_headers", endpoint=with_headers),
Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()),
WebSocketRoute("/runtime_error", endpoint=raise_runtime_error),
Route(
"/consume_body_in_endpoint_and_handler",
endpoint=read_body_and_raise_exc,
methods=["POST"],
),
]
)


app = ExceptionMiddleware(router)
app = ExceptionMiddleware(
router,
handlers={BadBodyException: handler_that_reads_body}, # type: ignore[dict-item]
)


@pytest.fixture
Expand Down Expand Up @@ -160,3 +185,9 @@ def test_exception_middleware_deprecation() -> None:

with pytest.warns(DeprecationWarning):
starlette.exceptions.ExceptionMiddleware


def test_request_in_app_and_handler_is_the_same_object(client) -> None:
response = client.post("/consume_body_in_endpoint_and_handler", content=b"Hello!")
assert response.status_code == 422
assert response.json() == {"body": "Hello!"}
6 changes: 1 addition & 5 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,13 +1033,9 @@ async def modified_send(msg: Message) -> None:
assert resp.status_code == 200, resp.content
assert "X-Mounted" in resp.headers

# this is the "surprising" behavior bit
# the middleware on the mount never runs because there
# is nothing to catch the HTTPException
# since Mount middlweare is not wrapped by ExceptionMiddleware
resp = client.get("/mount/err")
assert resp.status_code == 403, resp.content
assert "X-Mounted" not in resp.headers
assert "X-Mounted" in resp.headers


def test_route_repr() -> None:
Expand Down

0 comments on commit e99738b

Please sign in to comment.