Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace task cancellation in BaseHTTPMiddleware with http.disconnect+recv_stream.close #1715

Merged
merged 8 commits into from
Sep 24, 2022
40 changes: 37 additions & 3 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, Message, Receive, Scope, Send

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
]
T = typing.TypeVar("T")


class BaseHTTPMiddleware:
Expand All @@ -24,19 +25,52 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return

response_sent = anyio.Event()

async def call_next(request: Request) -> Response:
app_exc: typing.Optional[Exception] = None
send_stream, recv_stream = anyio.create_memory_object_stream()

async def receive_or_disconnect() -> Message:
if response_sent.is_set():
return {"type": "http.disconnect"}

async with anyio.create_task_group() as task_group:

async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
result = await func()
task_group.cancel_scope.cancel()
return result

task_group.start_soon(wrap, response_sent.wait)
message = await wrap(request.receive)
Comment on lines +38 to +46
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What this is doing is saying "wait for a message from the client but if response_sent gets set in the meantime then stop waiting/reading from the client and move on"

Copy link
Member Author

@jhominal jhominal Sep 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the issue that I want to solve here is, if the downstream app is waiting on receive, but as the response is sent (likely by the middleware), the downstream app cannot send anything meaningfully, so there is no point in letting downstream wait for another message from upstream.

We could also choose to rely on upstream receive returning a http.disconnect message when the response is sent (which should happen), but when I wrote that bit, I thought that a belt-and-braces approach would be better.

However, that approach does mean that every call to receive from an app gets an intermediary anyio.TaskGroup for each BaseHTTPMiddleware in the middleware chain.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other words, I would be open to modifying that part to remove the receive wrapper to avoid that cost if it thought to be too much.


if response_sent.is_set():
return {"type": "http.disconnect"}

return message

async def close_recv_stream_on_response_sent() -> None:
await response_sent.wait()
recv_stream.close()

async def send_no_error(message: Message) -> None:
try:
await send_stream.send(message)
except anyio.BrokenResourceError:
# recv_stream has been closed, i.e. response_sent has been set.
return

async def coro() -> None:
nonlocal app_exc

async with send_stream:
try:
await self.app(scope, request.receive, send_stream.send)
await self.app(scope, receive_or_disconnect, send_no_error)
except Exception as exc:
app_exc = exc

task_group.start_soon(close_recv_stream_on_response_sent)
task_group.start_soon(coro)

try:
Expand Down Expand Up @@ -71,7 +105,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)
task_group.cancel_scope.cancel()
response_sent.set()

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
Expand Down
207 changes: 207 additions & 0 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import contextvars
from contextlib import AsyncExitStack

import anyio
import pytest

from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse, StreamingResponse
Expand Down Expand Up @@ -206,3 +209,207 @@ async def homepage(request):
client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200, response.content


@pytest.mark.anyio
async def test_run_background_tasks_even_if_client_disconnects():
# test for https://github.com/encode/starlette/issues/1438
request_body_sent = False
response_complete = anyio.Event()
background_task_run = anyio.Event()

async def sleep_and_set():
# small delay to give BaseHTTPMiddleware a chance to cancel us
# this is required to make the test fail prior to fixing the issue
# so do not be surprised if you remove it and the test still passes
await anyio.sleep(0.1)
background_task_run.set()

async def endpoint_with_background_task(_):
return PlainTextResponse(background=BackgroundTask(sleep_and_set))

async def passthrough(request, call_next):
return await call_next(request)

app = Starlette(
middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)],
routes=[Route("/", endpoint_with_background_task)],
)

scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
}

async def receive():
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
# We simulate a client that disconnects immediately after receiving the response
await response_complete.wait()
return {"type": "http.disconnect"}

async def send(message):
if message["type"] == "http.response.body":
if not message.get("more_body", False):
response_complete.set()

await app(scope, receive, send)

assert background_task_run.is_set()


@pytest.mark.anyio
async def test_run_context_manager_exit_even_if_client_disconnects():
# test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042
request_body_sent = False
response_complete = anyio.Event()
context_manager_exited = anyio.Event()

async def sleep_and_set():
# small delay to give BaseHTTPMiddleware a chance to cancel us
# this is required to make the test fail prior to fixing the issue
# so do not be surprised if you remove it and the test still passes
await anyio.sleep(0.1)
context_manager_exited.set()

class ContextManagerMiddleware:
def __init__(self, app):
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send):
async with AsyncExitStack() as stack:
stack.push_async_callback(sleep_and_set)
await self.app(scope, receive, send)

async def simple_endpoint(_):
return PlainTextResponse(background=BackgroundTask(sleep_and_set))

async def passthrough(request, call_next):
return await call_next(request)

app = Starlette(
middleware=[
Middleware(BaseHTTPMiddleware, dispatch=passthrough),
Middleware(ContextManagerMiddleware),
],
routes=[Route("/", simple_endpoint)],
)

scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
}

async def receive():
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
# We simulate a client that disconnects immediately after receiving the response
await response_complete.wait()
return {"type": "http.disconnect"}

async def send(message):
if message["type"] == "http.response.body":
if not message.get("more_body", False):
response_complete.set()

await app(scope, receive, send)

assert context_manager_exited.is_set()


def test_app_receives_http_disconnect_while_sending_if_discarded(test_client_factory):
class DiscardingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
await call_next(request)
return PlainTextResponse("Custom")

async def downstream_app(scope, receive, send):
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/plain"),
],
}
)
async with anyio.create_task_group() as task_group:

async def cancel_on_disconnect():
while True:
message = await receive()
if message["type"] == "http.disconnect":
task_group.cancel_scope.cancel()
break

task_group.start_soon(cancel_on_disconnect)

# A timeout is set for 0.1 second in order to ensure that
# cancel_on_disconnect is scheduled by the event loop
with anyio.move_on_after(0.1):
while True:
await send(
{
"type": "http.response.body",
"body": b"chunk ",
"more_body": True,
}
)

pytest.fail(
"http.disconnect should have been received and canceled the scope"
) # pragma: no cover

app = DiscardingMiddleware(downstream_app)

client = test_client_factory(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"


def test_app_receives_http_disconnect_after_sending_if_discarded(test_client_factory):
class DiscardingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
await call_next(request)
return PlainTextResponse("Custom")

async def downstream_app(scope, receive, send):
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/plain"),
],
}
)
await send(
{
"type": "http.response.body",
"body": b"first chunk, ",
"more_body": True,
}
)
await send(
{
"type": "http.response.body",
"body": b"second chunk",
"more_body": True,
}
)
message = await receive()
assert message["type"] == "http.disconnect"

app = DiscardingMiddleware(downstream_app)

client = test_client_factory(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"