From 9f3dcb7d2b8556773e0d0bf7fd61311dded2ea35 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 8 Nov 2018 11:59:15 +0000 Subject: [PATCH] Rejig ExceptionMiddleware and ServerErrorMiddleware (#193) * Rejig ExceptionMiddleware and ServerErrorMiddleware * Tweak DebugMiddleware implementation * Support custom 500 handlers * Exception handling updates --- docs/applications.md | 4 +- docs/debug.md | 21 --- docs/exceptions.md | 111 ++++++-------- mkdocs.yml | 1 - starlette/applications.py | 35 ++++- starlette/debug.py | 108 -------------- starlette/exceptions.py | 78 ++++------ starlette/middleware/errors.py | 141 ++++++++++++++++++ .../test_errors.py} | 31 +++- tests/test_applications.py | 13 +- tests/test_exceptions.py | 34 ++--- 11 files changed, 297 insertions(+), 280 deletions(-) delete mode 100644 docs/debug.md delete mode 100644 starlette/debug.py create mode 100644 starlette/middleware/errors.py rename tests/{test_debug.py => middleware/test_errors.py} (69%) diff --git a/docs/applications.md b/docs/applications.md index a9f3aed58..aac8592e5 100644 --- a/docs/applications.md +++ b/docs/applications.md @@ -69,6 +69,6 @@ Submounting applications is a powerful way to include reusable ASGI applications You can use either of the following to catch and handle particular types of exceptions that occur within the application: -* `app.add_exception_handler(exc_class, handler)` - Add an error handler. The handler function may be either a coroutine or a regular function, with a signature like `func(request, exc) -> response`. -* `@app.exception_handler(exc_class)` - Add an error handler, decorator style. +* `app.add_exception_handler(exc_class_or_status_code, handler)` - Add an error handler. The handler function may be either a coroutine or a regular function, with a signature like `func(request, exc) -> response`. +* `@app.exception_handler(exc_class_or_status_code)` - Add an error handler, decorator style. * `app.debug` - Enable or disable error tracebacks in the browser. diff --git a/docs/debug.md b/docs/debug.md deleted file mode 100644 index 52e194cf9..000000000 --- a/docs/debug.md +++ /dev/null @@ -1,21 +0,0 @@ - -You can use Starlette's `DebugMiddleware` to display simple error tracebacks in the browser. - -```python -from starlette.debug import DebugMiddleware - - -class App: - def __init__(self, scope): - self.scope = scope - - async def __call__(self, receive, send): - raise RuntimeError('Something went wrong') - - -app = DebugMiddleware(App) -``` - -For a more complete handling of exception cases you may wish to use Starlette's -[`ExceptionMiddleware`](../exceptions/) class instead, which also includes -optional debug handling. diff --git a/docs/exceptions.md b/docs/exceptions.md index 09de6e833..03908f86b 100644 --- a/docs/exceptions.md +++ b/docs/exceptions.md @@ -1,93 +1,76 @@ -Starlette includes an exception handling middleware that you can use in order -to dispatch different classes of exceptions to different handlers. - -To see how this works, we'll start by with this small ASGI application: +Starlette allows you to install custom exception handlers to deal with +how you return responses when errors or handled exceptions occur. ```python -from starlette.exceptions import ExceptionMiddleware, HTTPException - - -class App: - def __init__(self, scope): - raise HTTPException(status_code=403) - - -app = ExceptionMiddleware(App) -``` - -If you run the app and make an HTTP request to it, you'll get a plain text -response with a "403 Permission Denied" response. This is the behaviour that the -default handler responds with when an `HTTPException` class or subclass is raised. - -Let's change the exception handling, so that we get JSON error responses -instead: +from starlette.applications import Starlette +from starlette.responses import HTMLResponse -```python -from starlette.exceptions import ExceptionMiddleware, HTTPException -from starlette.responses import JSONResponse - +HTML_404_PAGE = ... +HTML_500_PAGE = ... -class App: - def __init__(self, scope): - raise HTTPException(status_code=403) +app = Starlette() -def handler(request, exc): - return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) +@app.exception_handler(404) +async def not_found(request, exc): + return HTMLResponse(content=HTML_404_PAGE) -app = ExceptionMiddleware(App) -app.add_exception_handler(HTTPException, handler) +@app.exception_handler(500) +async def server_error(request, exc): + return HTMLResponse(content=HTML_500_PAGE) ``` -Now if we make a request to the application, we'll get back a JSON encoded -HTTP response. - -By default two types of exceptions are caught and dealt with: - -* `HTTPException` - Used to raise standard HTTP error codes. -* `Exception` - Used as a catch-all handler to deal with any `500 Internal -Server Error` responses. The `Exception` case also wraps any other exception -handling. +If `debug` is enabled and an error occurs, then instead of using the installed +500 handler, Starlette will respond with a traceback response. -The catch-all `Exception` case is used to return simple `500 Internal Server Error` -responses. During development you might want to switch the behaviour so that -it displays an error traceback in the browser: - -``` -app = ExceptionMiddleware(App, debug=True) +```python +app = Starlette(debug=True) ``` -This uses the same error tracebacks as the more minimal [`DebugMiddleware`](../debugging). +As well as registering handlers for specific status codes, you can also +register handlers for classes of exceptions. -The exception handler currently only catches and deals with exceptions within -HTTP requests. Any websocket exceptions will simply be raised to the server -and result in an error log. +In particular you might want to override how the built-in `HTTPException` class +is handled. For example, to use JSON style responses: -## ExceptionMiddleware +```python +@app.exception_handler(HTTPException) +async def http_exception(request, exc): + return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) +``` -The exception middleware catches and handles the exceptions, returning -appropriate HTTP responses. +## Errors and handled exceptions -* `ExceptionMiddleware(app, debug=False)` - Instantiate the exception handler, -wrapping up it around an inner ASGI application. +It is important to differentiate between handled exceptions and errors. -Adding handlers: +Handled exceptions do not represent error cases. They are coerced into appropriate +HTTP responses, which are then sent through the standard middleware stack. By default +the `HTTPException` class is used to manage any handled exceptions. -* `.add_exception_handler(exc_class, handler)` - Set a handler function to run -for the given exception class. +Errors are any other exception that occurs within the application. These cases +should bubble through the entire middleware stack as exceptions. Any error +logging middleware should ensure that it re-raises the exception all the +way up to the server. -Enabling debug mode: +In order to deal with this behaviour correctly, the middleware stack of a +`Starlette` application is configured like this: -* `.debug` - If set to `True`, then the catch-all handler for `Exception` will -not be used, and error tracebacks will be sent as responses instead. +* `ServerErrorMiddleware` - Returns 500 responses when server errors occur. +* Installed middleware +* `ExceptionMiddleware` - Deals with handled exceptions, and returns responses. +* Router +* Endpoints ## HTTPException The `HTTPException` class provides a base class that you can use for any -standard HTTP error conditions. The `ExceptionMiddleware` implementation -defaults to returning plain-text HTTP responses for any `HTTPException`. +handled exceptions. The `ExceptionMiddleware` implementation defaults to +returning plain-text HTTP responses for any `HTTPException`. * `HTTPException(status_code, detail=None)` + +You should only raise `HTTPException` inside routing or endpoints. Middleware +classes should instead just return appropriate responses directly. diff --git a/mkdocs.yml b/mkdocs.yml index 91bd5a084..c7cb8fe3e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -24,7 +24,6 @@ nav: - Events: 'events.md' - Background Tasks: 'background.md' - Exceptions: 'exceptions.md' - - Debug: 'debug.md' - Test Client: 'testclient.md' - Release Notes: 'release-notes.md' diff --git a/starlette/applications.py b/starlette/applications.py index 464e9c572..db47ef62f 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -4,6 +4,7 @@ from starlette.exceptions import ExceptionMiddleware from starlette.lifespan import LifespanHandler from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.errors import ServerErrorMiddleware from starlette.routing import BaseRoute, Router from starlette.schemas import BaseSchemaGenerator from starlette.types import ASGIApp, ASGIInstance, Scope @@ -11,10 +12,13 @@ class Starlette: def __init__(self, debug: bool = False) -> None: + self._debug = debug self.router = Router() self.lifespan_handler = LifespanHandler() - self.app = self.router self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) + self.error_middleware = ServerErrorMiddleware( + self.exception_middleware, debug=debug + ) self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator] @property @@ -23,11 +27,13 @@ def routes(self) -> typing.List[BaseRoute]: @property def debug(self) -> bool: - return self.exception_middleware.debug + return self._debug @debug.setter def debug(self, value: bool) -> None: + self._debug = value self.exception_middleware.debug = value + self.error_middleware.debug = value @property def schema(self) -> dict: @@ -41,10 +47,21 @@ def mount(self, path: str, app: ASGIApp, name: str = None) -> None: self.router.mount(path, app=app, name=name) def add_middleware(self, middleware_class: type, **kwargs: typing.Any) -> None: - self.exception_middleware.app = middleware_class(self.app, **kwargs) + self.error_middleware.app = middleware_class( + self.error_middleware.app, **kwargs + ) - def add_exception_handler(self, exc_class: type, handler: typing.Callable) -> None: - self.exception_middleware.add_exception_handler(exc_class, handler) + def add_exception_handler( + self, + exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], + handler: typing.Callable, + ) -> None: + if exc_class_or_status_code in (500, Exception): + self.error_middleware.handler = handler + else: + self.exception_middleware.add_exception_handler( + exc_class_or_status_code, handler + ) def add_event_handler(self, event_type: str, func: typing.Callable) -> None: self.lifespan_handler.add_event_handler(event_type, func) @@ -61,9 +78,11 @@ def add_route( def add_websocket_route(self, path: str, route: typing.Callable) -> None: self.router.add_websocket_route(path, route) - def exception_handler(self, exc_class: type) -> typing.Callable: + def exception_handler( + self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]] + ) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: - self.add_exception_handler(exc_class, func) + self.add_exception_handler(exc_class_or_status_code, func) return func return decorator @@ -107,4 +126,4 @@ def __call__(self, scope: Scope) -> ASGIInstance: scope["app"] = self if scope["type"] == "lifespan": return self.lifespan_handler(scope) - return self.exception_middleware(scope) + return self.error_middleware(scope) diff --git a/starlette/debug.py b/starlette/debug.py deleted file mode 100644 index 32cac008d..000000000 --- a/starlette/debug.py +++ /dev/null @@ -1,108 +0,0 @@ -import traceback - -from starlette.requests import Request -from starlette.responses import HTMLResponse, PlainTextResponse, Response -from starlette.types import ASGIApp, ASGIInstance, Message, Receive, Scope, Send - -STYLES = """\ - .traceback-container {border: 1px solid #038BB8;} - .traceback-title {background-color: #038BB8;color: lemonchiffon;padding: 12px;font-size: 20px;margin-top: 0px;} - .traceback-content {padding: 5px 0px 20px 20px;} - .frame-line {font-weight: unset;padding: 10px 10px 10px 20px;background-color: #E4F4FD; - margin-left: 10px;margin-right: 10px;font: #394D54;color: #191f21;font-size: 17px;border: 1px solid #c7dce8;} -""" - -TEMPLATE = """ - - Starlette Debugger -

500 Server Error

-

{error}

-
-

Traceback

-
{ext_html}
-
-""" - -FRAME_TEMPLATE = """ -

- File `{frame_filename}`, - line {frame_lineno}, - in {frame_name} -

{frame_line}

-

-""" - - -class DebugGenerator: - def __init__(self, exc: Exception) -> None: - self.exc = exc - self.traceback_obj = traceback.TracebackException.from_exception( - exc, capture_locals=True - ) - self.error = f"{self.traceback_obj.exc_type.__name__}: {self.traceback_obj}" - - @staticmethod - def gen_frame_html(frame: traceback.FrameSummary) -> str: - values = { - "frame_filename": frame.filename, - "frame_lineno": frame.lineno, - "frame_name": frame.name, - "frame_line": frame.line, - } - return FRAME_TEMPLATE.format(**values) - - def generate_html(self) -> str: - ext_html = "".join( - [self.gen_frame_html(frame) for frame in self.traceback_obj.stack] - ) - values = {"style": STYLES, "error": self.error, "ext_html": ext_html} - - return TEMPLATE.format(**values) - - def generate_plain_text(self) -> str: - return "".join(traceback.format_tb(self.exc.__traceback__)) - - -def get_debug_response(request: Request, exc: Exception) -> Response: - accept = request.headers.get("accept", "") - debug_gen = DebugGenerator(exc) - - if "text/html" in accept: - content = debug_gen.generate_html() - return HTMLResponse(content, status_code=500) - content = debug_gen.generate_plain_text() - return PlainTextResponse(content, status_code=500) - - -class DebugMiddleware: - def __init__(self, app: ASGIApp) -> None: - self.app = app - - def __call__(self, scope: Scope) -> ASGIInstance: - if scope["type"] != "http": - return self.app(scope) - return _DebugResponder(self.app, scope) - - -class _DebugResponder: - def __init__(self, app: ASGIApp, scope: Scope) -> None: - self.app = app - self.scope = scope - self.response_started = False - - async def __call__(self, receive: Receive, send: Send) -> None: - self.raw_send = send - try: - asgi = self.app(self.scope) - await asgi(receive, self.send) - except Exception as exc: - if not self.response_started: - request = Request(self.scope) - response = get_debug_response(request, exc) - await response(receive, send) - raise exc from None - - async def send(self, message: Message) -> None: - if message["type"] == "http.response.start": - self.response_started = True - await self.raw_send(message) diff --git a/starlette/exceptions.py b/starlette/exceptions.py index a8d7142b7..64691e8cf 100644 --- a/starlette/exceptions.py +++ b/starlette/exceptions.py @@ -2,7 +2,7 @@ import http import typing -from starlette.debug import get_debug_response +from starlette.concurrency import run_in_threadpool from starlette.requests import Request from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, ASGIInstance, Message, Receive, Scope, Send @@ -19,24 +19,27 @@ def __init__(self, status_code: int, detail: str = None) -> None: class ExceptionMiddleware: def __init__(self, app: ASGIApp, debug: bool = False) -> None: self.app = app - self.debug = debug + self.debug = debug # TODO: We ought to handle 404 cases if debug is set. + self._status_handlers = {} # type: typing.Dict[int, typing.Callable] self._exception_handlers = { - Exception: self.server_error, - HTTPException: self.http_exception, - } + HTTPException: self.http_exception + } # type: typing.Dict[typing.Type[Exception], typing.Callable] def add_exception_handler( - self, exc_class: typing.Type[Exception], handler: typing.Callable + self, + exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], + handler: typing.Callable, ) -> None: - assert issubclass(exc_class, BaseException) - self._exception_handlers[exc_class] = handler + if isinstance(exc_class_or_status_code, int): + self._status_handlers[exc_class_or_status_code] = handler + else: + assert issubclass(exc_class_or_status_code, Exception) + self._exception_handlers[exc_class_or_status_code] = handler def _lookup_exception_handler( self, exc: Exception ) -> typing.Optional[typing.Callable]: for cls in type(exc).__mro__: - if cls is Exception: - break if cls in self._exception_handlers: return self._exception_handlers[cls] return None @@ -56,48 +59,30 @@ async def sender(message: Message) -> None: await send(message) try: - try: - instance = self.app(scope) - await instance(receive, sender) - except Exception as exc: - # Exception handling is applied to any registed exception - # class or subclass that occurs within the application. - handler = self._lookup_exception_handler(exc) + instance = self.app(scope) + await instance(receive, sender) + except Exception as exc: + handler = None - # Note that we always handle `Exception` in the outermost block. - if handler is None: - raise exc from None + if isinstance(exc, HTTPException): + handler = self._status_handlers.get(exc.status_code) - if response_started: - msg = "Caught handled exception, but response already started." - raise RuntimeError(msg) from exc + if handler is None: + handler = self._lookup_exception_handler(exc) - request = Request(scope, receive=receive) - if asyncio.iscoroutinefunction(handler): - response = await handler(request, exc) - else: - response = handler(request, exc) - await response(receive, sender) + if handler is None: + raise exc from None + + if response_started: + msg = "Caught handled exception, but response already started." + raise RuntimeError(msg) from exc - except Exception as exc: - # The 'Exception' case always wraps everything else, and - # provides a last-ditch handler for dealing with server errors. request = Request(scope, receive=receive) - if self.debug: - handler = get_debug_response - else: - handler = self._exception_handlers[Exception] if asyncio.iscoroutinefunction(handler): - response = await handler(request, exc) # type: ignore + response = await handler(request, exc) else: - response = handler(request, exc) # type: ignore - if not response_started: - await response(receive, send) - - # We always raise the exception up to the server so that it - # is notified too. Typically this will mean that it'll log - # the exception. - raise + response = await run_in_threadpool(handler, request, exc) + await response(receive, sender) return app @@ -105,6 +90,3 @@ def http_exception(self, request: Request, exc: HTTPException) -> Response: if exc.status_code in {204, 304}: return Response(b"", status_code=exc.status_code) return PlainTextResponse(exc.detail, status_code=exc.status_code) - - def server_error(self, request: Request, exc: HTTPException) -> Response: - return PlainTextResponse("Internal Server Error", status_code=500) diff --git a/starlette/middleware/errors.py b/starlette/middleware/errors.py new file mode 100644 index 000000000..56cb20f47 --- /dev/null +++ b/starlette/middleware/errors.py @@ -0,0 +1,141 @@ +import asyncio +import functools +import traceback +import typing + +from starlette.concurrency import run_in_threadpool +from starlette.requests import Request +from starlette.responses import HTMLResponse, PlainTextResponse, Response +from starlette.types import ASGIApp, ASGIInstance, Message, Receive, Scope, Send + +STYLES = """\ + .traceback-container {border: 1px solid #038BB8;} + .traceback-title {background-color: #038BB8;color: lemonchiffon;padding: 12px;font-size: 20px;margin-top: 0px;} + .traceback-content {padding: 5px 0px 20px 20px;} + .frame-line {font-weight: unset;padding: 10px 10px 10px 20px;background-color: #E4F4FD; + margin-left: 10px;margin-right: 10px;font: #394D54;color: #191f21;font-size: 17px;border: 1px solid #c7dce8;} +""" + +TEMPLATE = """ + + Starlette Debugger +

500 Server Error

+

{error}

+
+

Traceback

+
{ext_html}
+
+""" + +FRAME_TEMPLATE = """ +

+ File `{frame_filename}`, + line {frame_lineno}, + in {frame_name} +

{frame_line}

+

+""" + + +class DebugGenerator: + def __init__(self, exc: Exception) -> None: + self.exc = exc + self.traceback_obj = traceback.TracebackException.from_exception( + exc, capture_locals=True + ) + self.error = f"{self.traceback_obj.exc_type.__name__}: {self.traceback_obj}" + + @staticmethod + def gen_frame_html(frame: traceback.FrameSummary) -> str: + values = { + "frame_filename": frame.filename, + "frame_lineno": frame.lineno, + "frame_name": frame.name, + "frame_line": frame.line, + } + return FRAME_TEMPLATE.format(**values) + + def generate_html(self) -> str: + ext_html = "".join( + [self.gen_frame_html(frame) for frame in self.traceback_obj.stack] + ) + values = {"style": STYLES, "error": self.error, "ext_html": ext_html} + + return TEMPLATE.format(**values) + + def generate_plain_text(self) -> str: + return "".join(traceback.format_tb(self.exc.__traceback__)) + + +class ServerErrorMiddleware: + """ + Handles returning 500 responses when a server error occurs. + + If 'debug' is set, then traceback responses will be returned, + otherwise the designated 'handler' will be called. + + This middleware class should generally be used to wrap *everything* + else up, so that unhandled exceptions anywhere in the stack + always result in an appropriate 500 response. + """ + + def __init__( + self, app: ASGIApp, handler: typing.Callable = None, debug: bool = False + ) -> None: + self.app = app + self.handler = handler + self.debug = debug + + def __call__(self, scope: Scope) -> ASGIInstance: + if scope["type"] != "http": + return self.app(scope) + return functools.partial(self.asgi, scope=scope) + + async def asgi(self, receive: Receive, send: Send, scope: Scope) -> None: + response_started = False + + async def _send(message: Message) -> None: + nonlocal response_started, send + + if message["type"] == "http.response.start": + response_started = True + await send(message) + + try: + asgi = self.app(scope) + await asgi(receive, _send) + except Exception as exc: + if not response_started: + request = Request(scope) + if self.debug: + # In debug mode, return traceback responses. + response = self.debug_response(request, exc) + elif self.handler is None: + # Use our default 500 error handler. + response = self.error_response(request, exc) + else: + # Use an installed 500 error handler. + if asyncio.iscoroutinefunction(self.handler): + response = await self.handler(request, exc) + else: + response = await run_in_threadpool(self.handler, request, exc) + + await response(receive, send) + + # We always continue to raise the exception. + # This allows servers to log the error, or allows test clients + # to optionally raise the error within the test case. + raise exc from None + + def debug_response(self, request: Request, exc: Exception) -> Response: + accept = request.headers.get("accept", "") + debug_gen = DebugGenerator(exc) + + if "text/html" in accept: + content = debug_gen.generate_html() + return HTMLResponse(content, status_code=500) + content = debug_gen.generate_plain_text() + return PlainTextResponse(content, status_code=500) + + def error_response(self, request: Request, exc: Exception) -> Response: + return PlainTextResponse("Internal Server Error", status_code=500) diff --git a/tests/test_debug.py b/tests/middleware/test_errors.py similarity index 69% rename from tests/test_debug.py rename to tests/middleware/test_errors.py index 48f71162e..ccaaafb67 100644 --- a/tests/test_debug.py +++ b/tests/middleware/test_errors.py @@ -1,10 +1,27 @@ import pytest -from starlette.debug import DebugMiddleware -from starlette.responses import Response +from starlette.middleware.errors import ServerErrorMiddleware +from starlette.responses import JSONResponse, Response from starlette.testclient import TestClient +def test_handler(): + def app(scope): + async def asgi(receive, send): + raise RuntimeError("Something went wrong") + + return asgi + + def error_500(request, exc): + return JSONResponse({"detail": "Server Error"}, status_code=500) + + app = ServerErrorMiddleware(app, handler=error_500) + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/") + assert response.status_code == 500 + assert response.json() == {"detail": "Server Error"} + + def test_debug_text(): def app(scope): async def asgi(receive, send): @@ -12,7 +29,7 @@ async def asgi(receive, send): return asgi - app = DebugMiddleware(app) + app = ServerErrorMiddleware(app, debug=True) client = TestClient(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 @@ -27,7 +44,7 @@ async def asgi(receive, send): return asgi - app = DebugMiddleware(app) + app = ServerErrorMiddleware(app, debug=True) client = TestClient(app, raise_server_exceptions=False) response = client.get("/", headers={"Accept": "text/html, */*"}) assert response.status_code == 500 @@ -44,7 +61,7 @@ async def asgi(receive, send): return asgi - app = DebugMiddleware(app) + app = ServerErrorMiddleware(app, debug=True) client = TestClient(app) with pytest.raises(RuntimeError): client.get("/") @@ -54,7 +71,7 @@ def test_debug_error_during_scope(): def app(scope): raise RuntimeError("Something went wrong") - app = DebugMiddleware(app) + app = ServerErrorMiddleware(app, debug=True) client = TestClient(app, raise_server_exceptions=False) response = client.get("/", headers={"Accept": "text/html, */*"}) assert response.status_code == 500 @@ -70,7 +87,7 @@ def test_debug_not_http(): def app(scope): raise RuntimeError("Something went wrong") - app = DebugMiddleware(app) + app = ServerErrorMiddleware(app) with pytest.raises(RuntimeError): app({"type": "websocket"}) diff --git a/tests/test_applications.py b/tests/test_applications.py index 323914a07..ef8496c8b 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -29,13 +29,18 @@ def __call__(self, scope): app.add_middleware(TrustedHostMiddleware, hostname="testserver") -@app.exception_handler(Exception) +@app.exception_handler(500) async def error_500(request, exc): return JSONResponse({"detail": "Server Error"}, status_code=500) +@app.exception_handler(405) +async def method_not_allowed(request, exc): + return JSONResponse({"detail": "Custom message"}, status_code=405) + + @app.exception_handler(HTTPException) -async def handler(request, exc): +async def http_exception(request, exc): return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) @@ -136,11 +141,11 @@ def test_400(): def test_405(): response = client.post("/func") assert response.status_code == 405 - assert response.json() == {"detail": "Method Not Allowed"} + assert response.json() == {"detail": "Custom message"} response = client.post("/class") assert response.status_code == 405 - assert response.json() == {"detail": "Method Not Allowed"} + assert response.json() == {"detail": "Custom message"} def test_500(): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 51de7ff1b..a18ef3246 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -43,23 +43,23 @@ async def __call__(self, receive, send): client = TestClient(app) -def test_server_error(): - with pytest.raises(RuntimeError): - response = client.get("/runtime_error") - - allow_500_client = TestClient(app, raise_server_exceptions=False) - response = allow_500_client.get("/runtime_error") - assert response.status_code == 500 - assert response.text == "Internal Server Error" - - -def test_debug_enabled(): - app = ExceptionMiddleware(router) - app.debug = True - allow_500_client = TestClient(app, raise_server_exceptions=False) - response = allow_500_client.get("/runtime_error") - assert response.status_code == 500 - assert "RuntimeError" in response.text +# def test_server_error(): +# with pytest.raises(RuntimeError): +# response = client.get("/runtime_error") +# +# allow_500_client = TestClient(app, raise_server_exceptions=False) +# response = allow_500_client.get("/runtime_error") +# assert response.status_code == 500 +# assert response.text == "Internal Server Error" + + +# def test_debug_enabled(): +# app = ExceptionMiddleware(router) +# app.debug = True +# allow_500_client = TestClient(app, raise_server_exceptions=False) +# response = allow_500_client.get("/runtime_error") +# assert response.status_code == 500 +# assert "RuntimeError" in response.text def test_not_acceptable():