Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reorganize BaseHTTPMiddleware tasks so that app is in the same task as __call__ #1

Draft
wants to merge 6 commits into
base: base-http-middleware-no-cancellation
Choose a base branch
from
94 changes: 77 additions & 17 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import contextvars
import typing

import anyio

from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, Message, Receive, Scope, Send

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
]
T = typing.TypeVar("T")


class BaseHTTPMiddleware:
Expand All @@ -24,20 +26,19 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return

async def call_next(request: Request) -> Response:
app_exc: typing.Optional[Exception] = None
send_stream, recv_stream = anyio.create_memory_object_stream()

async def coro() -> None:
nonlocal app_exc
dispatch_first_phase_ended = anyio.Event()
streams_ready = anyio.Event()
request_for_next: typing.Optional[Request] = None
response_sent = anyio.Event()
app_exc: typing.Optional[Exception] = None
dispatch_context_copy: typing.Optional[contextvars.Context] = None

async with send_stream:
try:
await self.app(scope, request.receive, send_stream.send)
except Exception as exc:
app_exc = exc

task_group.start_soon(coro)
async def call_next(request: Request) -> Response:
nonlocal request_for_next, dispatch_context_copy
request_for_next = request
dispatch_context_copy = contextvars.copy_context()
dispatch_first_phase_ended.set()
await streams_ready.wait()

try:
message = await recv_stream.receive()
Expand Down Expand Up @@ -67,11 +68,70 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
response.raw_headers = message["headers"]
return response

async with anyio.create_task_group() as task_group:
request = Request(scope, receive=receive)
async def process_dispatch(request: Request):
nonlocal dispatch_context_copy
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)
task_group.cancel_scope.cancel()
dispatch_context_copy = contextvars.copy_context()
dispatch_first_phase_ended.set()
response_sent.set()

async with anyio.create_task_group() as task_group:
task_group.start_soon(process_dispatch, Request(scope, receive=receive))

await dispatch_first_phase_ended.wait()

# Copy contextvars updated from dispatch into the current context.
for context_var, dispatch_context_value in dispatch_context_copy.items():
try:
if context_var.get() is not dispatch_context_value:
context_var.set(dispatch_context_value)
except LookupError:
context_var.set(dispatch_context_value)

if request_for_next is None:
return

async def receive_or_disconnect() -> Message:
if response_sent.is_set():
return {"type": "http.disconnect"}

async with anyio.create_task_group() as task_group:

async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
result = await func()
task_group.cancel_scope.cancel()
return result

task_group.start_soon(wrap, response_sent.wait)
message = await wrap(request_for_next.receive)

if response_sent.is_set():
return {"type": "http.disconnect"}

return message

async def close_recv_stream_on_response_sent() -> None:
await response_sent.wait()
recv_stream.close()

async def send_no_error(message: Message) -> None:
try:
await send_stream.send(message)
except anyio.BrokenResourceError:
# recv_stream has been closed, i.e. response_sent has been set.
return

send_stream, recv_stream = anyio.create_memory_object_stream()
streams_ready.set()

task_group.start_soon(close_recv_stream_on_response_sent)

async with send_stream:
try:
await self.app(scope, receive_or_disconnect, send_no_error)
except Exception as exc:
app_exc = exc

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
Expand Down
Loading