Skip to content

Commit

Permalink
Allow to raise HTTPException before websocket.accept() (#2725)
Browse files Browse the repository at this point in the history
* Allow to raise `HTTPException` before `websocket.accept()`

* move <<

* Add documentation
  • Loading branch information
Kludex authored Oct 15, 2024
1 parent 4ded4b7 commit 99b6938
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 54 deletions.
26 changes: 21 additions & 5 deletions docs/exceptions.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,30 @@ In order to deal with this behaviour correctly, the middleware stack of a

## HTTPException

The `HTTPException` class provides a base class that you can use for any
handled exceptions. The `ExceptionMiddleware` implementation defaults to
returning plain-text HTTP responses for any `HTTPException`.
The `HTTPException` class provides a base class that you can use for any handled exceptions.
The `ExceptionMiddleware` implementation defaults to returning plain-text HTTP responses for any `HTTPException`.

* `HTTPException(status_code, detail=None, headers=None)`

You should only raise `HTTPException` inside routing or endpoints. Middleware
classes should instead just return appropriate responses directly.
You should only raise `HTTPException` inside routing or endpoints.
Middleware classes should instead just return appropriate responses directly.

You can use an `HTTPException` on a WebSocket endpoint in case it's raised before `websocket.accept()`.
The connection is not upgraded to a WebSocket connection, and the proper HTTP response is returned.

```python
from starlette.applications import Starlette
from starlette.exceptions import HTTPException
from starlette.routing import WebSocketRoute
from starlette.websockets import WebSocket


async def websocket_endpoint(websocket: WebSocket):
raise HTTPException(status_code=400, detail="Bad request")


app = Starlette(routes=[WebSocketRoute("/ws", websocket_endpoint)])
```

## WebSocketException

Expand Down
36 changes: 8 additions & 28 deletions starlette/_exception_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,7 @@
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.types import (
ASGIApp,
ExceptionHandler,
HTTPExceptionHandler,
Message,
Receive,
Scope,
Send,
WebSocketExceptionHandler,
)
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
from starlette.websockets import WebSocket

ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler]
Expand Down Expand Up @@ -62,24 +53,13 @@ async def sender(message: Message) -> None:
raise exc

if response_started:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc

if scope["type"] == "http":
nonlocal conn
handler = typing.cast(HTTPExceptionHandler, handler)
conn = typing.cast(Request, conn)
if is_async_callable(handler):
response = await handler(conn, exc)
else:
response = await run_in_threadpool(handler, conn, exc)
raise RuntimeError("Caught handled exception, but response already started.") from exc

if is_async_callable(handler):
response = await handler(conn, exc)
else:
response = await run_in_threadpool(handler, conn, exc) # type: ignore
if response is not None:
await response(scope, receive, sender)
elif scope["type"] == "websocket":
handler = typing.cast(WebSocketExceptionHandler, handler)
conn = typing.cast(WebSocket, conn)
if is_async_callable(handler):
await handler(conn, exc)
else:
await run_in_threadpool(handler, conn, exc)

return wrapped_app
6 changes: 1 addition & 5 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,7 @@ def _raise_on_close(self, message: Message) -> None:
body.append(message["body"])
if not message.get("more_body", False):
break
raise WebSocketDenialResponse(
status_code=status_code,
headers=headers,
content=b"".join(body),
)
raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))

def send(self, message: Message) -> None:
self._receive_queue.put(message)
Expand Down
24 changes: 19 additions & 5 deletions tests/test_applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
from typing import AsyncGenerator, AsyncIterator, Generator

import anyio
import anyio.from_thread
import pytest

from starlette import status
Expand All @@ -17,7 +17,7 @@
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
from starlette.testclient import TestClient
from starlette.testclient import TestClient, WebSocketDenialResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket
from tests.types import TestClientFactory
Expand Down Expand Up @@ -71,11 +71,15 @@ async def websocket_endpoint(session: WebSocket) -> None:
await session.close()


async def websocket_raise_websocket(websocket: WebSocket) -> None:
async def websocket_raise_websocket_exception(websocket: WebSocket) -> None:
await websocket.accept()
raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA)


async def websocket_raise_http_exception(websocket: WebSocket) -> None:
raise HTTPException(status_code=401, detail="Unauthorized")


class CustomWSException(Exception):
pass

Expand Down Expand Up @@ -118,7 +122,8 @@ def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) ->
Route("/class", endpoint=Homepage),
Route("/500", endpoint=runtime_error),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket),
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception),
WebSocketRoute("/ws-raise-http", endpoint=websocket_raise_http_exception),
WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
Mount("/users", app=users),
Host("{subdomain}.example.org", app=subdomain),
Expand Down Expand Up @@ -219,6 +224,14 @@ def test_websocket_raise_websocket_exception(client: TestClient) -> None:
}


def test_websocket_raise_http_exception(client: TestClient) -> None:
with pytest.raises(WebSocketDenialResponse) as exc:
with client.websocket_connect("/ws-raise-http"):
pass # pragma: no cover
assert exc.value.status_code == 401
assert exc.value.content == b'{"detail":"Unauthorized"}'


def test_websocket_raise_custom_exception(client: TestClient) -> None:
with client.websocket_connect("/ws-raise-custom") as session:
response = session.receive()
Expand All @@ -243,7 +256,8 @@ def test_routes() -> None:
Route("/class", endpoint=Homepage),
Route("/500", endpoint=runtime_error, methods=["GET"]),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket),
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception),
WebSocketRoute("/ws-raise-http", endpoint=websocket_raise_http_exception),
WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
Mount(
"/users",
Expand Down
15 changes: 4 additions & 11 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
Route("/with_headers", endpoint=with_headers),
Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()),
WebSocketRoute("/runtime_error", endpoint=raise_runtime_error),
Route(
"/consume_body_in_endpoint_and_handler",
endpoint=read_body_and_raise_exc,
methods=["POST"],
),
Route("/consume_body_in_endpoint_and_handler", endpoint=read_body_and_raise_exc, methods=["POST"]),
]
)

Expand Down Expand Up @@ -114,13 +110,10 @@ def test_websockets_should_raise(client: TestClient) -> None:
pass # pragma: no cover


def test_handled_exc_after_response(
test_client_factory: TestClientFactory,
client: TestClient,
) -> None:
def test_handled_exc_after_response(test_client_factory: TestClientFactory, client: TestClient) -> None:
# A 406 HttpException is raised *after* the response has already been sent.
# The exception middleware should raise a RuntimeError.
with pytest.raises(RuntimeError):
with pytest.raises(RuntimeError, match="Caught handled exception, but response already started."):
client.get("/handled_exc_after_response")

# If `raise_server_exceptions=False` then the test client will still allow
Expand All @@ -132,7 +125,7 @@ def test_handled_exc_after_response(


def test_force_500_response(test_client_factory: TestClientFactory) -> None:
# use a sentinal variable to make sure we actually
# use a sentinel variable to make sure we actually
# make it into the endpoint and don't get a 500
# from an incorrect ASGI app signature or something
called = False
Expand Down

0 comments on commit 99b6938

Please sign in to comment.