From d1b33779d8892216f08e1ae74739c828bd9e0373 Mon Sep 17 00:00:00 2001 From: Jean Hominal Date: Wed, 29 Jun 2022 23:59:41 +0200 Subject: [PATCH 1/6] replace BaseMiddleware cancellation after request send with closing recv_stream + http.disconnect in receive fixes #1438 --- starlette/middleware/base.py | 40 +++++++++- tests/middleware/test_base.py | 138 ++++++++++++++++++++++++++++++++++ 2 files changed, 175 insertions(+), 3 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 49a5e3e2d..586c9870d 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -4,12 +4,13 @@ from starlette.requests import Request from starlette.responses import Response, StreamingResponse -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import ASGIApp, Message, Receive, Scope, Send RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] DispatchFunction = typing.Callable[ [Request, RequestResponseEndpoint], typing.Awaitable[Response] ] +T = typing.TypeVar("T") class BaseHTTPMiddleware: @@ -24,19 +25,52 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) return + response_sent = anyio.Event() + async def call_next(request: Request) -> Response: app_exc: typing.Optional[Exception] = None send_stream, recv_stream = anyio.create_memory_object_stream() + async def receive_or_disconnect() -> Message: + if response_sent.is_set(): + return {"type": "http.disconnect"} + + async with anyio.create_task_group() as task_group: + + async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T: + result = await func() + task_group.cancel_scope.cancel() + return result + + task_group.start_soon(wrap, response_sent.wait) + message = await wrap(request.receive) + + if response_sent.is_set(): + return {"type": "http.disconnect"} + + return message + + async def close_recv_stream_on_response_sent() -> None: + await response_sent.wait() + recv_stream.close() + + async def send_no_error(message: Message) -> None: + try: + await send_stream.send(message) + except anyio.BrokenResourceError: + # recv_stream has been closed, i.e. response_sent has been set. + return + async def coro() -> None: nonlocal app_exc async with send_stream: try: - await self.app(scope, request.receive, send_stream.send) + await self.app(scope, receive_or_disconnect, send_no_error) except Exception as exc: app_exc = exc + task_group.start_soon(close_recv_stream_on_response_sent) task_group.start_soon(coro) try: @@ -71,7 +105,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: request = Request(scope, receive=receive) response = await self.dispatch_func(request, call_next) await response(scope, receive, send) - task_group.cancel_scope.cancel() + response_sent.set() async def dispatch( self, request: Request, call_next: RequestResponseEndpoint diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 976d77b86..5dac98d39 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,8 +1,10 @@ import contextvars +import anyio import pytest from starlette.applications import Starlette +from starlette.background import BackgroundTask from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import PlainTextResponse, StreamingResponse @@ -206,3 +208,139 @@ async def homepage(request): client = test_client_factory(app) response = client.get("/") assert response.status_code == 200, response.content + + +@pytest.mark.anyio +async def test_run_background_tasks_even_if_client_disconnects(): + # test for https://github.com/encode/starlette/issues/1438 + request_body_sent = False + response_complete = anyio.Event() + background_task_run = anyio.Event() + + async def sleep_and_set(): + # small delay to give BaseHTTPMiddleware a chance to cancel us + # this is required to make the test fail prior to fixing the issue + # so do not be surprised if you remove it and the test still passes + await anyio.sleep(0.1) + background_task_run.set() + + async def endpoint_with_background_task(_): + return PlainTextResponse(background=BackgroundTask(sleep_and_set)) + + async def passthrough(request, call_next): + return await call_next(request) + + app = Starlette( + middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)], + routes=[Route("/", endpoint_with_background_task)], + ) + + scope = { + "type": "http", + "version": "3", + "method": "GET", + "path": "/", + } + + async def receive(): + nonlocal request_body_sent + if not request_body_sent: + request_body_sent = True + return {"type": "http.request", "body": b"", "more_body": False} + # We simulate a client that disconnects immediately after receiving the response + await response_complete.wait() + return {"type": "http.disconnect"} + + async def send(message): + if message["type"] == "http.response.body": + if not message.get("more_body", False): + response_complete.set() + + await app(scope, receive, send) + + assert background_task_run.is_set() + + +def test_app_receives_http_disconnect_while_sending_if_discarded(test_client_factory): + class DiscardingMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + await call_next(request) + return PlainTextResponse("Custom") + + async def downstream_app(scope, receive, send): + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"text/plain"), + ], + } + ) + async with anyio.create_task_group() as task_group: + + async def cancel_on_disconnect(): + while True: + message = await receive() + if message["type"] == "http.disconnect": + task_group.cancel_scope.cancel() + break + + task_group.start_soon(cancel_on_disconnect) + + await send( + { + "type": "http.response.body", + "body": b"chunk", + "more_body": True, + } + ) + pytest.fail( + "http.disconnect should have been received and canceled the scope" + ) + + app = DiscardingMiddleware(downstream_app) + + client = test_client_factory(app) + response = client.get("/does_not_exist") + assert response.text == "Custom" + + +def test_app_receives_http_disconnect_after_sending_if_discarded(test_client_factory): + class DiscardingMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + await call_next(request) + return PlainTextResponse("Custom") + + async def downstream_app(scope, receive, send): + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"text/plain"), + ], + } + ) + await send( + { + "type": "http.response.body", + "body": b"first chunk, ", + "more_body": True, + } + ) + await send( + { + "type": "http.response.body", + "body": b"second chunk", + "more_body": True, + } + ) + message = await receive() + assert message["type"] == "http.disconnect" + + app = DiscardingMiddleware(downstream_app) + + client = test_client_factory(app) + response = client.get("/does_not_exist") + assert response.text == "Custom" From 7751ddc2209b601d0c14432e4bd38675df2a6b28 Mon Sep 17 00:00:00 2001 From: Jean Hominal Date: Sat, 2 Jul 2022 07:26:52 +0200 Subject: [PATCH 2/6] Add no cover pragma on pytest.fail in tests/middleware/test_base.py Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> --- tests/middleware/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 5dac98d39..7519fdf1d 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -297,7 +297,7 @@ async def cancel_on_disconnect(): ) pytest.fail( "http.disconnect should have been received and canceled the scope" - ) + ) # pragma: no cover app = DiscardingMiddleware(downstream_app) From 137776bbba9b7f43957dd7d15de442fa549714c1 Mon Sep 17 00:00:00 2001 From: Jean Hominal Date: Sat, 2 Jul 2022 08:07:41 +0200 Subject: [PATCH 3/6] make http_disconnect_while_sending test more robust in the face of scheduling issues --- tests/middleware/test_base.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 7519fdf1d..8c110ca5d 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -288,13 +288,18 @@ async def cancel_on_disconnect(): task_group.start_soon(cancel_on_disconnect) - await send( - { - "type": "http.response.body", - "body": b"chunk", - "more_body": True, - } - ) + # A timeout is set for 0.1 second in order to ensure that + # cancel_on_disconnect is scheduled by the event loop + with anyio.move_on_after(0.1): + while True: + await send( + { + "type": "http.response.body", + "body": b"chunk ", + "more_body": True, + } + ) + pytest.fail( "http.disconnect should have been received and canceled the scope" ) # pragma: no cover From 10628438f013f892cc49101b7a2fade1a7993948 Mon Sep 17 00:00:00 2001 From: Jean Hominal Date: Wed, 6 Jul 2022 22:26:00 +0200 Subject: [PATCH 4/6] Fix issue with running middleware context manager Reported in https://github.com/encode/starlette/issues/1678#issuecomment-1172916042 --- tests/middleware/test_base.py | 64 +++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 8c110ca5d..ed0734bd3 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,4 +1,5 @@ import contextvars +from contextlib import AsyncExitStack import anyio import pytest @@ -261,6 +262,69 @@ async def send(message): assert background_task_run.is_set() +@pytest.mark.anyio +async def test_run_context_manager_exit_even_if_client_disconnects(): + # test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042 + request_body_sent = False + response_complete = anyio.Event() + context_manager_exited = anyio.Event() + + async def sleep_and_set(): + # small delay to give BaseHTTPMiddleware a chance to cancel us + # this is required to make the test fail prior to fixing the issue + # so do not be surprised if you remove it and the test still passes + await anyio.sleep(0.1) + context_manager_exited.set() + + class ContextManagerMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + async with AsyncExitStack() as stack: + stack.push_async_callback(sleep_and_set) + await self.app(scope, receive, send) + + async def simple_endpoint(_): + return PlainTextResponse(background=BackgroundTask(sleep_and_set)) + + async def passthrough(request, call_next): + return await call_next(request) + + app = Starlette( + middleware=[ + Middleware(BaseHTTPMiddleware, dispatch=passthrough), + Middleware(ContextManagerMiddleware), + ], + routes=[Route("/", simple_endpoint)], + ) + + scope = { + "type": "http", + "version": "3", + "method": "GET", + "path": "/", + } + + async def receive(): + nonlocal request_body_sent + if not request_body_sent: + request_body_sent = True + return {"type": "http.request", "body": b"", "more_body": False} + # We simulate a client that disconnects immediately after receiving the response + await response_complete.wait() + return {"type": "http.disconnect"} + + async def send(message): + if message["type"] == "http.response.body": + if not message.get("more_body", False): + response_complete.set() + + await app(scope, receive, send) + + assert context_manager_exited.is_set() + + def test_app_receives_http_disconnect_while_sending_if_discarded(test_client_factory): class DiscardingMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): From b98d6a1b547f3e1ed171c1f30181ee74014db2fd Mon Sep 17 00:00:00 2001 From: Jean Hominal Date: Sun, 3 Jul 2022 10:57:11 +0200 Subject: [PATCH 5/6] reorganize BaseHTTPMiddleware tasks so that app is in the same task as __call__ This has the following effects: * contextvars set in dispatch are not visible in endpoints * contextvars set in endpoints are visible upstream --- starlette/middleware/base.py | 102 +++++++++++++++++++--------------- tests/middleware/test_base.py | 58 ++++++++++++++++++- 2 files changed, 112 insertions(+), 48 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 586c9870d..4608fdd1c 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -25,11 +25,58 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) return + dispatch_first_phase_ended = anyio.Event() + streams_ready = anyio.Event() + request_for_next: typing.Optional[Request] = None response_sent = anyio.Event() + app_exc: typing.Optional[Exception] = None async def call_next(request: Request) -> Response: - app_exc: typing.Optional[Exception] = None - send_stream, recv_stream = anyio.create_memory_object_stream() + nonlocal request_for_next + request_for_next = request + dispatch_first_phase_ended.set() + await streams_ready.wait() + + try: + message = await recv_stream.receive() + except anyio.EndOfStream: + if app_exc is not None: + raise app_exc + raise RuntimeError("No response returned.") + + assert message["type"] == "http.response.start" + + async def body_stream() -> typing.AsyncGenerator[bytes, None]: + async with recv_stream: + async for message in recv_stream: + assert message["type"] == "http.response.body" + body = message.get("body", b"") + if body: + yield body + if not message.get("more_body", False): + break + + if app_exc is not None: + raise app_exc + + response = StreamingResponse( + status_code=message["status"], content=body_stream() + ) + response.raw_headers = message["headers"] + return response + + async def process_dispatch(request: Request): + response = await self.dispatch_func(request, call_next) + await response(scope, receive, send) + dispatch_first_phase_ended.set() + response_sent.set() + + async with anyio.create_task_group() as task_group: + task_group.start_soon(process_dispatch, Request(scope, receive=receive)) + + await dispatch_first_phase_ended.wait() + if request_for_next is None: + return async def receive_or_disconnect() -> Message: if response_sent.is_set(): @@ -43,7 +90,7 @@ async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T: return result task_group.start_soon(wrap, response_sent.wait) - message = await wrap(request.receive) + message = await wrap(request_for_next.receive) if response_sent.is_set(): return {"type": "http.disconnect"} @@ -61,51 +108,16 @@ async def send_no_error(message: Message) -> None: # recv_stream has been closed, i.e. response_sent has been set. return - async def coro() -> None: - nonlocal app_exc - - async with send_stream: - try: - await self.app(scope, receive_or_disconnect, send_no_error) - except Exception as exc: - app_exc = exc + send_stream, recv_stream = anyio.create_memory_object_stream() + streams_ready.set() task_group.start_soon(close_recv_stream_on_response_sent) - task_group.start_soon(coro) - - try: - message = await recv_stream.receive() - except anyio.EndOfStream: - if app_exc is not None: - raise app_exc - raise RuntimeError("No response returned.") - assert message["type"] == "http.response.start" - - async def body_stream() -> typing.AsyncGenerator[bytes, None]: - async with recv_stream: - async for message in recv_stream: - assert message["type"] == "http.response.body" - body = message.get("body", b"") - if body: - yield body - if not message.get("more_body", False): - break - - if app_exc is not None: - raise app_exc - - response = StreamingResponse( - status_code=message["status"], content=body_stream() - ) - response.raw_headers = message["headers"] - return response - - async with anyio.create_task_group() as task_group: - request = Request(scope, receive=receive) - response = await self.dispatch_func(request, call_next) - await response(scope, receive, send) - response_sent.set() + async with send_stream: + try: + await self.app(scope, receive_or_disconnect, send_no_error) + except Exception as exc: + app_exc = exc async def dispatch( self, request: Request, call_next: RequestResponseEndpoint diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index ed0734bd3..3505fd9d6 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -7,8 +7,9 @@ from starlette.applications import Starlette from starlette.background import BackgroundTask from starlette.middleware import Middleware -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import PlainTextResponse, StreamingResponse +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.requests import Request +from starlette.responses import PlainTextResponse, StreamingResponse, Response from starlette.routing import Route, WebSocketRoute from starlette.types import ASGIApp, Receive, Scope, Send @@ -198,7 +199,7 @@ def test_contextvars(test_client_factory, middleware_cls: type): # on sync endpoints which has it's own set of peculiarities w.r.t propagating # contextvars (it propagates them forwards but not backwards) async def homepage(request): - assert ctxvar.get() == "set by middleware" + assert ctxvar.get("unset") == "set by middleware" ctxvar.set("set by endpoint") return PlainTextResponse("Homepage") @@ -211,6 +212,57 @@ async def homepage(request): assert response.status_code == 200, response.content +class TransparentASGIMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + return await self.app(scope, receive, send) + + +class TransparentBaseHTTPMiddleware(BaseHTTPMiddleware): + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + return await call_next(request) + + +@pytest.mark.parametrize( + "middleware_cls", + [ + TransparentASGIMiddleware, + TransparentBaseHTTPMiddleware, + ], +) +@pytest.mark.anyio +async def test_endpoint_contextvars_available_upstream(middleware_cls: type): + # this has to be an async endpoint because Starlette calls run_in_threadpool + # on sync endpoints which has it's own set of peculiarities w.r.t propagating + # contextvars (it propagates them forwards but not backwards) + async def homepage(request): + ctxvar.set("set by endpoint") + return PlainTextResponse("Homepage") + + app = Starlette( + middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)] + ) + + request_body_sent = False + scope = {"type": "http", "method": "GET", "path": "/"} + + async def receive(): + nonlocal request_body_sent + if not request_body_sent: + request_body_sent = True + return {"type": "http.request"} + await anyio.sleep_forever() + + async def send(message): + pass + + await app(scope, receive, send) + assert ctxvar.get("unset") == "set by endpoint" + + @pytest.mark.anyio async def test_run_background_tasks_even_if_client_disconnects(): # test for https://github.com/encode/starlette/issues/1438 From 0532d37afc72991245420c0412886766d0b9705b Mon Sep 17 00:00:00 2001 From: Jean Hominal Date: Thu, 7 Jul 2022 23:23:50 +0200 Subject: [PATCH 6/6] Copy contextvars set in BaseHTTPMiddleware in the __call__ context That allows context var modifications from dispatch to be visible upstream and downstream. --- starlette/middleware/base.py | 16 +++++++++++++++- tests/middleware/test_base.py | 2 +- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 4608fdd1c..54d272cce 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -1,3 +1,4 @@ +import contextvars import typing import anyio @@ -30,10 +31,12 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: request_for_next: typing.Optional[Request] = None response_sent = anyio.Event() app_exc: typing.Optional[Exception] = None + dispatch_context_copy: typing.Optional[contextvars.Context] = None async def call_next(request: Request) -> Response: - nonlocal request_for_next + nonlocal request_for_next, dispatch_context_copy request_for_next = request + dispatch_context_copy = contextvars.copy_context() dispatch_first_phase_ended.set() await streams_ready.wait() @@ -66,8 +69,10 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: return response async def process_dispatch(request: Request): + nonlocal dispatch_context_copy response = await self.dispatch_func(request, call_next) await response(scope, receive, send) + dispatch_context_copy = contextvars.copy_context() dispatch_first_phase_ended.set() response_sent.set() @@ -75,6 +80,15 @@ async def process_dispatch(request: Request): task_group.start_soon(process_dispatch, Request(scope, receive=receive)) await dispatch_first_phase_ended.wait() + + # Copy contextvars updated from dispatch into the current context. + for context_var, dispatch_context_value in dispatch_context_copy.items(): + try: + if context_var.get() is not dispatch_context_value: + context_var.set(dispatch_context_value) + except LookupError: + context_var.set(dispatch_context_value) + if request_for_next is None: return diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 3505fd9d6..4244e1fa9 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -199,7 +199,7 @@ def test_contextvars(test_client_factory, middleware_cls: type): # on sync endpoints which has it's own set of peculiarities w.r.t propagating # contextvars (it propagates them forwards but not backwards) async def homepage(request): - assert ctxvar.get("unset") == "set by middleware" + assert ctxvar.get() == "set by middleware" ctxvar.set("set by endpoint") return PlainTextResponse("Homepage")