-
-
Notifications
You must be signed in to change notification settings - Fork 948
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Shield send "http.response.start" from cancellation #1706
Conversation
`RuntimeError: No response returned.` is raised in BaseHTTPMiddleware if request is disconnected, due to `task_group.cancel_scope.cancel()` in StreamingResponse.__call__.<locals>.wrap and cancellation check in `await checkpoint()` of MemoryObjectSendStream.send. Let's fix this behaviour change caused by anyio integration in 0.15.0.
I think this can be fixed without shielding. This test fails on diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py
index 49a5e3e..5210d2d 100644
--- a/starlette/middleware/base.py
+++ b/starlette/middleware/base.py
@@ -4,7 +4,7 @@ import anyio
from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
-from starlette.types import ASGIApp, Receive, Scope, Send
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
@@ -12,6 +12,10 @@ DispatchFunction = typing.Callable[
]
+class _ClientDisconnected(Exception):
+ pass
+
+
class BaseHTTPMiddleware:
def __init__(
self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
@@ -28,12 +32,18 @@ class BaseHTTPMiddleware:
app_exc: typing.Optional[Exception] = None
send_stream, recv_stream = anyio.create_memory_object_stream()
+ async def recv() -> Message:
+ message = await request.receive()
+ if message["type"] == "http.disconnect":
+ raise _ClientDisconnected
+ return message
+
async def coro() -> None:
nonlocal app_exc
async with send_stream:
try:
- await self.app(scope, request.receive, send_stream.send)
+ await self.app(scope, recv, send_stream.send)
except Exception as exc:
app_exc = exc
@@ -69,7 +79,10 @@ class BaseHTTPMiddleware:
async with anyio.create_task_group() as task_group:
request = Request(scope, receive=receive)
- response = await self.dispatch_func(request, call_next)
+ try:
+ response = await self.dispatch_func(request, call_next)
+ except _ClientDisconnected:
+ return
await response(scope, receive, send)
task_group.cancel_scope.cancel()
diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py
index 976d77b..92826bc 100644
--- a/tests/middleware/test_base.py
+++ b/tests/middleware/test_base.py
@@ -1,13 +1,17 @@
import contextvars
+from contextlib import AsyncExitStack
+from typing import AsyncGenerator, Awaitable, Callable
+import anyio
import pytest
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
-from starlette.responses import PlainTextResponse, StreamingResponse
+from starlette.requests import Request
+from starlette.responses import PlainTextResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
-from starlette.types import ASGIApp, Receive, Scope, Send
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
class CustomMiddleware(BaseHTTPMiddleware):
@@ -206,3 +210,41 @@ def test_contextvars(test_client_factory, middleware_cls: type):
client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200, response.content
+
+
+@pytest.mark.anyio
+async def test_client_disconnects_before_response_is_sent() -> None:
+ # test for https://github.com/encode/starlette/issues/1527
+ app: ASGIApp
+
+ async def homepage(request: Request):
+ await anyio.sleep(5)
+ return PlainTextResponse("hi!")
+
+ async def dispatch(
+ request: Request, call_next: Callable[[Request], Awaitable[Response]]
+ ) -> Response:
+ return await call_next(request)
+
+ app = BaseHTTPMiddleware(Route("/", homepage), dispatch=dispatch)
+ app = BaseHTTPMiddleware(app, dispatch=dispatch)
+
+ async def recv_gen() -> AsyncGenerator[Message, None]:
+ yield {"type": "http.request"}
+ yield {"type": "http.disconnect"}
+
+ async def send_gen() -> AsyncGenerator[None, Message]:
+ msg = yield
+ assert msg["type"] == "http.response.start"
+ msg = yield
+ raise AssertionError("Should not be called")
+
+ scope = {"type": "http", "method": "GET", "path": "/"}
+
+ async with AsyncExitStack() as stack:
+ recv = recv_gen()
+ stack.push_async_callback(recv.aclose)
+ send = send_gen()
+ stack.push_async_callback(send.aclose)
+ await send.__anext__()
+ await app(scope, recv.__aiter__().__anext__, send.asend) |
My fix addresses the behaviour change in That behaviour of That test passes on my branch if async def recv_gen() -> AsyncGenerator[Message, None]:
yield {"type": "http.request"}
yield {"type": "http.disconnect"}
+ yield {"type": "http.disconnect"} |
Apologies. I've been looking at I asked in asgiref for confirmation on the expected behavior or ASGI servers w.r.t. sending the disconnect message multiple times. I think it would be a good idea to adapt that test (or just write a new one, up to you) to the specific situation this is supposed to fix. I think a test will be required before merging this. |
I've added a test for the specific situation this is supposed to fix. Actually, I don't think async def send(msg):
with anyio.CancelScope(shield=True):
await send_stream.send(msg) - await self.app(scope, request.receive, send_stream.send)
+ await self.app(scope, request.receive, send) I think it is probably preferable to do this in I am also happy to close this PR in favour of the fix that you proposed in |
If that's all that's required in Also the barrier for doing something like this in BaseHTTPMiddleware is a lower: the fix is close the the source of the issue and BaseHTTPMiddleware already is dealing with streams, tasks, cancellation and such so adding some shielding isn't moving the needle too much on complexity. |
Well, it's not actually 1 LOC 😅 I have submitted PR #1710 to shield send "http.response.start" from cancellation in |
Fixes #1634
RuntimeError: No response returned.
is raised inBaseHTTPMiddleware
if request is disconnected, due totask_group.cancel_scope.cancel()
inStreamingResponse.__call__.<locals>.wrap
and cancellation check inawait checkpoint()
ofMemoryObjectSendStream.send
.Let's fix this behaviour change caused by anyio integration in 0.15.0.
I managed to make this error reproducible in 0.14.2 by partially emulating 0.15.0 logic: acjh@37dd8ac
starlette/concurrency.py:
starlette/middleware/base.py: