Skip to content

Commit

Permalink
Attempt controlling response attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca committed Jun 15, 2022
1 parent 9e48c1f commit a18bb82
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 14 deletions.
71 changes: 64 additions & 7 deletions starlette/middleware/http.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import AsyncGenerator, Callable, Optional, Union
from typing import Any, AsyncGenerator, Callable, Optional, Union

from .._compat import aclosing
from ..datastructures import MutableHeaders
from ..datastructures import Headers
from ..requests import HTTPConnection
from ..responses import Response
from ..types import ASGIApp, Message, Receive, Scope, Send
Expand Down Expand Up @@ -73,8 +73,12 @@ async def wrapped_send(message: Message) -> None:
if message["type"] == "http.response.start":
response_started = True

response = Response(status_code=message["status"])
response.raw_headers.clear()
headers = Headers(raw=message["headers"])
response = _StubResponse(
status_code=message["status"],
media_type=headers.get("content-type"),
)
response.raw_headers = headers.raw

try:
await flow.asend(response)
Expand All @@ -83,9 +87,7 @@ async def wrapped_send(message: Message) -> None:
else:
raise RuntimeError("dispatch() should yield exactly once")

headers = MutableHeaders(raw=message["headers"])
headers.update(response.headers)
message["headers"] = headers.raw
message["headers"] = response.raw_headers

await send(message)

Expand All @@ -111,3 +113,58 @@ async def wrapped_send(message: Message) -> None:

await response(scope, receive, send)
return


# This customized stub response helps prevent users from shooting themselves
# in the foot, doing things that don't actually have any effect.


class _StubResponse(Response):
def __init__(self, status_code: int, media_type: Optional[str] = None) -> None:
self._status_code = status_code
self._media_type = media_type
self.raw_headers = []

@property # type: ignore
def status_code(self) -> int: # type: ignore
return self._status_code

@status_code.setter
def status_code(self, value: Any) -> None:
raise RuntimeError(
"Setting .status_code in HTTPMiddleware is not supported. "
"If you're writing middleware that requires modifying the response "
"status code or sending another response altogether, please consider "
"writing pure ASGI middleware. "
"See: https://starlette.io/middleware/#pure-asgi-middleware"
)

@property # type: ignore
def media_type(self) -> Optional[str]: # type: ignore
return self._media_type

@media_type.setter
def media_type(self, value: Any) -> None:
raise RuntimeError(
"Setting .media_type in HTTPMiddleware is not supported, as it has "
"no effect. If you do need to tweak the response "
"content type, consider: response.headers['Content-Type'] = ..."
)

@property # type: ignore
def body(self) -> bytes: # type: ignore
raise RuntimeError(
"Accessing the response body in HTTPMiddleware is not supported. "
"If you're writing middleware that requires peeking into the response "
"body, please consider writing pure ASGI middleware and wrapping send(). "
"See: https://starlette.io/middleware/#pure-asgi-middleware"
)

@body.setter
def body(self, body: bytes) -> None:
raise RuntimeError(
"Setting the response body in HTTPMiddleware is not supported."
"If you're writing middleware that requires modifying the response "
"body, please consider writing pure ASGI middleware and wrapping send(). "
"See: https://starlette.io/middleware/#pure-asgi-middleware"
)
67 changes: 60 additions & 7 deletions tests/middleware/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ async def dispatch(


def test_early_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None:
async def index(request: Request) -> Response:
return PlainTextResponse("Hello, world!")
async def homepage(request: Request) -> Response:
return PlainTextResponse("OK")

class CustomMiddleware(HTTPMiddleware):
async def dispatch(
Expand All @@ -140,14 +140,14 @@ async def dispatch(
yield None

app = Starlette(
routes=[Route("/", index)],
routes=[Route("/", homepage)],
middleware=[Middleware(CustomMiddleware)],
)

client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200
assert response.text == "Hello, world!"
assert response.text == "OK"
response = client.get("/", headers={"X-Early": "true"})
assert response.status_code == 401

Expand Down Expand Up @@ -202,7 +202,7 @@ def test_error_handling_must_send_response(
class Failed(Exception):
pass

async def index(request: Request) -> Response:
async def failure(request: Request) -> Response:
raise Failed()

class CustomMiddleware(HTTPMiddleware):
Expand All @@ -215,19 +215,72 @@ async def dispatch(
pass # `yield <response>` expected

app = Starlette(
routes=[Route("/", index)],
routes=[Route("/fail", failure)],
middleware=[Middleware(CustomMiddleware)],
)

client = test_client_factory(app)
with pytest.raises(RuntimeError, match="no response was returned"):
client.get("/")
client.get("/fail")


def test_no_dispatch_given(
test_client_factory: Callable[[ASGIApp], TestClient]
) -> None:
app = Starlette(middleware=[Middleware(HTTPMiddleware)])

client = test_client_factory(app)
with pytest.raises(NotImplementedError, match="No dispatch implementation"):
client.get("/")


def test_response_stub_attributes(
test_client_factory: Callable[[ASGIApp], TestClient]
) -> None:
async def homepage(request: Request) -> Response:
return PlainTextResponse("OK")

async def dispatch(conn: HTTPConnection) -> AsyncGenerator[None, Response]:
response = yield
if conn.url.path == "/status_code":
response.status_code = 401
if conn.url.path == "/media_type":
response.media_type = "text/csv"
if conn.url.path == "/body-get":
response.body
if conn.url.path == "/body-set":
response.body = b"changed"

app = Starlette(
routes=[
Route("/status_code", homepage),
Route("/media_type", homepage),
Route("/body-get", homepage),
Route("/body-set", homepage),
],
middleware=[Middleware(HTTPMiddleware, dispatch=dispatch)],
)

client = test_client_factory(app)

with pytest.raises(
RuntimeError, match="Setting .status_code in HTTPMiddleware is not supported."
):
client.get("/status_code")

with pytest.raises(
RuntimeError, match="Setting .media_type in HTTPMiddleware is not supported"
):
client.get("/media_type")

with pytest.raises(
RuntimeError,
match="Accessing the response body in HTTPMiddleware is not supported",
):
client.get("/body-get")

with pytest.raises(
RuntimeError,
match="Setting the response body in HTTPMiddleware is not supported",
):
client.get("/body-set")

0 comments on commit a18bb82

Please sign in to comment.