diff --git a/starlette/middleware/http.py b/starlette/middleware/http.py index aa16752639..4341458ccc 100644 --- a/starlette/middleware/http.py +++ b/starlette/middleware/http.py @@ -1,7 +1,7 @@ -from typing import AsyncGenerator, Callable, Optional, Union +from typing import Any, AsyncGenerator, Callable, Optional, Union from .._compat import aclosing -from ..datastructures import MutableHeaders +from ..datastructures import Headers from ..requests import HTTPConnection from ..responses import Response from ..types import ASGIApp, Message, Receive, Scope, Send @@ -73,8 +73,12 @@ async def wrapped_send(message: Message) -> None: if message["type"] == "http.response.start": response_started = True - response = Response(status_code=message["status"]) - response.raw_headers.clear() + headers = Headers(raw=message["headers"]) + response = _StubResponse( + status_code=message["status"], + media_type=headers.get("content-type"), + ) + response.raw_headers = headers.raw try: await flow.asend(response) @@ -83,9 +87,7 @@ async def wrapped_send(message: Message) -> None: else: raise RuntimeError("dispatch() should yield exactly once") - headers = MutableHeaders(raw=message["headers"]) - headers.update(response.headers) - message["headers"] = headers.raw + message["headers"] = response.raw_headers await send(message) @@ -111,3 +113,58 @@ async def wrapped_send(message: Message) -> None: await response(scope, receive, send) return + + +# This customized stub response helps prevent users from shooting themselves +# in the foot, doing things that don't actually have any effect. + + +class _StubResponse(Response): + def __init__(self, status_code: int, media_type: Optional[str] = None) -> None: + self._status_code = status_code + self._media_type = media_type + self.raw_headers = [] + + @property # type: ignore + def status_code(self) -> int: # type: ignore + return self._status_code + + @status_code.setter + def status_code(self, value: Any) -> None: + raise RuntimeError( + "Setting .status_code in HTTPMiddleware is not supported. " + "If you're writing middleware that requires modifying the response " + "status code or sending another response altogether, please consider " + "writing pure ASGI middleware. " + "See: https://starlette.io/middleware/#pure-asgi-middleware" + ) + + @property # type: ignore + def media_type(self) -> Optional[str]: # type: ignore + return self._media_type + + @media_type.setter + def media_type(self, value: Any) -> None: + raise RuntimeError( + "Setting .media_type in HTTPMiddleware is not supported, as it has " + "no effect. If you do need to tweak the response " + "content type, consider: response.headers['Content-Type'] = ..." + ) + + @property # type: ignore + def body(self) -> bytes: # type: ignore + raise RuntimeError( + "Accessing the response body in HTTPMiddleware is not supported. " + "If you're writing middleware that requires peeking into the response " + "body, please consider writing pure ASGI middleware and wrapping send(). " + "See: https://starlette.io/middleware/#pure-asgi-middleware" + ) + + @body.setter + def body(self, body: bytes) -> None: + raise RuntimeError( + "Setting the response body in HTTPMiddleware is not supported." + "If you're writing middleware that requires modifying the response " + "body, please consider writing pure ASGI middleware and wrapping send(). " + "See: https://starlette.io/middleware/#pure-asgi-middleware" + ) diff --git a/tests/middleware/test_http.py b/tests/middleware/test_http.py index cbe1c3ac73..935729cafc 100644 --- a/tests/middleware/test_http.py +++ b/tests/middleware/test_http.py @@ -127,8 +127,8 @@ async def dispatch( def test_early_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None: - async def index(request: Request) -> Response: - return PlainTextResponse("Hello, world!") + async def homepage(request: Request) -> Response: + return PlainTextResponse("OK") class CustomMiddleware(HTTPMiddleware): async def dispatch( @@ -140,14 +140,14 @@ async def dispatch( yield None app = Starlette( - routes=[Route("/", index)], + routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)], ) client = test_client_factory(app) response = client.get("/") assert response.status_code == 200 - assert response.text == "Hello, world!" + assert response.text == "OK" response = client.get("/", headers={"X-Early": "true"}) assert response.status_code == 401 @@ -202,7 +202,7 @@ def test_error_handling_must_send_response( class Failed(Exception): pass - async def index(request: Request) -> Response: + async def failure(request: Request) -> Response: raise Failed() class CustomMiddleware(HTTPMiddleware): @@ -215,19 +215,72 @@ async def dispatch( pass # `yield ` expected app = Starlette( - routes=[Route("/", index)], + routes=[Route("/fail", failure)], middleware=[Middleware(CustomMiddleware)], ) client = test_client_factory(app) with pytest.raises(RuntimeError, match="no response was returned"): - client.get("/") + client.get("/fail") def test_no_dispatch_given( test_client_factory: Callable[[ASGIApp], TestClient] ) -> None: app = Starlette(middleware=[Middleware(HTTPMiddleware)]) + client = test_client_factory(app) with pytest.raises(NotImplementedError, match="No dispatch implementation"): client.get("/") + + +def test_response_stub_attributes( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: + async def homepage(request: Request) -> Response: + return PlainTextResponse("OK") + + async def dispatch(conn: HTTPConnection) -> AsyncGenerator[None, Response]: + response = yield + if conn.url.path == "/status_code": + response.status_code = 401 + if conn.url.path == "/media_type": + response.media_type = "text/csv" + if conn.url.path == "/body-get": + response.body + if conn.url.path == "/body-set": + response.body = b"changed" + + app = Starlette( + routes=[ + Route("/status_code", homepage), + Route("/media_type", homepage), + Route("/body-get", homepage), + Route("/body-set", homepage), + ], + middleware=[Middleware(HTTPMiddleware, dispatch=dispatch)], + ) + + client = test_client_factory(app) + + with pytest.raises( + RuntimeError, match="Setting .status_code in HTTPMiddleware is not supported." + ): + client.get("/status_code") + + with pytest.raises( + RuntimeError, match="Setting .media_type in HTTPMiddleware is not supported" + ): + client.get("/media_type") + + with pytest.raises( + RuntimeError, + match="Accessing the response body in HTTPMiddleware is not supported", + ): + client.get("/body-get") + + with pytest.raises( + RuntimeError, + match="Setting the response body in HTTPMiddleware is not supported", + ): + client.get("/body-set")