Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move BackgroundTask execution outside of middleware stack #1700

Closed
wants to merge 19 commits into from
Closed
2 changes: 2 additions & 0 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from starlette.datastructures import State, URLPath
from starlette.middleware import Middleware
from starlette.middleware.background import BackgroundTaskMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
Expand Down Expand Up @@ -90,6 +91,7 @@ def build_middleware_stack(self) -> ASGIApp:

middleware = (
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
+ [Middleware(BackgroundTaskMiddleware)]
+ self.user_middleware
+ [
Middleware(
Expand Down
18 changes: 18 additions & 0 deletions starlette/middleware/background.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import List

from starlette.background import BackgroundTask
from starlette.types import ASGIApp, Receive, Scope, Send


class BackgroundTaskMiddleware:
def __init__(self, app: ASGIApp) -> None:
self._app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
tasks: "List[BackgroundTask]"
scope["starlette.background"] = tasks = []
try:
await self._app(scope, receive, send)
finally:
for task in tasks:
await task()
15 changes: 15 additions & 0 deletions starlette/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ def delete_cookie(
)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.background is not None and "starlette.background" in scope:
tasks: "typing.List[BackgroundTask]" = scope["starlette.background"]
tasks.append(self.background)
self.background = None

await send(
{
"type": "http.response.start",
Expand Down Expand Up @@ -263,6 +268,11 @@ async def stream_response(self, send: Send) -> None:
await send({"type": "http.response.body", "body": b"", "more_body": False})

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.background is not None and "starlette.background" in scope:
tasks: "typing.List[BackgroundTask]" = scope["starlette.background"]
tasks.append(self.background)
self.background = None

async with anyio.create_task_group() as task_group:

async def wrap(func: "typing.Callable[[], typing.Awaitable[None]]") -> None:
Expand Down Expand Up @@ -326,6 +336,11 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None:
self.headers.setdefault("etag", etag)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.background is not None and "starlette.background" in scope:
tasks: "typing.List[BackgroundTask]" = scope["starlette.background"]
tasks.append(self.background)
self.background = None

if self.stat_result is None:
try:
stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
Expand Down
89 changes: 87 additions & 2 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import contextvars
from contextlib import AsyncExitStack
from typing import AsyncGenerator, Awaitable, Callable, List

import anyio
import pytest

from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware import Middleware
from starlette.middleware.background import BackgroundTaskMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Message, Receive, Scope, Send


class CustomMiddleware(BaseHTTPMiddleware):
Expand Down Expand Up @@ -206,3 +213,81 @@ async def homepage(request):
client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200, response.content


@pytest.mark.anyio
async def test_background_tasks_client_disconnect() -> None:
# test for https://github.com/encode/starlette/issues/1438
container: List[str] = []

disconnected = anyio.Event()

async def slow_background() -> None:
await disconnected.wait()
container.append("called")
adriangb marked this conversation as resolved.
Show resolved Hide resolved

app: ASGIApp
app = PlainTextResponse("hi!", background=BackgroundTask(slow_background))

async def dispatch(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
return await call_next(request)

app = BaseHTTPMiddleware(app, dispatch=dispatch)

app = BackgroundTaskMiddleware(app)

async def recv_gen() -> AsyncGenerator[Message, None]:
yield {"type": "http.request"}
disconnected.set()
yield {"type": "http.disconnect"}

async def send_gen() -> AsyncGenerator[None, Message]:
msg = yield
assert msg["type"] == "http.response.start"
await disconnected.wait()
raise AssertionError("Should not be called") # pragma: no cover
adriangb marked this conversation as resolved.
Show resolved Hide resolved

scope = {"type": "http", "method": "GET", "path": "/"}

async with AsyncExitStack() as stack:
recv = recv_gen()
stack.push_async_callback(recv.aclose)
send = send_gen()
stack.push_async_callback(send.aclose)
await send.__anext__()
await app(scope, recv.__aiter__().__anext__, send.asend)

assert container == ["called"]


def test_background_tasks(test_client_factory: Callable[[ASGIApp], TestClient]) -> None:
# test for https://github.com/encode/starlette/issues/919
container: List[str] = []

async def slow_task() -> None:
container.append("started")
# small delay to give BaseHTTPMiddleware a chance to cancel us
# this is required to make the test fail prior to fixing the issue
# so do not be surprised if you remove it and the test still passes
await anyio.sleep(0.1)
container.append("finished")

async def dispatch(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
return await call_next(request)

async def endpoint(request: Request) -> Response:
return Response(background=BackgroundTask(slow_task))

app = Starlette(
routes=[Route("/", endpoint)],
middleware=[Middleware(BaseHTTPMiddleware, dispatch=dispatch)],
)

client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200, response.content
assert container == ["started", "finished"]
81 changes: 71 additions & 10 deletions tests/test_background.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,76 @@
from typing import Callable
from tempfile import NamedTemporaryFile
from typing import Any, AsyncIterable, Callable, List

import pytest

from starlette.background import BackgroundTask, BackgroundTasks
from starlette.responses import Response
from starlette.middleware.background import BackgroundTaskMiddleware
from starlette.responses import FileResponse, Response, StreamingResponse
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Receive, Scope, Send

TestClientFactory = Callable[[ASGIApp], TestClient]

def test_async_task(test_client_factory):

@pytest.fixture(
params=[[], [BackgroundTaskMiddleware]],
ids=["without BackgroundTaskMiddleware", "with BackgroundTaskMiddleware"],
)
def test_client_factory_mw(
test_client_factory: TestClientFactory, request: Any
) -> TestClientFactory:
mw_stack: List[Callable[[ASGIApp], ASGIApp]] = request.param

def client_factory(app: ASGIApp) -> TestClient:
for mw in mw_stack:
app = mw(app)
return test_client_factory(app)

return client_factory


def response_app_factory(task: BackgroundTask) -> ASGIApp:
async def app(scope: Scope, receive: Receive, send: Send):
response = Response(b"task initiated", media_type="text/plain", background=task)
await response(scope, receive, send)

return app


def file_response_app_factory(task: BackgroundTask) -> ASGIApp:
async def app(scope: Scope, receive: Receive, send: Send):
with NamedTemporaryFile("wb+") as f:
f.write(b"task initiated")
f.seek(0)
response = FileResponse(f.name, media_type="text/plain", background=task)
await response(scope, receive, send)

return app


def streaming_response_app_factory(task: BackgroundTask) -> ASGIApp:
async def app(scope: Scope, receive: Receive, send: Send):
async def stream() -> AsyncIterable[bytes]:
yield b"task initiated"

response = StreamingResponse(stream(), media_type="text/plain", background=task)
await response(scope, receive, send)

return app
Comment on lines +15 to +59
Copy link
Member Author

@adriangb adriangb Jul 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is all just to get 100% code coverage with both code paths. If we removed the old code path (i.e. backgroundTask does not work without the middleware) there would be (almost) no changes to these tests and it would look a lot like https://github.com/xpresso-devs/asgi-background



@pytest.mark.parametrize(
"app_factory",
[
response_app_factory,
streaming_response_app_factory,
file_response_app_factory,
],
)
def test_async_task(
test_client_factory_mw: TestClientFactory,
app_factory: Callable[[BackgroundTask], ASGIApp],
):
TASK_COMPLETE = False

async def async_task():
Expand All @@ -16,17 +79,15 @@ async def async_task():

task = BackgroundTask(async_task)

async def app(scope, receive, send):
response = Response("task initiated", media_type="text/plain", background=task)
await response(scope, receive, send)
app = app_factory(task)

client = test_client_factory(app)
client = test_client_factory_mw(app)
response = client.get("/")
assert response.text == "task initiated"
assert TASK_COMPLETE


def test_sync_task(test_client_factory):
def test_sync_task(test_client_factory: TestClientFactory):
TASK_COMPLETE = False

def sync_task():
Expand All @@ -45,7 +106,7 @@ async def app(scope, receive, send):
assert TASK_COMPLETE


def test_multiple_tasks(test_client_factory: Callable[..., TestClient]):
def test_multiple_tasks(test_client_factory: TestClientFactory):
TASK_COUNTER = 0

def increment(amount):
Expand All @@ -69,7 +130,7 @@ async def app(scope, receive, send):


def test_multi_tasks_failure_avoids_next_execution(
test_client_factory: Callable[..., TestClient]
test_client_factory: TestClientFactory,
) -> None:
TASK_COUNTER = 0

Expand Down
7 changes: 4 additions & 3 deletions tests/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from starlette import status
from starlette.background import BackgroundTask
from starlette.middleware.background import BackgroundTaskMiddleware
from starlette.requests import Request
from starlette.responses import (
FileResponse,
Expand Down Expand Up @@ -113,7 +114,7 @@ async def numbers_for_cleanup(start=1, stop=5):
await response(scope, receive, send)

assert filled_by_bg_task == ""
client = test_client_factory(app)
client = test_client_factory(BackgroundTaskMiddleware(app))
response = client.get("/")
assert response.text == "1, 2, 3, 4, 5"
assert filled_by_bg_task == "6, 7, 8, 9"
Expand All @@ -137,7 +138,7 @@ async def __anext__(self):
response = StreamingResponse(CustomAsyncIterator(), media_type="text/plain")
await response(scope, receive, send)

client = test_client_factory(app)
client = test_client_factory(BackgroundTaskMiddleware(app))
response = client.get("/")
assert response.text == "12345"

Expand Down Expand Up @@ -228,7 +229,7 @@ async def app(scope, receive, send):
await response(scope, receive, send)

assert filled_by_bg_task == ""
client = test_client_factory(app)
client = test_client_factory(BackgroundTaskMiddleware(app))
response = client.get("/")
expected_disposition = 'attachment; filename="example.png"'
assert response.status_code == status.HTTP_200_OK
Expand Down