Skip to content

Commit

Permalink
WIP partial revert of disconnect
Browse files Browse the repository at this point in the history
Partial revert of this commit:

040d8c8 Replace task cancellation in `BaseHTTPMiddleware` with `http.disconnect`+`recv_stream.close` (encode#1715)
  • Loading branch information
gctucker committed Jun 12, 2024
1 parent 554f368 commit 01ed47d
Showing 1 changed file with 2 additions and 34 deletions.
36 changes: 2 additions & 34 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,54 +105,22 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

request = _CachedRequest(scope, receive)
wrapped_receive = request.wrapped_receive
response_sent = anyio.Event()

async def call_next(request: Request) -> Response:
app_exc: Exception | None = None
send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
send_stream, recv_stream = anyio.create_memory_object_stream()

async def receive_or_disconnect() -> Message:
if response_sent.is_set():
return {"type": "http.disconnect"}

async with anyio.create_task_group() as task_group:

async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
result = await func()
task_group.cancel_scope.cancel()
return result

task_group.start_soon(wrap, response_sent.wait)
message = await wrap(wrapped_receive)

if response_sent.is_set():
return {"type": "http.disconnect"}

return message

async def close_recv_stream_on_response_sent() -> None:
await response_sent.wait()
recv_stream.close()

async def send_no_error(message: Message) -> None:
try:
await send_stream.send(message)
except anyio.BrokenResourceError:
# recv_stream has been closed, i.e. response_sent has been set.
return

async def coro() -> None:
nonlocal app_exc

async with send_stream:
try:
await self.app(scope, receive_or_disconnect, send_no_error)
await self.app(scope, request.receive, send_stream.send)
except Exception as exc:
app_exc = exc

task_group.start_soon(close_recv_stream_on_response_sent)
task_group.start_soon(coro)

try:
Expand Down Expand Up @@ -190,7 +158,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
async with anyio.create_task_group() as task_group:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
response_sent.set()
task_group.cancel_scope.cancel()

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
Expand Down

0 comments on commit 01ed47d

Please sign in to comment.