From 9f3dcb7d2b8556773e0d0bf7fd61311dded2ea35 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 8 Nov 2018 11:59:15 +0000
Subject: [PATCH] Rejig ExceptionMiddleware and ServerErrorMiddleware (#193)
* Rejig ExceptionMiddleware and ServerErrorMiddleware
* Tweak DebugMiddleware implementation
* Support custom 500 handlers
* Exception handling updates
---
docs/applications.md | 4 +-
docs/debug.md | 21 ---
docs/exceptions.md | 111 ++++++--------
mkdocs.yml | 1 -
starlette/applications.py | 35 ++++-
starlette/debug.py | 108 --------------
starlette/exceptions.py | 78 ++++------
starlette/middleware/errors.py | 141 ++++++++++++++++++
.../test_errors.py} | 31 +++-
tests/test_applications.py | 13 +-
tests/test_exceptions.py | 34 ++---
11 files changed, 297 insertions(+), 280 deletions(-)
delete mode 100644 docs/debug.md
delete mode 100644 starlette/debug.py
create mode 100644 starlette/middleware/errors.py
rename tests/{test_debug.py => middleware/test_errors.py} (69%)
diff --git a/docs/applications.md b/docs/applications.md
index a9f3aed58..aac8592e5 100644
--- a/docs/applications.md
+++ b/docs/applications.md
@@ -69,6 +69,6 @@ Submounting applications is a powerful way to include reusable ASGI applications
You can use either of the following to catch and handle particular types of
exceptions that occur within the application:
-* `app.add_exception_handler(exc_class, handler)` - Add an error handler. The handler function may be either a coroutine or a regular function, with a signature like `func(request, exc) -> response`.
-* `@app.exception_handler(exc_class)` - Add an error handler, decorator style.
+* `app.add_exception_handler(exc_class_or_status_code, handler)` - Add an error handler. The handler function may be either a coroutine or a regular function, with a signature like `func(request, exc) -> response`.
+* `@app.exception_handler(exc_class_or_status_code)` - Add an error handler, decorator style.
* `app.debug` - Enable or disable error tracebacks in the browser.
diff --git a/docs/debug.md b/docs/debug.md
deleted file mode 100644
index 52e194cf9..000000000
--- a/docs/debug.md
+++ /dev/null
@@ -1,21 +0,0 @@
-
-You can use Starlette's `DebugMiddleware` to display simple error tracebacks in the browser.
-
-```python
-from starlette.debug import DebugMiddleware
-
-
-class App:
- def __init__(self, scope):
- self.scope = scope
-
- async def __call__(self, receive, send):
- raise RuntimeError('Something went wrong')
-
-
-app = DebugMiddleware(App)
-```
-
-For a more complete handling of exception cases you may wish to use Starlette's
-[`ExceptionMiddleware`](../exceptions/) class instead, which also includes
-optional debug handling.
diff --git a/docs/exceptions.md b/docs/exceptions.md
index 09de6e833..03908f86b 100644
--- a/docs/exceptions.md
+++ b/docs/exceptions.md
@@ -1,93 +1,76 @@
-Starlette includes an exception handling middleware that you can use in order
-to dispatch different classes of exceptions to different handlers.
-
-To see how this works, we'll start by with this small ASGI application:
+Starlette allows you to install custom exception handlers to deal with
+how you return responses when errors or handled exceptions occur.
```python
-from starlette.exceptions import ExceptionMiddleware, HTTPException
-
-
-class App:
- def __init__(self, scope):
- raise HTTPException(status_code=403)
-
-
-app = ExceptionMiddleware(App)
-```
-
-If you run the app and make an HTTP request to it, you'll get a plain text
-response with a "403 Permission Denied" response. This is the behaviour that the
-default handler responds with when an `HTTPException` class or subclass is raised.
-
-Let's change the exception handling, so that we get JSON error responses
-instead:
+from starlette.applications import Starlette
+from starlette.responses import HTMLResponse
-```python
-from starlette.exceptions import ExceptionMiddleware, HTTPException
-from starlette.responses import JSONResponse
-
+HTML_404_PAGE = ...
+HTML_500_PAGE = ...
-class App:
- def __init__(self, scope):
- raise HTTPException(status_code=403)
+app = Starlette()
-def handler(request, exc):
- return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
+@app.exception_handler(404)
+async def not_found(request, exc):
+ return HTMLResponse(content=HTML_404_PAGE)
-app = ExceptionMiddleware(App)
-app.add_exception_handler(HTTPException, handler)
+@app.exception_handler(500)
+async def server_error(request, exc):
+ return HTMLResponse(content=HTML_500_PAGE)
```
-Now if we make a request to the application, we'll get back a JSON encoded
-HTTP response.
-
-By default two types of exceptions are caught and dealt with:
-
-* `HTTPException` - Used to raise standard HTTP error codes.
-* `Exception` - Used as a catch-all handler to deal with any `500 Internal
-Server Error` responses. The `Exception` case also wraps any other exception
-handling.
+If `debug` is enabled and an error occurs, then instead of using the installed
+500 handler, Starlette will respond with a traceback response.
-The catch-all `Exception` case is used to return simple `500 Internal Server Error`
-responses. During development you might want to switch the behaviour so that
-it displays an error traceback in the browser:
-
-```
-app = ExceptionMiddleware(App, debug=True)
+```python
+app = Starlette(debug=True)
```
-This uses the same error tracebacks as the more minimal [`DebugMiddleware`](../debugging).
+As well as registering handlers for specific status codes, you can also
+register handlers for classes of exceptions.
-The exception handler currently only catches and deals with exceptions within
-HTTP requests. Any websocket exceptions will simply be raised to the server
-and result in an error log.
+In particular you might want to override how the built-in `HTTPException` class
+is handled. For example, to use JSON style responses:
-## ExceptionMiddleware
+```python
+@app.exception_handler(HTTPException)
+async def http_exception(request, exc):
+ return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
+```
-The exception middleware catches and handles the exceptions, returning
-appropriate HTTP responses.
+## Errors and handled exceptions
-* `ExceptionMiddleware(app, debug=False)` - Instantiate the exception handler,
-wrapping up it around an inner ASGI application.
+It is important to differentiate between handled exceptions and errors.
-Adding handlers:
+Handled exceptions do not represent error cases. They are coerced into appropriate
+HTTP responses, which are then sent through the standard middleware stack. By default
+the `HTTPException` class is used to manage any handled exceptions.
-* `.add_exception_handler(exc_class, handler)` - Set a handler function to run
-for the given exception class.
+Errors are any other exception that occurs within the application. These cases
+should bubble through the entire middleware stack as exceptions. Any error
+logging middleware should ensure that it re-raises the exception all the
+way up to the server.
-Enabling debug mode:
+In order to deal with this behaviour correctly, the middleware stack of a
+`Starlette` application is configured like this:
-* `.debug` - If set to `True`, then the catch-all handler for `Exception` will
-not be used, and error tracebacks will be sent as responses instead.
+* `ServerErrorMiddleware` - Returns 500 responses when server errors occur.
+* Installed middleware
+* `ExceptionMiddleware` - Deals with handled exceptions, and returns responses.
+* Router
+* Endpoints
## HTTPException
The `HTTPException` class provides a base class that you can use for any
-standard HTTP error conditions. The `ExceptionMiddleware` implementation
-defaults to returning plain-text HTTP responses for any `HTTPException`.
+handled exceptions. The `ExceptionMiddleware` implementation defaults to
+returning plain-text HTTP responses for any `HTTPException`.
* `HTTPException(status_code, detail=None)`
+
+You should only raise `HTTPException` inside routing or endpoints. Middleware
+classes should instead just return appropriate responses directly.
diff --git a/mkdocs.yml b/mkdocs.yml
index 91bd5a084..c7cb8fe3e 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -24,7 +24,6 @@ nav:
- Events: 'events.md'
- Background Tasks: 'background.md'
- Exceptions: 'exceptions.md'
- - Debug: 'debug.md'
- Test Client: 'testclient.md'
- Release Notes: 'release-notes.md'
diff --git a/starlette/applications.py b/starlette/applications.py
index 464e9c572..db47ef62f 100644
--- a/starlette/applications.py
+++ b/starlette/applications.py
@@ -4,6 +4,7 @@
from starlette.exceptions import ExceptionMiddleware
from starlette.lifespan import LifespanHandler
from starlette.middleware.base import BaseHTTPMiddleware
+from starlette.middleware.errors import ServerErrorMiddleware
from starlette.routing import BaseRoute, Router
from starlette.schemas import BaseSchemaGenerator
from starlette.types import ASGIApp, ASGIInstance, Scope
@@ -11,10 +12,13 @@
class Starlette:
def __init__(self, debug: bool = False) -> None:
+ self._debug = debug
self.router = Router()
self.lifespan_handler = LifespanHandler()
- self.app = self.router
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
+ self.error_middleware = ServerErrorMiddleware(
+ self.exception_middleware, debug=debug
+ )
self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator]
@property
@@ -23,11 +27,13 @@ def routes(self) -> typing.List[BaseRoute]:
@property
def debug(self) -> bool:
- return self.exception_middleware.debug
+ return self._debug
@debug.setter
def debug(self, value: bool) -> None:
+ self._debug = value
self.exception_middleware.debug = value
+ self.error_middleware.debug = value
@property
def schema(self) -> dict:
@@ -41,10 +47,21 @@ def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
self.router.mount(path, app=app, name=name)
def add_middleware(self, middleware_class: type, **kwargs: typing.Any) -> None:
- self.exception_middleware.app = middleware_class(self.app, **kwargs)
+ self.error_middleware.app = middleware_class(
+ self.error_middleware.app, **kwargs
+ )
- def add_exception_handler(self, exc_class: type, handler: typing.Callable) -> None:
- self.exception_middleware.add_exception_handler(exc_class, handler)
+ def add_exception_handler(
+ self,
+ exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
+ handler: typing.Callable,
+ ) -> None:
+ if exc_class_or_status_code in (500, Exception):
+ self.error_middleware.handler = handler
+ else:
+ self.exception_middleware.add_exception_handler(
+ exc_class_or_status_code, handler
+ )
def add_event_handler(self, event_type: str, func: typing.Callable) -> None:
self.lifespan_handler.add_event_handler(event_type, func)
@@ -61,9 +78,11 @@ def add_route(
def add_websocket_route(self, path: str, route: typing.Callable) -> None:
self.router.add_websocket_route(path, route)
- def exception_handler(self, exc_class: type) -> typing.Callable:
+ def exception_handler(
+ self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]]
+ ) -> typing.Callable:
def decorator(func: typing.Callable) -> typing.Callable:
- self.add_exception_handler(exc_class, func)
+ self.add_exception_handler(exc_class_or_status_code, func)
return func
return decorator
@@ -107,4 +126,4 @@ def __call__(self, scope: Scope) -> ASGIInstance:
scope["app"] = self
if scope["type"] == "lifespan":
return self.lifespan_handler(scope)
- return self.exception_middleware(scope)
+ return self.error_middleware(scope)
diff --git a/starlette/debug.py b/starlette/debug.py
deleted file mode 100644
index 32cac008d..000000000
--- a/starlette/debug.py
+++ /dev/null
@@ -1,108 +0,0 @@
-import traceback
-
-from starlette.requests import Request
-from starlette.responses import HTMLResponse, PlainTextResponse, Response
-from starlette.types import ASGIApp, ASGIInstance, Message, Receive, Scope, Send
-
-STYLES = """\
- .traceback-container {border: 1px solid #038BB8;}
- .traceback-title {background-color: #038BB8;color: lemonchiffon;padding: 12px;font-size: 20px;margin-top: 0px;}
- .traceback-content {padding: 5px 0px 20px 20px;}
- .frame-line {font-weight: unset;padding: 10px 10px 10px 20px;background-color: #E4F4FD;
- margin-left: 10px;margin-right: 10px;font: #394D54;color: #191f21;font-size: 17px;border: 1px solid #c7dce8;}
-"""
-
-TEMPLATE = """
-
- Starlette Debugger
- 500 Server Error
- {error}
-
-
Traceback
-
{ext_html}
-
-"""
-
-FRAME_TEMPLATE = """
-
- File `{frame_filename}`,
- line {frame_lineno},
- in {frame_name}
-
{frame_line}
-
-"""
-
-
-class DebugGenerator:
- def __init__(self, exc: Exception) -> None:
- self.exc = exc
- self.traceback_obj = traceback.TracebackException.from_exception(
- exc, capture_locals=True
- )
- self.error = f"{self.traceback_obj.exc_type.__name__}: {self.traceback_obj}"
-
- @staticmethod
- def gen_frame_html(frame: traceback.FrameSummary) -> str:
- values = {
- "frame_filename": frame.filename,
- "frame_lineno": frame.lineno,
- "frame_name": frame.name,
- "frame_line": frame.line,
- }
- return FRAME_TEMPLATE.format(**values)
-
- def generate_html(self) -> str:
- ext_html = "".join(
- [self.gen_frame_html(frame) for frame in self.traceback_obj.stack]
- )
- values = {"style": STYLES, "error": self.error, "ext_html": ext_html}
-
- return TEMPLATE.format(**values)
-
- def generate_plain_text(self) -> str:
- return "".join(traceback.format_tb(self.exc.__traceback__))
-
-
-def get_debug_response(request: Request, exc: Exception) -> Response:
- accept = request.headers.get("accept", "")
- debug_gen = DebugGenerator(exc)
-
- if "text/html" in accept:
- content = debug_gen.generate_html()
- return HTMLResponse(content, status_code=500)
- content = debug_gen.generate_plain_text()
- return PlainTextResponse(content, status_code=500)
-
-
-class DebugMiddleware:
- def __init__(self, app: ASGIApp) -> None:
- self.app = app
-
- def __call__(self, scope: Scope) -> ASGIInstance:
- if scope["type"] != "http":
- return self.app(scope)
- return _DebugResponder(self.app, scope)
-
-
-class _DebugResponder:
- def __init__(self, app: ASGIApp, scope: Scope) -> None:
- self.app = app
- self.scope = scope
- self.response_started = False
-
- async def __call__(self, receive: Receive, send: Send) -> None:
- self.raw_send = send
- try:
- asgi = self.app(self.scope)
- await asgi(receive, self.send)
- except Exception as exc:
- if not self.response_started:
- request = Request(self.scope)
- response = get_debug_response(request, exc)
- await response(receive, send)
- raise exc from None
-
- async def send(self, message: Message) -> None:
- if message["type"] == "http.response.start":
- self.response_started = True
- await self.raw_send(message)
diff --git a/starlette/exceptions.py b/starlette/exceptions.py
index a8d7142b7..64691e8cf 100644
--- a/starlette/exceptions.py
+++ b/starlette/exceptions.py
@@ -2,7 +2,7 @@
import http
import typing
-from starlette.debug import get_debug_response
+from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, ASGIInstance, Message, Receive, Scope, Send
@@ -19,24 +19,27 @@ def __init__(self, status_code: int, detail: str = None) -> None:
class ExceptionMiddleware:
def __init__(self, app: ASGIApp, debug: bool = False) -> None:
self.app = app
- self.debug = debug
+ self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
+ self._status_handlers = {} # type: typing.Dict[int, typing.Callable]
self._exception_handlers = {
- Exception: self.server_error,
- HTTPException: self.http_exception,
- }
+ HTTPException: self.http_exception
+ } # type: typing.Dict[typing.Type[Exception], typing.Callable]
def add_exception_handler(
- self, exc_class: typing.Type[Exception], handler: typing.Callable
+ self,
+ exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
+ handler: typing.Callable,
) -> None:
- assert issubclass(exc_class, BaseException)
- self._exception_handlers[exc_class] = handler
+ if isinstance(exc_class_or_status_code, int):
+ self._status_handlers[exc_class_or_status_code] = handler
+ else:
+ assert issubclass(exc_class_or_status_code, Exception)
+ self._exception_handlers[exc_class_or_status_code] = handler
def _lookup_exception_handler(
self, exc: Exception
) -> typing.Optional[typing.Callable]:
for cls in type(exc).__mro__:
- if cls is Exception:
- break
if cls in self._exception_handlers:
return self._exception_handlers[cls]
return None
@@ -56,48 +59,30 @@ async def sender(message: Message) -> None:
await send(message)
try:
- try:
- instance = self.app(scope)
- await instance(receive, sender)
- except Exception as exc:
- # Exception handling is applied to any registed exception
- # class or subclass that occurs within the application.
- handler = self._lookup_exception_handler(exc)
+ instance = self.app(scope)
+ await instance(receive, sender)
+ except Exception as exc:
+ handler = None
- # Note that we always handle `Exception` in the outermost block.
- if handler is None:
- raise exc from None
+ if isinstance(exc, HTTPException):
+ handler = self._status_handlers.get(exc.status_code)
- if response_started:
- msg = "Caught handled exception, but response already started."
- raise RuntimeError(msg) from exc
+ if handler is None:
+ handler = self._lookup_exception_handler(exc)
- request = Request(scope, receive=receive)
- if asyncio.iscoroutinefunction(handler):
- response = await handler(request, exc)
- else:
- response = handler(request, exc)
- await response(receive, sender)
+ if handler is None:
+ raise exc from None
+
+ if response_started:
+ msg = "Caught handled exception, but response already started."
+ raise RuntimeError(msg) from exc
- except Exception as exc:
- # The 'Exception' case always wraps everything else, and
- # provides a last-ditch handler for dealing with server errors.
request = Request(scope, receive=receive)
- if self.debug:
- handler = get_debug_response
- else:
- handler = self._exception_handlers[Exception]
if asyncio.iscoroutinefunction(handler):
- response = await handler(request, exc) # type: ignore
+ response = await handler(request, exc)
else:
- response = handler(request, exc) # type: ignore
- if not response_started:
- await response(receive, send)
-
- # We always raise the exception up to the server so that it
- # is notified too. Typically this will mean that it'll log
- # the exception.
- raise
+ response = await run_in_threadpool(handler, request, exc)
+ await response(receive, sender)
return app
@@ -105,6 +90,3 @@ def http_exception(self, request: Request, exc: HTTPException) -> Response:
if exc.status_code in {204, 304}:
return Response(b"", status_code=exc.status_code)
return PlainTextResponse(exc.detail, status_code=exc.status_code)
-
- def server_error(self, request: Request, exc: HTTPException) -> Response:
- return PlainTextResponse("Internal Server Error", status_code=500)
diff --git a/starlette/middleware/errors.py b/starlette/middleware/errors.py
new file mode 100644
index 000000000..56cb20f47
--- /dev/null
+++ b/starlette/middleware/errors.py
@@ -0,0 +1,141 @@
+import asyncio
+import functools
+import traceback
+import typing
+
+from starlette.concurrency import run_in_threadpool
+from starlette.requests import Request
+from starlette.responses import HTMLResponse, PlainTextResponse, Response
+from starlette.types import ASGIApp, ASGIInstance, Message, Receive, Scope, Send
+
+STYLES = """\
+ .traceback-container {border: 1px solid #038BB8;}
+ .traceback-title {background-color: #038BB8;color: lemonchiffon;padding: 12px;font-size: 20px;margin-top: 0px;}
+ .traceback-content {padding: 5px 0px 20px 20px;}
+ .frame-line {font-weight: unset;padding: 10px 10px 10px 20px;background-color: #E4F4FD;
+ margin-left: 10px;margin-right: 10px;font: #394D54;color: #191f21;font-size: 17px;border: 1px solid #c7dce8;}
+"""
+
+TEMPLATE = """
+
+ Starlette Debugger
+ 500 Server Error
+ {error}
+
+
Traceback
+
{ext_html}
+
+"""
+
+FRAME_TEMPLATE = """
+
+ File `{frame_filename}`,
+ line {frame_lineno},
+ in {frame_name}
+
{frame_line}
+
+"""
+
+
+class DebugGenerator:
+ def __init__(self, exc: Exception) -> None:
+ self.exc = exc
+ self.traceback_obj = traceback.TracebackException.from_exception(
+ exc, capture_locals=True
+ )
+ self.error = f"{self.traceback_obj.exc_type.__name__}: {self.traceback_obj}"
+
+ @staticmethod
+ def gen_frame_html(frame: traceback.FrameSummary) -> str:
+ values = {
+ "frame_filename": frame.filename,
+ "frame_lineno": frame.lineno,
+ "frame_name": frame.name,
+ "frame_line": frame.line,
+ }
+ return FRAME_TEMPLATE.format(**values)
+
+ def generate_html(self) -> str:
+ ext_html = "".join(
+ [self.gen_frame_html(frame) for frame in self.traceback_obj.stack]
+ )
+ values = {"style": STYLES, "error": self.error, "ext_html": ext_html}
+
+ return TEMPLATE.format(**values)
+
+ def generate_plain_text(self) -> str:
+ return "".join(traceback.format_tb(self.exc.__traceback__))
+
+
+class ServerErrorMiddleware:
+ """
+ Handles returning 500 responses when a server error occurs.
+
+ If 'debug' is set, then traceback responses will be returned,
+ otherwise the designated 'handler' will be called.
+
+ This middleware class should generally be used to wrap *everything*
+ else up, so that unhandled exceptions anywhere in the stack
+ always result in an appropriate 500 response.
+ """
+
+ def __init__(
+ self, app: ASGIApp, handler: typing.Callable = None, debug: bool = False
+ ) -> None:
+ self.app = app
+ self.handler = handler
+ self.debug = debug
+
+ def __call__(self, scope: Scope) -> ASGIInstance:
+ if scope["type"] != "http":
+ return self.app(scope)
+ return functools.partial(self.asgi, scope=scope)
+
+ async def asgi(self, receive: Receive, send: Send, scope: Scope) -> None:
+ response_started = False
+
+ async def _send(message: Message) -> None:
+ nonlocal response_started, send
+
+ if message["type"] == "http.response.start":
+ response_started = True
+ await send(message)
+
+ try:
+ asgi = self.app(scope)
+ await asgi(receive, _send)
+ except Exception as exc:
+ if not response_started:
+ request = Request(scope)
+ if self.debug:
+ # In debug mode, return traceback responses.
+ response = self.debug_response(request, exc)
+ elif self.handler is None:
+ # Use our default 500 error handler.
+ response = self.error_response(request, exc)
+ else:
+ # Use an installed 500 error handler.
+ if asyncio.iscoroutinefunction(self.handler):
+ response = await self.handler(request, exc)
+ else:
+ response = await run_in_threadpool(self.handler, request, exc)
+
+ await response(receive, send)
+
+ # We always continue to raise the exception.
+ # This allows servers to log the error, or allows test clients
+ # to optionally raise the error within the test case.
+ raise exc from None
+
+ def debug_response(self, request: Request, exc: Exception) -> Response:
+ accept = request.headers.get("accept", "")
+ debug_gen = DebugGenerator(exc)
+
+ if "text/html" in accept:
+ content = debug_gen.generate_html()
+ return HTMLResponse(content, status_code=500)
+ content = debug_gen.generate_plain_text()
+ return PlainTextResponse(content, status_code=500)
+
+ def error_response(self, request: Request, exc: Exception) -> Response:
+ return PlainTextResponse("Internal Server Error", status_code=500)
diff --git a/tests/test_debug.py b/tests/middleware/test_errors.py
similarity index 69%
rename from tests/test_debug.py
rename to tests/middleware/test_errors.py
index 48f71162e..ccaaafb67 100644
--- a/tests/test_debug.py
+++ b/tests/middleware/test_errors.py
@@ -1,10 +1,27 @@
import pytest
-from starlette.debug import DebugMiddleware
-from starlette.responses import Response
+from starlette.middleware.errors import ServerErrorMiddleware
+from starlette.responses import JSONResponse, Response
from starlette.testclient import TestClient
+def test_handler():
+ def app(scope):
+ async def asgi(receive, send):
+ raise RuntimeError("Something went wrong")
+
+ return asgi
+
+ def error_500(request, exc):
+ return JSONResponse({"detail": "Server Error"}, status_code=500)
+
+ app = ServerErrorMiddleware(app, handler=error_500)
+ client = TestClient(app, raise_server_exceptions=False)
+ response = client.get("/")
+ assert response.status_code == 500
+ assert response.json() == {"detail": "Server Error"}
+
+
def test_debug_text():
def app(scope):
async def asgi(receive, send):
@@ -12,7 +29,7 @@ async def asgi(receive, send):
return asgi
- app = DebugMiddleware(app)
+ app = ServerErrorMiddleware(app, debug=True)
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 500
@@ -27,7 +44,7 @@ async def asgi(receive, send):
return asgi
- app = DebugMiddleware(app)
+ app = ServerErrorMiddleware(app, debug=True)
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/", headers={"Accept": "text/html, */*"})
assert response.status_code == 500
@@ -44,7 +61,7 @@ async def asgi(receive, send):
return asgi
- app = DebugMiddleware(app)
+ app = ServerErrorMiddleware(app, debug=True)
client = TestClient(app)
with pytest.raises(RuntimeError):
client.get("/")
@@ -54,7 +71,7 @@ def test_debug_error_during_scope():
def app(scope):
raise RuntimeError("Something went wrong")
- app = DebugMiddleware(app)
+ app = ServerErrorMiddleware(app, debug=True)
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/", headers={"Accept": "text/html, */*"})
assert response.status_code == 500
@@ -70,7 +87,7 @@ def test_debug_not_http():
def app(scope):
raise RuntimeError("Something went wrong")
- app = DebugMiddleware(app)
+ app = ServerErrorMiddleware(app)
with pytest.raises(RuntimeError):
app({"type": "websocket"})
diff --git a/tests/test_applications.py b/tests/test_applications.py
index 323914a07..ef8496c8b 100644
--- a/tests/test_applications.py
+++ b/tests/test_applications.py
@@ -29,13 +29,18 @@ def __call__(self, scope):
app.add_middleware(TrustedHostMiddleware, hostname="testserver")
-@app.exception_handler(Exception)
+@app.exception_handler(500)
async def error_500(request, exc):
return JSONResponse({"detail": "Server Error"}, status_code=500)
+@app.exception_handler(405)
+async def method_not_allowed(request, exc):
+ return JSONResponse({"detail": "Custom message"}, status_code=405)
+
+
@app.exception_handler(HTTPException)
-async def handler(request, exc):
+async def http_exception(request, exc):
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
@@ -136,11 +141,11 @@ def test_400():
def test_405():
response = client.post("/func")
assert response.status_code == 405
- assert response.json() == {"detail": "Method Not Allowed"}
+ assert response.json() == {"detail": "Custom message"}
response = client.post("/class")
assert response.status_code == 405
- assert response.json() == {"detail": "Method Not Allowed"}
+ assert response.json() == {"detail": "Custom message"}
def test_500():
diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py
index 51de7ff1b..a18ef3246 100644
--- a/tests/test_exceptions.py
+++ b/tests/test_exceptions.py
@@ -43,23 +43,23 @@ async def __call__(self, receive, send):
client = TestClient(app)
-def test_server_error():
- with pytest.raises(RuntimeError):
- response = client.get("/runtime_error")
-
- allow_500_client = TestClient(app, raise_server_exceptions=False)
- response = allow_500_client.get("/runtime_error")
- assert response.status_code == 500
- assert response.text == "Internal Server Error"
-
-
-def test_debug_enabled():
- app = ExceptionMiddleware(router)
- app.debug = True
- allow_500_client = TestClient(app, raise_server_exceptions=False)
- response = allow_500_client.get("/runtime_error")
- assert response.status_code == 500
- assert "RuntimeError" in response.text
+# def test_server_error():
+# with pytest.raises(RuntimeError):
+# response = client.get("/runtime_error")
+#
+# allow_500_client = TestClient(app, raise_server_exceptions=False)
+# response = allow_500_client.get("/runtime_error")
+# assert response.status_code == 500
+# assert response.text == "Internal Server Error"
+
+
+# def test_debug_enabled():
+# app = ExceptionMiddleware(router)
+# app.debug = True
+# allow_500_client = TestClient(app, raise_server_exceptions=False)
+# response = allow_500_client.get("/runtime_error")
+# assert response.status_code == 500
+# assert "RuntimeError" in response.text
def test_not_acceptable():