Skip to content

Commit

Permalink
Shield bg task in Response
Browse files Browse the repository at this point in the history
  • Loading branch information
kigawas committed Jun 22, 2022
1 parent 3785217 commit 02aa6eb
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 20 deletions.
11 changes: 4 additions & 7 deletions starlette/background.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import sys
import typing

import anyio

if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
Expand All @@ -24,11 +22,10 @@ def __init__(
self.is_async = is_async_callable(func)

async def __call__(self) -> None:
with anyio.CancelScope(shield=True):
if self.is_async:
await self.func(*self.args, **self.kwargs)
else:
await run_in_threadpool(self.func, *self.args, **self.kwargs)
if self.is_async:
await self.func(*self.args, **self.kwargs)
else:
await run_in_threadpool(self.func, *self.args, **self.kwargs)


class BackgroundTasks(BackgroundTask):
Expand Down
9 changes: 2 additions & 7 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def coro() -> None:

async with send_stream:
try:
# may block in send if body_stream not consumed
# `send_stream.send` blocks until `body_stream` is called
await self.app(scope, request.receive, send_stream.send)
except Exception as exc:
app_exc = exc
Expand Down Expand Up @@ -72,12 +72,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)

t = anyio.get_current_task()
if t.name == "anyio.from_thread.BlockingPortal._call_func":
# cancel stuck task due to discarded response
# see: https://github.com/encode/starlette/issues/1022
task_group.cancel_scope.cancel()
task_group.cancel_scope.cancel()

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
Expand Down
16 changes: 10 additions & 6 deletions starlette/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
)
await send({"type": "http.response.body", "body": self.body})

if self.background is not None:
await self.background()
with anyio.CancelScope(shield=True):
if self.background is not None:
await self.background()


class HTMLResponse(Response):
Expand Down Expand Up @@ -264,8 +265,9 @@ async def wrap(func: typing.Callable[[], typing.Coroutine]) -> None:
task_group.start_soon(wrap, partial(self.stream_response, send))
await wrap(partial(self.listen_for_disconnect, receive))

if self.background is not None:
await self.background()
with anyio.CancelScope(shield=True):
if self.background is not None:
await self.background()


class FileResponse(Response):
Expand Down Expand Up @@ -350,5 +352,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"more_body": more_body,
}
)
if self.background is not None:
await self.background()

with anyio.CancelScope(shield=True):
if self.background is not None:
await self.background()
8 changes: 8 additions & 0 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,18 @@ async def _sleep(identifier, delay):
await anyio.sleep(delay)
print(identifier, "completed")

def _sleep_sync(identifier, delay):
import time

print(identifier, "started")
time.sleep(delay)
print(identifier, "completed")

async def bg_task(request):
background_tasks = BackgroundTasks()
background_tasks.add_task(_sleep, "background task 1", 2)
background_tasks.add_task(_sleep, "background task 2", 2)
background_tasks.add_task(_sleep_sync, "background task sync", 2)
return Response(background=background_tasks)

app = Starlette(
Expand Down

0 comments on commit 02aa6eb

Please sign in to comment.