diff --git a/starlette/_exception_handler.py b/starlette/_exception_handler.py new file mode 100644 index 0000000000..8a9beb3b29 --- /dev/null +++ b/starlette/_exception_handler.py @@ -0,0 +1,76 @@ +import typing + +from starlette._utils import is_async_callable +from starlette.concurrency import run_in_threadpool +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import ASGIApp, Message, Receive, Scope, Send +from starlette.websockets import WebSocket + +Handler = typing.Callable[..., typing.Any] +ExceptionHandlers = typing.Dict[typing.Any, Handler] +StatusHandlers = typing.Dict[int, Handler] + + +def _lookup_exception_handler( + exc_handlers: ExceptionHandlers, exc: Exception +) -> typing.Optional[Handler]: + for cls in type(exc).__mro__: + if cls in exc_handlers: + return exc_handlers[cls] + return None + + +def wrap_app_handling_exceptions( + app: ASGIApp, conn: typing.Union[Request, WebSocket] +) -> ASGIApp: + exception_handlers: ExceptionHandlers + status_handlers: StatusHandlers + try: + exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"] + except KeyError: + exception_handlers, status_handlers = {}, {} + + async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None: + response_started = False + + async def sender(message: Message) -> None: + nonlocal response_started + + if message["type"] == "http.response.start": + response_started = True + await send(message) + + try: + await app(scope, receive, sender) + except Exception as exc: + handler = None + + if isinstance(exc, HTTPException): + handler = status_handlers.get(exc.status_code) + + if handler is None: + handler = _lookup_exception_handler(exception_handlers, exc) + + if handler is None: + raise exc + + if response_started: + msg = "Caught handled exception, but response already started." + raise RuntimeError(msg) from exc + + if scope["type"] == "http": + response: Response + if is_async_callable(handler): + response = await handler(conn, exc) + else: + response = await run_in_threadpool(handler, conn, exc) + await response(scope, receive, sender) + elif scope["type"] == "websocket": + if is_async_callable(handler): + await handler(conn, exc) + else: + await run_in_threadpool(handler, conn, exc) + + return wrapped_app diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py index cd72941704..59010c7e68 100644 --- a/starlette/middleware/exceptions.py +++ b/starlette/middleware/exceptions.py @@ -1,11 +1,14 @@ import typing -from starlette._utils import is_async_callable -from starlette.concurrency import run_in_threadpool +from starlette._exception_handler import ( + ExceptionHandlers, + StatusHandlers, + wrap_app_handling_exceptions, +) from starlette.exceptions import HTTPException, WebSocketException from starlette.requests import Request from starlette.responses import PlainTextResponse, Response -from starlette.types import ASGIApp, Message, Receive, Scope, Send +from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket @@ -20,12 +23,10 @@ def __init__( ) -> None: self.app = app self.debug = debug # TODO: We ought to handle 404 cases if debug is set. - self._status_handlers: typing.Dict[int, typing.Callable] = {} - self._exception_handlers: typing.Dict[ - typing.Type[Exception], typing.Callable - ] = { + self._status_handlers: StatusHandlers = {} + self._exception_handlers: ExceptionHandlers = { HTTPException: self.http_exception, - WebSocketException: self.websocket_exception, + WebSocketException: self.websocket_exception, # type: ignore[dict-item] } if handlers is not None: for key, value in handlers.items(): @@ -42,68 +43,32 @@ def add_exception_handler( assert issubclass(exc_class_or_status_code, Exception) self._exception_handlers[exc_class_or_status_code] = handler - def _lookup_exception_handler( - self, exc: Exception - ) -> typing.Optional[typing.Callable]: - for cls in type(exc).__mro__: - if cls in self._exception_handlers: - return self._exception_handlers[cls] - return None - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] not in ("http", "websocket"): await self.app(scope, receive, send) return - response_started = False - - async def sender(message: Message) -> None: - nonlocal response_started - - if message["type"] == "http.response.start": - response_started = True - await send(message) - - try: - await self.app(scope, receive, sender) - except Exception as exc: - handler = None - - if isinstance(exc, HTTPException): - handler = self._status_handlers.get(exc.status_code) - - if handler is None: - handler = self._lookup_exception_handler(exc) - - if handler is None: - raise exc + scope["starlette.exception_handlers"] = ( + self._exception_handlers, + self._status_handlers, + ) - if response_started: - msg = "Caught handled exception, but response already started." - raise RuntimeError(msg) from exc + conn: typing.Union[Request, WebSocket] + if scope["type"] == "http": + conn = Request(scope, receive, send) + else: + conn = WebSocket(scope, receive, send) - if scope["type"] == "http": - request = Request(scope, receive=receive) - if is_async_callable(handler): - response = await handler(request, exc) - else: - response = await run_in_threadpool(handler, request, exc) - await response(scope, receive, sender) - elif scope["type"] == "websocket": - websocket = WebSocket(scope, receive=receive, send=send) - if is_async_callable(handler): - await handler(websocket, exc) - else: - await run_in_threadpool(handler, websocket, exc) + await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send) - def http_exception(self, request: Request, exc: HTTPException) -> Response: + def http_exception(self, request: Request, exc: Exception) -> Response: + assert isinstance(exc, HTTPException) if exc.status_code in {204, 304}: return Response(status_code=exc.status_code, headers=exc.headers) return PlainTextResponse( exc.detail, status_code=exc.status_code, headers=exc.headers ) - async def websocket_exception( - self, websocket: WebSocket, exc: WebSocketException - ) -> None: - await websocket.close(code=exc.code, reason=exc.reason) + async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None: + assert isinstance(exc, WebSocketException) + await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover diff --git a/starlette/routing.py b/starlette/routing.py index 52cf174e1f..8e01c8562c 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -9,6 +9,7 @@ from contextlib import asynccontextmanager from enum import Enum +from starlette._exception_handler import wrap_app_handling_exceptions from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.convertors import CONVERTOR_TYPES, Convertor @@ -61,12 +62,16 @@ def request_response(func: typing.Callable) -> ASGIApp: is_coroutine = is_async_callable(func) async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request(scope, receive=receive, send=send) - if is_coroutine: - response = await func(request) - else: - response = await run_in_threadpool(func, request) - await response(scope, receive, send) + request = Request(scope, receive, send) + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + if is_coroutine: + response = await func(request) + else: + response = await run_in_threadpool(func, request) + await response(scope, receive, send) + + await wrap_app_handling_exceptions(app, request)(scope, receive, send) return app @@ -79,7 +84,11 @@ def websocket_session(func: typing.Callable) -> ASGIApp: async def app(scope: Scope, receive: Receive, send: Send) -> None: session = WebSocket(scope, receive=receive, send=send) - await func(session) + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await func(session) + + await wrap_app_handling_exceptions(app, session)(scope, receive, send) return app diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 05583a430b..2f2b891673 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -4,7 +4,8 @@ from starlette.exceptions import HTTPException, WebSocketException from starlette.middleware.exceptions import ExceptionMiddleware -from starlette.responses import PlainTextResponse +from starlette.requests import Request +from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Route, Router, WebSocketRoute @@ -28,6 +29,22 @@ def with_headers(request): raise HTTPException(status_code=200, headers={"x-potato": "always"}) +class BadBodyException(HTTPException): + pass + + +async def read_body_and_raise_exc(request: Request): + await request.body() + raise BadBodyException(422) + + +async def handler_that_reads_body( + request: Request, exc: BadBodyException +) -> JSONResponse: + body = await request.body() + return JSONResponse(status_code=422, content={"body": body.decode()}) + + class HandledExcAfterResponse: async def __call__(self, scope, receive, send): response = PlainTextResponse("OK", status_code=200) @@ -44,11 +61,19 @@ async def __call__(self, scope, receive, send): Route("/with_headers", endpoint=with_headers), Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()), WebSocketRoute("/runtime_error", endpoint=raise_runtime_error), + Route( + "/consume_body_in_endpoint_and_handler", + endpoint=read_body_and_raise_exc, + methods=["POST"], + ), ] ) -app = ExceptionMiddleware(router) +app = ExceptionMiddleware( + router, + handlers={BadBodyException: handler_that_reads_body}, # type: ignore[dict-item] +) @pytest.fixture @@ -160,3 +185,9 @@ def test_exception_middleware_deprecation() -> None: with pytest.warns(DeprecationWarning): starlette.exceptions.ExceptionMiddleware + + +def test_request_in_app_and_handler_is_the_same_object(client) -> None: + response = client.post("/consume_body_in_endpoint_and_handler", content=b"Hello!") + assert response.status_code == 422 + assert response.json() == {"body": "Hello!"} diff --git a/tests/test_routing.py b/tests/test_routing.py index 298745d407..1292932243 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1033,13 +1033,9 @@ async def modified_send(msg: Message) -> None: assert resp.status_code == 200, resp.content assert "X-Mounted" in resp.headers - # this is the "surprising" behavior bit - # the middleware on the mount never runs because there - # is nothing to catch the HTTPException - # since Mount middlweare is not wrapped by ExceptionMiddleware resp = client.get("/mount/err") assert resp.status_code == 403, resp.content - assert "X-Mounted" not in resp.headers + assert "X-Mounted" in resp.headers def test_route_repr() -> None: