Skip to content

Commit

Permalink
Replace task cancellation in BaseHTTPMiddleware with `http.disconne…
Browse files Browse the repository at this point in the history
…ct`+`recv_stream.close` (encode#1715)

* replace BaseMiddleware cancellation after request send with closing recv_stream + http.disconnect in receive

fixes encode#1438

* Add no cover pragma on pytest.fail in tests/middleware/test_base.py

Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>

* make http_disconnect_while_sending test more robust in the face of scheduling issues

* Fix issue with running middleware context manager

Reported in encode#1678 (comment)

Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
  • Loading branch information
3 people authored Sep 24, 2022
1 parent 70971ea commit 040d8c8
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 3 deletions.
40 changes: 37 additions & 3 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, Message, Receive, Scope, Send

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
]
T = typing.TypeVar("T")


class BaseHTTPMiddleware:
Expand All @@ -24,19 +25,52 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return

response_sent = anyio.Event()

async def call_next(request: Request) -> Response:
app_exc: typing.Optional[Exception] = None
send_stream, recv_stream = anyio.create_memory_object_stream()

async def receive_or_disconnect() -> Message:
if response_sent.is_set():
return {"type": "http.disconnect"}

async with anyio.create_task_group() as task_group:

async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
result = await func()
task_group.cancel_scope.cancel()
return result

task_group.start_soon(wrap, response_sent.wait)
message = await wrap(request.receive)

if response_sent.is_set():
return {"type": "http.disconnect"}

return message

async def close_recv_stream_on_response_sent() -> None:
await response_sent.wait()
recv_stream.close()

async def send_no_error(message: Message) -> None:
try:
await send_stream.send(message)
except anyio.BrokenResourceError:
# recv_stream has been closed, i.e. response_sent has been set.
return

async def coro() -> None:
nonlocal app_exc

async with send_stream:
try:
await self.app(scope, request.receive, send_stream.send)
await self.app(scope, receive_or_disconnect, send_no_error)
except Exception as exc:
app_exc = exc

task_group.start_soon(close_recv_stream_on_response_sent)
task_group.start_soon(coro)

try:
Expand Down Expand Up @@ -71,7 +105,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)
task_group.cancel_scope.cancel()
response_sent.set()

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
Expand Down
207 changes: 207 additions & 0 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import contextvars
from contextlib import AsyncExitStack

import anyio
import pytest

from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse, StreamingResponse
Expand Down Expand Up @@ -206,3 +209,207 @@ async def homepage(request):
client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200, response.content


@pytest.mark.anyio
async def test_run_background_tasks_even_if_client_disconnects():
# test for https://github.com/encode/starlette/issues/1438
request_body_sent = False
response_complete = anyio.Event()
background_task_run = anyio.Event()

async def sleep_and_set():
# small delay to give BaseHTTPMiddleware a chance to cancel us
# this is required to make the test fail prior to fixing the issue
# so do not be surprised if you remove it and the test still passes
await anyio.sleep(0.1)
background_task_run.set()

async def endpoint_with_background_task(_):
return PlainTextResponse(background=BackgroundTask(sleep_and_set))

async def passthrough(request, call_next):
return await call_next(request)

app = Starlette(
middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)],
routes=[Route("/", endpoint_with_background_task)],
)

scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
}

async def receive():
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
# We simulate a client that disconnects immediately after receiving the response
await response_complete.wait()
return {"type": "http.disconnect"}

async def send(message):
if message["type"] == "http.response.body":
if not message.get("more_body", False):
response_complete.set()

await app(scope, receive, send)

assert background_task_run.is_set()


@pytest.mark.anyio
async def test_run_context_manager_exit_even_if_client_disconnects():
# test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042
request_body_sent = False
response_complete = anyio.Event()
context_manager_exited = anyio.Event()

async def sleep_and_set():
# small delay to give BaseHTTPMiddleware a chance to cancel us
# this is required to make the test fail prior to fixing the issue
# so do not be surprised if you remove it and the test still passes
await anyio.sleep(0.1)
context_manager_exited.set()

class ContextManagerMiddleware:
def __init__(self, app):
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send):
async with AsyncExitStack() as stack:
stack.push_async_callback(sleep_and_set)
await self.app(scope, receive, send)

async def simple_endpoint(_):
return PlainTextResponse(background=BackgroundTask(sleep_and_set))

async def passthrough(request, call_next):
return await call_next(request)

app = Starlette(
middleware=[
Middleware(BaseHTTPMiddleware, dispatch=passthrough),
Middleware(ContextManagerMiddleware),
],
routes=[Route("/", simple_endpoint)],
)

scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
}

async def receive():
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
# We simulate a client that disconnects immediately after receiving the response
await response_complete.wait()
return {"type": "http.disconnect"}

async def send(message):
if message["type"] == "http.response.body":
if not message.get("more_body", False):
response_complete.set()

await app(scope, receive, send)

assert context_manager_exited.is_set()


def test_app_receives_http_disconnect_while_sending_if_discarded(test_client_factory):
class DiscardingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
await call_next(request)
return PlainTextResponse("Custom")

async def downstream_app(scope, receive, send):
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/plain"),
],
}
)
async with anyio.create_task_group() as task_group:

async def cancel_on_disconnect():
while True:
message = await receive()
if message["type"] == "http.disconnect":
task_group.cancel_scope.cancel()
break

task_group.start_soon(cancel_on_disconnect)

# A timeout is set for 0.1 second in order to ensure that
# cancel_on_disconnect is scheduled by the event loop
with anyio.move_on_after(0.1):
while True:
await send(
{
"type": "http.response.body",
"body": b"chunk ",
"more_body": True,
}
)

pytest.fail(
"http.disconnect should have been received and canceled the scope"
) # pragma: no cover

app = DiscardingMiddleware(downstream_app)

client = test_client_factory(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"


def test_app_receives_http_disconnect_after_sending_if_discarded(test_client_factory):
class DiscardingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
await call_next(request)
return PlainTextResponse("Custom")

async def downstream_app(scope, receive, send):
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/plain"),
],
}
)
await send(
{
"type": "http.response.body",
"body": b"first chunk, ",
"more_body": True,
}
)
await send(
{
"type": "http.response.body",
"body": b"second chunk",
"more_body": True,
}
)
message = await receive()
assert message["type"] == "http.disconnect"

app = DiscardingMiddleware(downstream_app)

client = test_client_factory(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"

0 comments on commit 040d8c8

Please sign in to comment.