From ec468d3c1c613b74d39419948bdcf7349aae8dd7 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 7 Nov 2018 17:57:12 +0000 Subject: [PATCH 1/4] Rejig ExceptionMiddleware and ServerErrorMiddleware --- starlette/applications.py | 15 +++++---- starlette/debug.py | 37 +++++++++++--------- starlette/exceptions.py | 71 ++++++++++++--------------------------- tests/test_debug.py | 8 ++--- tests/test_exceptions.py | 34 +++++++++---------- 5 files changed, 73 insertions(+), 92 deletions(-) diff --git a/starlette/applications.py b/starlette/applications.py index 464e9c572..373f14c99 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -1,6 +1,7 @@ import typing from starlette.datastructures import URL, URLPath +from starlette.debug import DebugMiddleware from starlette.exceptions import ExceptionMiddleware from starlette.lifespan import LifespanHandler from starlette.middleware.base import BaseHTTPMiddleware @@ -13,8 +14,8 @@ class Starlette: def __init__(self, debug: bool = False) -> None: self.router = Router() self.lifespan_handler = LifespanHandler() - self.app = self.router - self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) + self.exception_middleware = ExceptionMiddleware(self.router) + self.debug_middleware = DebugMiddleware(self.exception_middleware, debug=debug) self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator] @property @@ -23,11 +24,11 @@ def routes(self) -> typing.List[BaseRoute]: @property def debug(self) -> bool: - return self.exception_middleware.debug + return self.debug_middleware.debug @debug.setter def debug(self, value: bool) -> None: - self.exception_middleware.debug = value + self.debug_middleware.debug = value @property def schema(self) -> dict: @@ -41,7 +42,9 @@ def mount(self, path: str, app: ASGIApp, name: str = None) -> None: self.router.mount(path, app=app, name=name) def add_middleware(self, middleware_class: type, **kwargs: typing.Any) -> None: - self.exception_middleware.app = middleware_class(self.app, **kwargs) + self.debug_middleware.app = middleware_class( + self.debug_middleware.app, **kwargs + ) def add_exception_handler(self, exc_class: type, handler: typing.Callable) -> None: self.exception_middleware.add_exception_handler(exc_class, handler) @@ -107,4 +110,4 @@ def __call__(self, scope: Scope) -> ASGIInstance: scope["app"] = self if scope["type"] == "lifespan": return self.lifespan_handler(scope) - return self.exception_middleware(scope) + return self.debug_middleware(scope) diff --git a/starlette/debug.py b/starlette/debug.py index 32cac008d..5aa209cdc 100644 --- a/starlette/debug.py +++ b/starlette/debug.py @@ -63,30 +63,21 @@ def generate_plain_text(self) -> str: return "".join(traceback.format_tb(self.exc.__traceback__)) -def get_debug_response(request: Request, exc: Exception) -> Response: - accept = request.headers.get("accept", "") - debug_gen = DebugGenerator(exc) - - if "text/html" in accept: - content = debug_gen.generate_html() - return HTMLResponse(content, status_code=500) - content = debug_gen.generate_plain_text() - return PlainTextResponse(content, status_code=500) - - class DebugMiddleware: - def __init__(self, app: ASGIApp) -> None: + def __init__(self, app: ASGIApp, debug: bool = False) -> None: self.app = app + self.debug = debug def __call__(self, scope: Scope) -> ASGIInstance: if scope["type"] != "http": return self.app(scope) - return _DebugResponder(self.app, scope) + return _DebugResponder(self.app, self.debug, scope) class _DebugResponder: - def __init__(self, app: ASGIApp, scope: Scope) -> None: + def __init__(self, app: ASGIApp, debug: bool, scope: Scope) -> None: self.app = app + self.debug = debug self.scope = scope self.response_started = False @@ -98,7 +89,10 @@ async def __call__(self, receive: Receive, send: Send) -> None: except Exception as exc: if not self.response_started: request = Request(self.scope) - response = get_debug_response(request, exc) + if self.debug: + response = self.debug_response(request, exc) + else: + response = self.error_response(request, exc) await response(receive, send) raise exc from None @@ -106,3 +100,16 @@ async def send(self, message: Message) -> None: if message["type"] == "http.response.start": self.response_started = True await self.raw_send(message) + + def debug_response(self, request: Request, exc: Exception) -> Response: + accept = request.headers.get("accept", "") + debug_gen = DebugGenerator(exc) + + if "text/html" in accept: + content = debug_gen.generate_html() + return HTMLResponse(content, status_code=500) + content = debug_gen.generate_plain_text() + return PlainTextResponse(content, status_code=500) + + def error_response(self, request: Request, exc: Exception) -> Response: + return PlainTextResponse("Internal Server Error", status_code=500) diff --git a/starlette/exceptions.py b/starlette/exceptions.py index a8d7142b7..3fc810c67 100644 --- a/starlette/exceptions.py +++ b/starlette/exceptions.py @@ -2,7 +2,7 @@ import http import typing -from starlette.debug import get_debug_response +from starlette.concurrency import run_in_threadpool from starlette.requests import Request from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, ASGIInstance, Message, Receive, Scope, Send @@ -17,26 +17,22 @@ def __init__(self, status_code: int, detail: str = None) -> None: class ExceptionMiddleware: - def __init__(self, app: ASGIApp, debug: bool = False) -> None: + def __init__(self, app: ASGIApp) -> None: self.app = app - self.debug = debug self._exception_handlers = { - Exception: self.server_error, - HTTPException: self.http_exception, - } + HTTPException: self.http_exception + } # type: typing.Dict[typing.Type[Exception], typing.Callable] def add_exception_handler( self, exc_class: typing.Type[Exception], handler: typing.Callable ) -> None: - assert issubclass(exc_class, BaseException) + assert issubclass(exc_class, Exception) self._exception_handlers[exc_class] = handler def _lookup_exception_handler( self, exc: Exception ) -> typing.Optional[typing.Callable]: for cls in type(exc).__mro__: - if cls is Exception: - break if cls in self._exception_handlers: return self._exception_handlers[cls] return None @@ -56,48 +52,26 @@ async def sender(message: Message) -> None: await send(message) try: - try: - instance = self.app(scope) - await instance(receive, sender) - except Exception as exc: - # Exception handling is applied to any registed exception - # class or subclass that occurs within the application. - handler = self._lookup_exception_handler(exc) - - # Note that we always handle `Exception` in the outermost block. - if handler is None: - raise exc from None - - if response_started: - msg = "Caught handled exception, but response already started." - raise RuntimeError(msg) from exc - - request = Request(scope, receive=receive) - if asyncio.iscoroutinefunction(handler): - response = await handler(request, exc) - else: - response = handler(request, exc) - await response(receive, sender) - + instance = self.app(scope) + await instance(receive, sender) except Exception as exc: - # The 'Exception' case always wraps everything else, and - # provides a last-ditch handler for dealing with server errors. + # Exception handling is applied to any registed exception + # class or subclass that occurs within the application. + handler = self._lookup_exception_handler(exc) + + if handler is None: + raise exc from None + + if response_started: + msg = "Caught handled exception, but response already started." + raise RuntimeError(msg) from exc + request = Request(scope, receive=receive) - if self.debug: - handler = get_debug_response - else: - handler = self._exception_handlers[Exception] if asyncio.iscoroutinefunction(handler): - response = await handler(request, exc) # type: ignore + response = await handler(request, exc) else: - response = handler(request, exc) # type: ignore - if not response_started: - await response(receive, send) - - # We always raise the exception up to the server so that it - # is notified too. Typically this will mean that it'll log - # the exception. - raise + response = await run_in_threadpool(handler, request, exc) + await response(receive, sender) return app @@ -105,6 +79,3 @@ def http_exception(self, request: Request, exc: HTTPException) -> Response: if exc.status_code in {204, 304}: return Response(b"", status_code=exc.status_code) return PlainTextResponse(exc.detail, status_code=exc.status_code) - - def server_error(self, request: Request, exc: HTTPException) -> Response: - return PlainTextResponse("Internal Server Error", status_code=500) diff --git a/tests/test_debug.py b/tests/test_debug.py index 48f71162e..c1d2763eb 100644 --- a/tests/test_debug.py +++ b/tests/test_debug.py @@ -12,7 +12,7 @@ async def asgi(receive, send): return asgi - app = DebugMiddleware(app) + app = DebugMiddleware(app, debug=True) client = TestClient(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 @@ -27,7 +27,7 @@ async def asgi(receive, send): return asgi - app = DebugMiddleware(app) + app = DebugMiddleware(app, debug=True) client = TestClient(app, raise_server_exceptions=False) response = client.get("/", headers={"Accept": "text/html, */*"}) assert response.status_code == 500 @@ -44,7 +44,7 @@ async def asgi(receive, send): return asgi - app = DebugMiddleware(app) + app = DebugMiddleware(app, debug=True) client = TestClient(app) with pytest.raises(RuntimeError): client.get("/") @@ -54,7 +54,7 @@ def test_debug_error_during_scope(): def app(scope): raise RuntimeError("Something went wrong") - app = DebugMiddleware(app) + app = DebugMiddleware(app, debug=True) client = TestClient(app, raise_server_exceptions=False) response = client.get("/", headers={"Accept": "text/html, */*"}) assert response.status_code == 500 diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 51de7ff1b..a18ef3246 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -43,23 +43,23 @@ async def __call__(self, receive, send): client = TestClient(app) -def test_server_error(): - with pytest.raises(RuntimeError): - response = client.get("/runtime_error") - - allow_500_client = TestClient(app, raise_server_exceptions=False) - response = allow_500_client.get("/runtime_error") - assert response.status_code == 500 - assert response.text == "Internal Server Error" - - -def test_debug_enabled(): - app = ExceptionMiddleware(router) - app.debug = True - allow_500_client = TestClient(app, raise_server_exceptions=False) - response = allow_500_client.get("/runtime_error") - assert response.status_code == 500 - assert "RuntimeError" in response.text +# def test_server_error(): +# with pytest.raises(RuntimeError): +# response = client.get("/runtime_error") +# +# allow_500_client = TestClient(app, raise_server_exceptions=False) +# response = allow_500_client.get("/runtime_error") +# assert response.status_code == 500 +# assert response.text == "Internal Server Error" + + +# def test_debug_enabled(): +# app = ExceptionMiddleware(router) +# app.debug = True +# allow_500_client = TestClient(app, raise_server_exceptions=False) +# response = allow_500_client.get("/runtime_error") +# assert response.status_code == 500 +# assert "RuntimeError" in response.text def test_not_acceptable(): From 4e2b60bc6b717dbb3d600df84b3af99f678b69e4 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 8 Nov 2018 09:53:51 +0000 Subject: [PATCH 2/4] Tweak DebugMiddleware implementation --- starlette/debug.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/starlette/debug.py b/starlette/debug.py index 5aa209cdc..faad9835d 100644 --- a/starlette/debug.py +++ b/starlette/debug.py @@ -1,3 +1,4 @@ +import functools import traceback from starlette.requests import Request @@ -71,24 +72,24 @@ def __init__(self, app: ASGIApp, debug: bool = False) -> None: def __call__(self, scope: Scope) -> ASGIInstance: if scope["type"] != "http": return self.app(scope) - return _DebugResponder(self.app, self.debug, scope) + return functools.partial(self.asgi, scope=scope) + async def asgi(self, receive: Receive, send: Send, scope: Scope) -> None: + response_started = False -class _DebugResponder: - def __init__(self, app: ASGIApp, debug: bool, scope: Scope) -> None: - self.app = app - self.debug = debug - self.scope = scope - self.response_started = False + async def _send(message: Message) -> None: + nonlocal response_started, send + + if message["type"] == "http.response.start": + response_started = True + await send(message) - async def __call__(self, receive: Receive, send: Send) -> None: - self.raw_send = send try: - asgi = self.app(self.scope) - await asgi(receive, self.send) + asgi = self.app(scope) + await asgi(receive, _send) except Exception as exc: - if not self.response_started: - request = Request(self.scope) + if not response_started: + request = Request(scope) if self.debug: response = self.debug_response(request, exc) else: @@ -96,11 +97,6 @@ async def __call__(self, receive: Receive, send: Send) -> None: await response(receive, send) raise exc from None - async def send(self, message: Message) -> None: - if message["type"] == "http.response.start": - self.response_started = True - await self.raw_send(message) - def debug_response(self, request: Request, exc: Exception) -> Response: accept = request.headers.get("accept", "") debug_gen = DebugGenerator(exc) From d964339591431ef1cc52bbeb27f4b5e84e4700b3 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 8 Nov 2018 11:13:53 +0000 Subject: [PATCH 3/4] Support custom 500 handlers --- starlette/applications.py | 38 +++++++++++++------ starlette/exceptions.py | 25 ++++++++---- starlette/{debug.py => middleware/errors.py} | 36 ++++++++++++++++-- .../test_errors.py} | 31 +++++++++++---- tests/test_applications.py | 13 +++++-- 5 files changed, 110 insertions(+), 33 deletions(-) rename starlette/{debug.py => middleware/errors.py} (74%) rename tests/{test_debug.py => middleware/test_errors.py} (69%) diff --git a/starlette/applications.py b/starlette/applications.py index 373f14c99..9eef74472 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -1,10 +1,10 @@ import typing from starlette.datastructures import URL, URLPath -from starlette.debug import DebugMiddleware from starlette.exceptions import ExceptionMiddleware from starlette.lifespan import LifespanHandler from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.errors import ServerErrorMiddleware from starlette.routing import BaseRoute, Router from starlette.schemas import BaseSchemaGenerator from starlette.types import ASGIApp, ASGIInstance, Scope @@ -12,10 +12,13 @@ class Starlette: def __init__(self, debug: bool = False) -> None: + self._debug = debug self.router = Router() self.lifespan_handler = LifespanHandler() - self.exception_middleware = ExceptionMiddleware(self.router) - self.debug_middleware = DebugMiddleware(self.exception_middleware, debug=debug) + self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) + self.error_middleware = ServerErrorMiddleware( + self.exception_middleware, debug=debug + ) self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator] @property @@ -24,11 +27,13 @@ def routes(self) -> typing.List[BaseRoute]: @property def debug(self) -> bool: - return self.debug_middleware.debug + return self._debug @debug.setter def debug(self, value: bool) -> None: - self.debug_middleware.debug = value + self._debug = value + self.exception_middleware.debug = value + self.error_middleware.debug = value @property def schema(self) -> dict: @@ -42,12 +47,19 @@ def mount(self, path: str, app: ASGIApp, name: str = None) -> None: self.router.mount(path, app=app, name=name) def add_middleware(self, middleware_class: type, **kwargs: typing.Any) -> None: - self.debug_middleware.app = middleware_class( - self.debug_middleware.app, **kwargs + self.error_middleware.app = middleware_class( + self.error_middleware.app, **kwargs ) - def add_exception_handler(self, exc_class: type, handler: typing.Callable) -> None: - self.exception_middleware.add_exception_handler(exc_class, handler) + def add_exception_handler( + self, + category: typing.Union[int, typing.Type[Exception]], + handler: typing.Callable, + ) -> None: + if category in (500, Exception): + self.error_middleware.handler = handler + else: + self.exception_middleware.add_exception_handler(category, handler) def add_event_handler(self, event_type: str, func: typing.Callable) -> None: self.lifespan_handler.add_event_handler(event_type, func) @@ -64,9 +76,11 @@ def add_route( def add_websocket_route(self, path: str, route: typing.Callable) -> None: self.router.add_websocket_route(path, route) - def exception_handler(self, exc_class: type) -> typing.Callable: + def exception_handler( + self, category: typing.Union[int, typing.Type[Exception]] + ) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: - self.add_exception_handler(exc_class, func) + self.add_exception_handler(category, func) return func return decorator @@ -110,4 +124,4 @@ def __call__(self, scope: Scope) -> ASGIInstance: scope["app"] = self if scope["type"] == "lifespan": return self.lifespan_handler(scope) - return self.debug_middleware(scope) + return self.error_middleware(scope) diff --git a/starlette/exceptions.py b/starlette/exceptions.py index 3fc810c67..e51628f3d 100644 --- a/starlette/exceptions.py +++ b/starlette/exceptions.py @@ -17,17 +17,24 @@ def __init__(self, status_code: int, detail: str = None) -> None: class ExceptionMiddleware: - def __init__(self, app: ASGIApp) -> None: + def __init__(self, app: ASGIApp, debug: bool = False) -> None: self.app = app + self.debug = debug # TODO: We ought to handle 404 cases if debug is set. + self._status_handlers = {} # type: typing.Dict[int, typing.Callable] self._exception_handlers = { HTTPException: self.http_exception } # type: typing.Dict[typing.Type[Exception], typing.Callable] def add_exception_handler( - self, exc_class: typing.Type[Exception], handler: typing.Callable + self, + category: typing.Union[int, typing.Type[Exception]], + handler: typing.Callable, ) -> None: - assert issubclass(exc_class, Exception) - self._exception_handlers[exc_class] = handler + if isinstance(category, int): + self._status_handlers[category] = handler + else: + assert issubclass(category, Exception) + self._exception_handlers[category] = handler def _lookup_exception_handler( self, exc: Exception @@ -55,9 +62,13 @@ async def sender(message: Message) -> None: instance = self.app(scope) await instance(receive, sender) except Exception as exc: - # Exception handling is applied to any registed exception - # class or subclass that occurs within the application. - handler = self._lookup_exception_handler(exc) + handler = None + + if isinstance(exc, HTTPException): + handler = self._status_handlers.get(exc.status_code) + + if handler is None: + handler = self._lookup_exception_handler(exc) if handler is None: raise exc from None diff --git a/starlette/debug.py b/starlette/middleware/errors.py similarity index 74% rename from starlette/debug.py rename to starlette/middleware/errors.py index faad9835d..021b8064d 100644 --- a/starlette/debug.py +++ b/starlette/middleware/errors.py @@ -1,6 +1,9 @@ +import asyncio import functools import traceback +import typing +from starlette.concurrency import run_in_threadpool from starlette.requests import Request from starlette.responses import HTMLResponse, PlainTextResponse, Response from starlette.types import ASGIApp, ASGIInstance, Message, Receive, Scope, Send @@ -64,9 +67,23 @@ def generate_plain_text(self) -> str: return "".join(traceback.format_tb(self.exc.__traceback__)) -class DebugMiddleware: - def __init__(self, app: ASGIApp, debug: bool = False) -> None: +class ServerErrorMiddleware: + """ + Handles returning 500 responses when a server error occurs. + + If 'debug' is set, then traceback responses will be returned, + otherwise the designated 'handler' will be called. + + This middleware class should generally be used to wrap everything + else up, so that unhandled exceptions anywhere in the stack + always result in an appropriate 500 response. + """ + + def __init__( + self, app: ASGIApp, handler: typing.Callable = None, debug: bool = False + ) -> None: self.app = app + self.handler = handler self.debug = debug def __call__(self, scope: Scope) -> ASGIInstance: @@ -91,10 +108,23 @@ async def _send(message: Message) -> None: if not response_started: request = Request(scope) if self.debug: + # In debug mode, return traceback responses. response = self.debug_response(request, exc) - else: + elif self.handler is None: + # Use our default 500 error handler. response = self.error_response(request, exc) + else: + # Use an installed 500 error handler. + if asyncio.iscoroutinefunction(self.handler): + response = await self.handler(request, exc) + else: + response = await run_in_threadpool(self.handler, request, exc) + await response(receive, send) + + # We always continue to raise the exception. + # This allows servers to log the error, or allows test clients + # to optionally raise the error within the test case. raise exc from None def debug_response(self, request: Request, exc: Exception) -> Response: diff --git a/tests/test_debug.py b/tests/middleware/test_errors.py similarity index 69% rename from tests/test_debug.py rename to tests/middleware/test_errors.py index c1d2763eb..ccaaafb67 100644 --- a/tests/test_debug.py +++ b/tests/middleware/test_errors.py @@ -1,10 +1,27 @@ import pytest -from starlette.debug import DebugMiddleware -from starlette.responses import Response +from starlette.middleware.errors import ServerErrorMiddleware +from starlette.responses import JSONResponse, Response from starlette.testclient import TestClient +def test_handler(): + def app(scope): + async def asgi(receive, send): + raise RuntimeError("Something went wrong") + + return asgi + + def error_500(request, exc): + return JSONResponse({"detail": "Server Error"}, status_code=500) + + app = ServerErrorMiddleware(app, handler=error_500) + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/") + assert response.status_code == 500 + assert response.json() == {"detail": "Server Error"} + + def test_debug_text(): def app(scope): async def asgi(receive, send): @@ -12,7 +29,7 @@ async def asgi(receive, send): return asgi - app = DebugMiddleware(app, debug=True) + app = ServerErrorMiddleware(app, debug=True) client = TestClient(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 @@ -27,7 +44,7 @@ async def asgi(receive, send): return asgi - app = DebugMiddleware(app, debug=True) + app = ServerErrorMiddleware(app, debug=True) client = TestClient(app, raise_server_exceptions=False) response = client.get("/", headers={"Accept": "text/html, */*"}) assert response.status_code == 500 @@ -44,7 +61,7 @@ async def asgi(receive, send): return asgi - app = DebugMiddleware(app, debug=True) + app = ServerErrorMiddleware(app, debug=True) client = TestClient(app) with pytest.raises(RuntimeError): client.get("/") @@ -54,7 +71,7 @@ def test_debug_error_during_scope(): def app(scope): raise RuntimeError("Something went wrong") - app = DebugMiddleware(app, debug=True) + app = ServerErrorMiddleware(app, debug=True) client = TestClient(app, raise_server_exceptions=False) response = client.get("/", headers={"Accept": "text/html, */*"}) assert response.status_code == 500 @@ -70,7 +87,7 @@ def test_debug_not_http(): def app(scope): raise RuntimeError("Something went wrong") - app = DebugMiddleware(app) + app = ServerErrorMiddleware(app) with pytest.raises(RuntimeError): app({"type": "websocket"}) diff --git a/tests/test_applications.py b/tests/test_applications.py index 323914a07..ef8496c8b 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -29,13 +29,18 @@ def __call__(self, scope): app.add_middleware(TrustedHostMiddleware, hostname="testserver") -@app.exception_handler(Exception) +@app.exception_handler(500) async def error_500(request, exc): return JSONResponse({"detail": "Server Error"}, status_code=500) +@app.exception_handler(405) +async def method_not_allowed(request, exc): + return JSONResponse({"detail": "Custom message"}, status_code=405) + + @app.exception_handler(HTTPException) -async def handler(request, exc): +async def http_exception(request, exc): return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) @@ -136,11 +141,11 @@ def test_400(): def test_405(): response = client.post("/func") assert response.status_code == 405 - assert response.json() == {"detail": "Method Not Allowed"} + assert response.json() == {"detail": "Custom message"} response = client.post("/class") assert response.status_code == 405 - assert response.json() == {"detail": "Method Not Allowed"} + assert response.json() == {"detail": "Custom message"} def test_500(): From 5bac15e01c8dc242eb8a09c5345e38a7eb7b7c81 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 8 Nov 2018 11:56:39 +0000 Subject: [PATCH 4/4] Exception handling updates --- docs/applications.md | 4 +- docs/debug.md | 21 ------- docs/exceptions.md | 111 ++++++++++++++------------------- mkdocs.yml | 1 - starlette/applications.py | 12 ++-- starlette/exceptions.py | 10 +-- starlette/middleware/errors.py | 2 +- 7 files changed, 62 insertions(+), 99 deletions(-) delete mode 100644 docs/debug.md diff --git a/docs/applications.md b/docs/applications.md index a9f3aed58..aac8592e5 100644 --- a/docs/applications.md +++ b/docs/applications.md @@ -69,6 +69,6 @@ Submounting applications is a powerful way to include reusable ASGI applications You can use either of the following to catch and handle particular types of exceptions that occur within the application: -* `app.add_exception_handler(exc_class, handler)` - Add an error handler. The handler function may be either a coroutine or a regular function, with a signature like `func(request, exc) -> response`. -* `@app.exception_handler(exc_class)` - Add an error handler, decorator style. +* `app.add_exception_handler(exc_class_or_status_code, handler)` - Add an error handler. The handler function may be either a coroutine or a regular function, with a signature like `func(request, exc) -> response`. +* `@app.exception_handler(exc_class_or_status_code)` - Add an error handler, decorator style. * `app.debug` - Enable or disable error tracebacks in the browser. diff --git a/docs/debug.md b/docs/debug.md deleted file mode 100644 index 52e194cf9..000000000 --- a/docs/debug.md +++ /dev/null @@ -1,21 +0,0 @@ - -You can use Starlette's `DebugMiddleware` to display simple error tracebacks in the browser. - -```python -from starlette.debug import DebugMiddleware - - -class App: - def __init__(self, scope): - self.scope = scope - - async def __call__(self, receive, send): - raise RuntimeError('Something went wrong') - - -app = DebugMiddleware(App) -``` - -For a more complete handling of exception cases you may wish to use Starlette's -[`ExceptionMiddleware`](../exceptions/) class instead, which also includes -optional debug handling. diff --git a/docs/exceptions.md b/docs/exceptions.md index 09de6e833..03908f86b 100644 --- a/docs/exceptions.md +++ b/docs/exceptions.md @@ -1,93 +1,76 @@ -Starlette includes an exception handling middleware that you can use in order -to dispatch different classes of exceptions to different handlers. - -To see how this works, we'll start by with this small ASGI application: +Starlette allows you to install custom exception handlers to deal with +how you return responses when errors or handled exceptions occur. ```python -from starlette.exceptions import ExceptionMiddleware, HTTPException - - -class App: - def __init__(self, scope): - raise HTTPException(status_code=403) - - -app = ExceptionMiddleware(App) -``` - -If you run the app and make an HTTP request to it, you'll get a plain text -response with a "403 Permission Denied" response. This is the behaviour that the -default handler responds with when an `HTTPException` class or subclass is raised. - -Let's change the exception handling, so that we get JSON error responses -instead: +from starlette.applications import Starlette +from starlette.responses import HTMLResponse -```python -from starlette.exceptions import ExceptionMiddleware, HTTPException -from starlette.responses import JSONResponse - +HTML_404_PAGE = ... +HTML_500_PAGE = ... -class App: - def __init__(self, scope): - raise HTTPException(status_code=403) +app = Starlette() -def handler(request, exc): - return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) +@app.exception_handler(404) +async def not_found(request, exc): + return HTMLResponse(content=HTML_404_PAGE) -app = ExceptionMiddleware(App) -app.add_exception_handler(HTTPException, handler) +@app.exception_handler(500) +async def server_error(request, exc): + return HTMLResponse(content=HTML_500_PAGE) ``` -Now if we make a request to the application, we'll get back a JSON encoded -HTTP response. - -By default two types of exceptions are caught and dealt with: - -* `HTTPException` - Used to raise standard HTTP error codes. -* `Exception` - Used as a catch-all handler to deal with any `500 Internal -Server Error` responses. The `Exception` case also wraps any other exception -handling. +If `debug` is enabled and an error occurs, then instead of using the installed +500 handler, Starlette will respond with a traceback response. -The catch-all `Exception` case is used to return simple `500 Internal Server Error` -responses. During development you might want to switch the behaviour so that -it displays an error traceback in the browser: - -``` -app = ExceptionMiddleware(App, debug=True) +```python +app = Starlette(debug=True) ``` -This uses the same error tracebacks as the more minimal [`DebugMiddleware`](../debugging). +As well as registering handlers for specific status codes, you can also +register handlers for classes of exceptions. -The exception handler currently only catches and deals with exceptions within -HTTP requests. Any websocket exceptions will simply be raised to the server -and result in an error log. +In particular you might want to override how the built-in `HTTPException` class +is handled. For example, to use JSON style responses: -## ExceptionMiddleware +```python +@app.exception_handler(HTTPException) +async def http_exception(request, exc): + return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) +``` -The exception middleware catches and handles the exceptions, returning -appropriate HTTP responses. +## Errors and handled exceptions -* `ExceptionMiddleware(app, debug=False)` - Instantiate the exception handler, -wrapping up it around an inner ASGI application. +It is important to differentiate between handled exceptions and errors. -Adding handlers: +Handled exceptions do not represent error cases. They are coerced into appropriate +HTTP responses, which are then sent through the standard middleware stack. By default +the `HTTPException` class is used to manage any handled exceptions. -* `.add_exception_handler(exc_class, handler)` - Set a handler function to run -for the given exception class. +Errors are any other exception that occurs within the application. These cases +should bubble through the entire middleware stack as exceptions. Any error +logging middleware should ensure that it re-raises the exception all the +way up to the server. -Enabling debug mode: +In order to deal with this behaviour correctly, the middleware stack of a +`Starlette` application is configured like this: -* `.debug` - If set to `True`, then the catch-all handler for `Exception` will -not be used, and error tracebacks will be sent as responses instead. +* `ServerErrorMiddleware` - Returns 500 responses when server errors occur. +* Installed middleware +* `ExceptionMiddleware` - Deals with handled exceptions, and returns responses. +* Router +* Endpoints ## HTTPException The `HTTPException` class provides a base class that you can use for any -standard HTTP error conditions. The `ExceptionMiddleware` implementation -defaults to returning plain-text HTTP responses for any `HTTPException`. +handled exceptions. The `ExceptionMiddleware` implementation defaults to +returning plain-text HTTP responses for any `HTTPException`. * `HTTPException(status_code, detail=None)` + +You should only raise `HTTPException` inside routing or endpoints. Middleware +classes should instead just return appropriate responses directly. diff --git a/mkdocs.yml b/mkdocs.yml index 91bd5a084..c7cb8fe3e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -24,7 +24,6 @@ nav: - Events: 'events.md' - Background Tasks: 'background.md' - Exceptions: 'exceptions.md' - - Debug: 'debug.md' - Test Client: 'testclient.md' - Release Notes: 'release-notes.md' diff --git a/starlette/applications.py b/starlette/applications.py index 9eef74472..db47ef62f 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -53,13 +53,15 @@ def add_middleware(self, middleware_class: type, **kwargs: typing.Any) -> None: def add_exception_handler( self, - category: typing.Union[int, typing.Type[Exception]], + exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], handler: typing.Callable, ) -> None: - if category in (500, Exception): + if exc_class_or_status_code in (500, Exception): self.error_middleware.handler = handler else: - self.exception_middleware.add_exception_handler(category, handler) + self.exception_middleware.add_exception_handler( + exc_class_or_status_code, handler + ) def add_event_handler(self, event_type: str, func: typing.Callable) -> None: self.lifespan_handler.add_event_handler(event_type, func) @@ -77,10 +79,10 @@ def add_websocket_route(self, path: str, route: typing.Callable) -> None: self.router.add_websocket_route(path, route) def exception_handler( - self, category: typing.Union[int, typing.Type[Exception]] + self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]] ) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: - self.add_exception_handler(category, func) + self.add_exception_handler(exc_class_or_status_code, func) return func return decorator diff --git a/starlette/exceptions.py b/starlette/exceptions.py index e51628f3d..64691e8cf 100644 --- a/starlette/exceptions.py +++ b/starlette/exceptions.py @@ -27,14 +27,14 @@ def __init__(self, app: ASGIApp, debug: bool = False) -> None: def add_exception_handler( self, - category: typing.Union[int, typing.Type[Exception]], + exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], handler: typing.Callable, ) -> None: - if isinstance(category, int): - self._status_handlers[category] = handler + if isinstance(exc_class_or_status_code, int): + self._status_handlers[exc_class_or_status_code] = handler else: - assert issubclass(category, Exception) - self._exception_handlers[category] = handler + assert issubclass(exc_class_or_status_code, Exception) + self._exception_handlers[exc_class_or_status_code] = handler def _lookup_exception_handler( self, exc: Exception diff --git a/starlette/middleware/errors.py b/starlette/middleware/errors.py index 021b8064d..56cb20f47 100644 --- a/starlette/middleware/errors.py +++ b/starlette/middleware/errors.py @@ -74,7 +74,7 @@ class ServerErrorMiddleware: If 'debug' is set, then traceback responses will be returned, otherwise the designated 'handler' will be called. - This middleware class should generally be used to wrap everything + This middleware class should generally be used to wrap *everything* else up, so that unhandled exceptions anywhere in the stack always result in an appropriate 500 response. """