diff --git a/starlette/applications.py b/starlette/applications.py index 913fd4c9d..076b4d25f 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -11,6 +11,7 @@ from starlette.datastructures import State, URLPath from starlette.middleware import Middleware, _MiddlewareClass +from starlette.middleware.background import BackgroundTaskMiddleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.errors import ServerErrorMiddleware from starlette.middleware.exceptions import ExceptionMiddleware @@ -96,6 +97,7 @@ def build_middleware_stack(self) -> ASGIApp: middleware = ( [Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)] + + [Middleware(BackgroundTaskMiddleware)] + self.user_middleware + [ Middleware( diff --git a/starlette/middleware/background.py b/starlette/middleware/background.py new file mode 100644 index 000000000..13e18f049 --- /dev/null +++ b/starlette/middleware/background.py @@ -0,0 +1,37 @@ +from typing import List, cast + +from starlette.background import BackgroundTask +from starlette.types import ASGIApp, Receive, Scope, Send + +# consider this a private implementation detail subject to change +# do not rely on this key +_SCOPE_KEY = "starlette._background" + + +_BackgroundTaskList = List[BackgroundTask] + + +def is_background_task_middleware_installed(scope: Scope) -> bool: + return _SCOPE_KEY in scope + + +def add_tasks(scope: Scope, task: BackgroundTask, /) -> None: + if _SCOPE_KEY not in scope: # pragma: no cover + raise RuntimeError( + "`add_tasks` can only be used if `BackgroundTaskMIddleware is installed" + ) + cast(_BackgroundTaskList, scope[_SCOPE_KEY]).append(task) + + +class BackgroundTaskMiddleware: + def __init__(self, app: ASGIApp) -> None: + self._app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + tasks: _BackgroundTaskList + scope[_SCOPE_KEY] = tasks = [] + try: + await self._app(scope, receive, send) + finally: + for task in tasks: + await task() diff --git a/starlette/responses.py b/starlette/responses.py index a6975747b..911c604e5 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -19,6 +19,7 @@ from starlette.background import BackgroundTask from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, MutableHeaders +from starlette.middleware import background from starlette.types import Receive, Scope, Send @@ -148,6 +149,12 @@ def delete_cookie( ) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if ( + self.background is not None + and background.is_background_task_middleware_installed(scope) + ): + background.add_tasks(scope, self.background) + self.background = None prefix = "websocket." if scope["type"] == "websocket" else "" await send( { @@ -255,6 +262,12 @@ async def stream_response(self, send: Send) -> None: await send({"type": "http.response.body", "body": b"", "more_body": False}) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if ( + self.background is not None + and background.is_background_task_middleware_installed(scope) + ): + background.add_tasks(scope, self.background) + self.background = None async with anyio.create_task_group() as task_group: async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: @@ -322,6 +335,12 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None: self.headers.setdefault("etag", etag) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if ( + self.background is not None + and background.is_background_task_middleware_installed(scope) + ): + background.add_tasks(scope, self.background) + self.background = None if self.stat_result is None: try: stat_result = await anyio.to_thread.run_sync(os.stat, self.path) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 2176404d8..3860c1fcb 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -5,8 +5,8 @@ from typing import ( Any, AsyncGenerator, - Callable, Generator, + Literal, ) import anyio @@ -14,8 +14,9 @@ from anyio.abc import TaskStatus from starlette.applications import Starlette -from starlette.background import BackgroundTask +from starlette.background import BackgroundTask, BackgroundTasks from starlette.middleware import Middleware, _MiddlewareClass +from starlette.middleware.background import BackgroundTaskMiddleware from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.responses import PlainTextResponse, Response, StreamingResponse @@ -23,8 +24,7 @@ from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket - -TestClientFactory = Callable[[ASGIApp], TestClient] +from tests.conftest import TestClientFactory class CustomMiddleware(BaseHTTPMiddleware): @@ -372,8 +372,8 @@ async def send(message: Message) -> None: {"body": b"Hello", "more_body": True, "type": "http.response.body"}, {"body": b"", "more_body": False, "type": "http.response.body"}, "Background task started", - "Background task started", "Background task finished", + "Background task started", "Background task finished", ] @@ -1035,3 +1035,97 @@ async def endpoint(request: Request) -> Response: resp.raise_for_status() assert bodies == [b"Hello, World!-foo"] + + +@pytest.mark.anyio +async def test_background_tasks_client_disconnect() -> None: + # test for https://github.com/encode/starlette/issues/1438 + container: list[str] = [] + + disconnected = anyio.Event() + + async def slow_background() -> None: + # small delay to give BaseHTTPMiddleware a chance to cancel us + # this is required to make the test fail prior to fixing the issue + # so do not be surprised if you remove it and the test still passes + await anyio.sleep(0.1) + container.append("called") + + app: ASGIApp + app = PlainTextResponse("hi!", background=BackgroundTask(slow_background)) + + async def dispatch( + request: Request, call_next: RequestResponseEndpoint + ) -> Response: + return await call_next(request) + + app = BaseHTTPMiddleware(app, dispatch=dispatch) + + app = BackgroundTaskMiddleware(app) + + async def recv_gen() -> AsyncGenerator[Message, None]: + yield {"type": "http.request"} + await disconnected.wait() + while True: + yield {"type": "http.disconnect"} + + async def send_gen() -> AsyncGenerator[None, Message]: + while True: + msg = yield + if msg["type"] == "http.response.body" and not msg.get("more_body", False): + disconnected.set() + + scope = {"type": "http", "method": "GET", "path": "/"} + + async with AsyncExitStack() as stack: + recv = recv_gen() + stack.push_async_callback(recv.aclose) + send = send_gen() + stack.push_async_callback(send.aclose) + await send.__anext__() + await app(scope, recv.__aiter__().__anext__, send.asend) + + assert container == ["called"] + + +@pytest.mark.anyio +async def test_background_tasks_failure( + test_client_factory: TestClientFactory, + anyio_backend_name: Literal["asyncio", "trio"], +) -> None: + if anyio_backend_name == "trio": + pytest.skip("this test hangs with trio") + + # test for https://github.com/encode/starlette/discussions/2640 + container: list[str] = [] + + async def task1() -> None: + container.append("task1 called") + raise ValueError("task1 failed") + + async def task2() -> None: + container.append("task2 called") # pragma: no cover + + async def endpoint(request: Request) -> Response: + background = BackgroundTasks() + background.add_task(task1) + background.add_task(task2) + return PlainTextResponse("hi!", background=background) + + async def dispatch( + request: Request, call_next: RequestResponseEndpoint + ) -> Response: + return await call_next(request) + + app = Starlette( + routes=[Route("/", endpoint)], + middleware=[Middleware(BaseHTTPMiddleware, dispatch=dispatch)], + ) + + client = test_client_factory(app, raise_server_exceptions=False) + + response = client.get("/") + assert response.status_code == 200 + assert response.text == "hi!" + + assert container == ["task1 called"] diff --git a/tests/test_background.py b/tests/test_background.py index 846deecfd..2ebe75a8a 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -1,40 +1,100 @@ -from typing import Callable +from __future__ import annotations + +from tempfile import NamedTemporaryFile +from typing import Any, AsyncIterable, Callable import pytest from starlette.background import BackgroundTask, BackgroundTasks -from starlette.responses import Response +from starlette.middleware.background import BackgroundTaskMiddleware +from starlette.responses import FileResponse, Response, StreamingResponse from starlette.testclient import TestClient -from starlette.types import Receive, Scope, Send +from starlette.types import ASGIApp, Receive, Scope, Send -TestClientFactory = Callable[..., TestClient] +TestClientFactory = Callable[[ASGIApp], TestClient] -def test_async_task(test_client_factory: TestClientFactory) -> None: - TASK_COMPLETE = False +@pytest.fixture( + params=[[], [BackgroundTaskMiddleware]], + ids=["without BackgroundTaskMiddleware", "with BackgroundTaskMiddleware"], +) +def test_client_factory_mw( + test_client_factory: TestClientFactory, request: Any +) -> TestClientFactory: + mw_stack: list[Callable[[ASGIApp], ASGIApp]] = request.param - async def async_task() -> None: - nonlocal TASK_COMPLETE - TASK_COMPLETE = True + def client_factory(app: ASGIApp) -> TestClient: + for mw in mw_stack: + app = mw(app) + return test_client_factory(app) + + return client_factory - task = BackgroundTask(async_task) +def response_app_factory(task: BackgroundTask) -> ASGIApp: async def app(scope: Scope, receive: Receive, send: Send) -> None: - response = Response("task initiated", media_type="text/plain", background=task) + response = Response(b"task initiated", media_type="text/plain", background=task) await response(scope, receive, send) - client = test_client_factory(app) + return app + + +def file_response_app_factory(task: BackgroundTask) -> ASGIApp: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + with NamedTemporaryFile("wb+") as f: + f.write(b"task initiated") + f.seek(0) + response = FileResponse(f.name, media_type="text/plain", background=task) + await response(scope, receive, send) + + return app + + +def streaming_response_app_factory(task: BackgroundTask) -> ASGIApp: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + async def stream() -> AsyncIterable[bytes]: + yield b"task initiated" + + response = StreamingResponse(stream(), media_type="text/plain", background=task) + await response(scope, receive, send) + + return app + + +@pytest.mark.parametrize( + "app_factory", + [ + response_app_factory, + streaming_response_app_factory, + file_response_app_factory, + ], +) +def test_async_task( + test_client_factory_mw: TestClientFactory, + app_factory: Callable[[BackgroundTask], ASGIApp], +) -> None: + task_complete = False + + async def async_task() -> None: + nonlocal task_complete + task_complete = True + + task = BackgroundTask(async_task) + + app = app_factory(task) + + client = test_client_factory_mw(app) response = client.get("/") assert response.text == "task initiated" - assert TASK_COMPLETE + assert task_complete def test_sync_task(test_client_factory: TestClientFactory) -> None: - TASK_COMPLETE = False + task_complete = False def sync_task() -> None: - nonlocal TASK_COMPLETE - TASK_COMPLETE = True + nonlocal task_complete + task_complete = True task = BackgroundTask(sync_task) @@ -45,15 +105,15 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: client = test_client_factory(app) response = client.get("/") assert response.text == "task initiated" - assert TASK_COMPLETE + assert task_complete def test_multiple_tasks(test_client_factory: TestClientFactory) -> None: - TASK_COUNTER = 0 + task_counter = 0 def increment(amount: int) -> None: - nonlocal TASK_COUNTER - TASK_COUNTER += amount + nonlocal task_counter + task_counter += amount async def app(scope: Scope, receive: Receive, send: Send) -> None: tasks = BackgroundTasks() @@ -68,18 +128,18 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: client = test_client_factory(app) response = client.get("/") assert response.text == "tasks initiated" - assert TASK_COUNTER == 1 + 2 + 3 + assert task_counter == 1 + 2 + 3 def test_multi_tasks_failure_avoids_next_execution( test_client_factory: TestClientFactory, ) -> None: - TASK_COUNTER = 0 + task_counter = 0 def increment() -> None: - nonlocal TASK_COUNTER - TASK_COUNTER += 1 - if TASK_COUNTER == 1: + nonlocal task_counter + task_counter += 1 + if task_counter == 1: raise Exception("task failed") async def app(scope: Scope, receive: Receive, send: Send) -> None: @@ -94,4 +154,4 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: client = test_client_factory(app) with pytest.raises(Exception): client.get("/") - assert TASK_COUNTER == 1 + assert task_counter == 1 diff --git a/tests/test_responses.py b/tests/test_responses.py index 434cc5a22..a1df737b8 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -13,6 +13,7 @@ from starlette import status from starlette.background import BackgroundTask from starlette.datastructures import Headers +from starlette.middleware.background import BackgroundTaskMiddleware from starlette.requests import Request from starlette.responses import ( FileResponse, @@ -126,7 +127,7 @@ async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None: await response(scope, receive, send) assert filled_by_bg_task == "" - client = test_client_factory(app) + client = test_client_factory(BackgroundTaskMiddleware(app)) response = client.get("/") assert response.text == "1, 2, 3, 4, 5" assert filled_by_bg_task == "6, 7, 8, 9" @@ -152,7 +153,7 @@ async def __anext__(self) -> str: response = StreamingResponse(CustomAsyncIterator(), media_type="text/plain") await response(scope, receive, send) - client = test_client_factory(app) + client = test_client_factory(BackgroundTaskMiddleware(app)) response = client.get("/") assert response.text == "12345" @@ -245,7 +246,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await response(scope, receive, send) assert filled_by_bg_task == "" - client = test_client_factory(app) + client = test_client_factory(BackgroundTaskMiddleware(app)) response = client.get("/") expected_disposition = 'attachment; filename="example.png"' assert response.status_code == status.HTTP_200_OK