Skip to content

Commit

Permalink
Add WebSocketException and support for WS handlers (#1263)
Browse files Browse the repository at this point in the history
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
Co-authored-by: Tom Christie <tom@tomchristie.com>
  • Loading branch information
4 people authored Sep 5, 2022
1 parent 9386bcf commit d525431
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 13 deletions.
19 changes: 19 additions & 0 deletions docs/exceptions.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ async def http_exception(request: Request, exc: HTTPException):
)
```

You might also want to override how `WebSocketException` is handled:

```python
async def websocket_exception(websocket: WebSocket, exc: WebSocketException):
await websocket.close(code=1008)

exception_handlers = {
WebSocketException: websocket_exception
}
```

## Errors and handled exceptions

It is important to differentiate between handled exceptions and errors.
Expand Down Expand Up @@ -112,3 +123,11 @@ returning plain-text HTTP responses for any `HTTPException`.

You should only raise `HTTPException` inside routing or endpoints. Middleware
classes should instead just return appropriate responses directly.

## WebSocketException

You can use the `WebSocketException` class to raise errors inside of WebSocket endpoints.

* `WebSocketException(code=1008, reason=None)`

You can set any code valid as defined [in the specification](https://tools.ietf.org/html/rfc6455#section-7.4.1).
12 changes: 11 additions & 1 deletion starlette/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typing
import warnings

__all__ = ("HTTPException",)
__all__ = ("HTTPException", "WebSocketException")


class HTTPException(Exception):
Expand All @@ -23,6 +23,16 @@ def __repr__(self) -> str:
return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})"


class WebSocketException(Exception):
def __init__(self, code: int, reason: typing.Optional[str] = None) -> None:
self.code = code
self.reason = reason or ""

def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}(code={self.code!r}, reason={self.reason!r})"


__deprecated__ = "ExceptionMiddleware"


Expand Down
34 changes: 25 additions & 9 deletions starlette/middleware/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.exceptions import HTTPException, WebSocketException
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket


class ExceptionMiddleware:
Expand All @@ -22,7 +23,10 @@ def __init__(
self._status_handlers: typing.Dict[int, typing.Callable] = {}
self._exception_handlers: typing.Dict[
typing.Type[Exception], typing.Callable
] = {HTTPException: self.http_exception}
] = {
HTTPException: self.http_exception,
WebSocketException: self.websocket_exception,
}
if handlers is not None:
for key, value in handlers.items():
self.add_exception_handler(key, value)
Expand All @@ -47,7 +51,7 @@ def _lookup_exception_handler(
return None

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
if scope["type"] not in ("http", "websocket"):
await self.app(scope, receive, send)
return

Expand Down Expand Up @@ -78,16 +82,28 @@ async def sender(message: Message) -> None:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc

request = Request(scope, receive=receive)
if is_async_callable(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)
await response(scope, receive, sender)
if scope["type"] == "http":
request = Request(scope, receive=receive)
if is_async_callable(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)
await response(scope, receive, sender)
elif scope["type"] == "websocket":
websocket = WebSocket(scope, receive=receive, send=send)
if is_async_callable(handler):
await handler(websocket, exc)
else:
await run_in_threadpool(handler, websocket, exc)

def http_exception(self, request: Request, exc: HTTPException) -> Response:
if exc.status_code in {204, 304}:
return Response(status_code=exc.status_code, headers=exc.headers)
return PlainTextResponse(
exc.detail, status_code=exc.status_code, headers=exc.headers
)

async def websocket_exception(
self, websocket: WebSocket, exc: WebSocketException
) -> None:
await websocket.close(code=exc.code, reason=exc.reason)
48 changes: 47 additions & 1 deletion tests/test_applications.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import os
from contextlib import asynccontextmanager

import anyio
import pytest

from starlette import status
from starlette.applications import Starlette
from starlette.endpoints import HTTPEndpoint
from starlette.exceptions import HTTPException
from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware import Middleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
from starlette.websockets import WebSocket


async def error_500(request, exc):
Expand Down Expand Up @@ -61,6 +64,24 @@ async def websocket_endpoint(session):
await session.close()


async def websocket_raise_websocket(websocket: WebSocket):
await websocket.accept()
raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA)


class CustomWSException(Exception):
pass


async def websocket_raise_custom(websocket: WebSocket):
await websocket.accept()
raise CustomWSException()


def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException):
anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER)


users = Router(
routes=[
Route("/", endpoint=all_users_page),
Expand All @@ -78,6 +99,7 @@ async def websocket_endpoint(session):
500: error_500,
405: method_not_allowed,
HTTPException: http_exception,
CustomWSException: custom_ws_exception_handler,
}

middleware = [
Expand All @@ -91,6 +113,8 @@ async def websocket_endpoint(session):
Route("/class", endpoint=Homepage),
Route("/500", endpoint=runtime_error),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket),
WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
Mount("/users", app=users),
Host("{subdomain}.example.org", app=subdomain),
],
Expand Down Expand Up @@ -180,6 +204,26 @@ def test_500(test_client_factory):
assert response.json() == {"detail": "Server Error"}


def test_websocket_raise_websocket_exception(client):
with client.websocket_connect("/ws-raise-websocket") as session:
response = session.receive()
assert response == {
"type": "websocket.close",
"code": status.WS_1003_UNSUPPORTED_DATA,
"reason": "",
}


def test_websocket_raise_custom_exception(client):
with client.websocket_connect("/ws-raise-custom") as session:
response = session.receive()
assert response == {
"type": "websocket.close",
"code": status.WS_1013_TRY_AGAIN_LATER,
"reason": "",
}


def test_middleware(test_client_factory):
client = test_client_factory(app, base_url="http://incorrecthost")
response = client.get("/func")
Expand All @@ -194,6 +238,8 @@ def test_routes():
Route("/class", endpoint=Homepage),
Route("/500", endpoint=runtime_error, methods=["GET"]),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket),
WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
Mount(
"/users",
app=Router(
Expand Down
18 changes: 16 additions & 2 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from starlette.exceptions import HTTPException
from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.responses import PlainTextResponse
from starlette.routing import Route, Router, WebSocketRoute
Expand Down Expand Up @@ -119,7 +119,7 @@ async def app(scope, receive, send):
assert response.text == ""


def test_repr():
def test_http_repr():
assert repr(HTTPException(404)) == (
"HTTPException(status_code=404, detail='Not Found')"
)
Expand All @@ -135,6 +135,20 @@ class CustomHTTPException(HTTPException):
)


def test_websocket_repr():
assert repr(WebSocketException(1008, reason="Policy Violation")) == (
"WebSocketException(code=1008, reason='Policy Violation')"
)

class CustomWebSocketException(WebSocketException):
pass

assert (
repr(CustomWebSocketException(1013, reason="Something custom"))
== "CustomWebSocketException(code=1013, reason='Something custom')"
)


def test_exception_middleware_deprecation() -> None:
# this test should be removed once the deprecation shim is removed
with pytest.warns(DeprecationWarning):
Expand Down

0 comments on commit d525431

Please sign in to comment.