Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow using Request.form() as a context manager #1903

Merged
merged 11 commits into from
Feb 6, 2023
54 changes: 54 additions & 0 deletions starlette/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import asyncio
import functools
import sys
import typing
from types import TracebackType

if sys.version_info < (3, 8): # pragma: no cover
from typing_extensions import Protocol
else: # pragma: no cover
from typing import Protocol


def is_async_callable(obj: typing.Any) -> bool:
Expand All @@ -10,3 +17,50 @@ def is_async_callable(obj: typing.Any) -> bool:
return asyncio.iscoroutinefunction(obj) or (
callable(obj) and asyncio.iscoroutinefunction(obj.__call__)
)


T_co = typing.TypeVar("T_co", covariant=True)


class AwaitableOrContextManager(Protocol[T_co]):
def __await__(self) -> typing.Generator[typing.Any, None, T_co]:
... # pragma: no cover

async def __aenter__(self) -> T_co:
... # pragma: no cover

async def __aexit__(
self,
__exc_type: typing.Optional[typing.Type[BaseException]],
__exc_value: typing.Optional[BaseException],
__traceback: typing.Optional[TracebackType],
) -> typing.Union[bool, None]:
... # pragma: no cover
Kludex marked this conversation as resolved.
Show resolved Hide resolved
adriangb marked this conversation as resolved.
Show resolved Hide resolved


class SupportsAsyncClose(Protocol):
async def close(self) -> None:
... # pragma: no cover


SupportsAsyncCloseType = typing.TypeVar(
"SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False
)


class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]):
__slots__ = ("aw", "entered")

def __init__(self, aw: typing.Awaitable[SupportsAsyncCloseType]) -> None:
self.aw = aw

def __await__(self) -> typing.Generator[typing.Any, None, SupportsAsyncCloseType]:
return self.aw.__await__()

async def __aenter__(self) -> SupportsAsyncCloseType:
self.entered = await self.aw
return self.entered

async def __aexit__(self, *args: typing.Any) -> typing.Union[None, bool]:
await self.entered.close()
return None
15 changes: 10 additions & 5 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import anyio

from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
from starlette.exceptions import HTTPException
from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
Expand Down Expand Up @@ -187,6 +188,8 @@ async def empty_send(message: Message) -> typing.NoReturn:


class Request(HTTPConnection):
_form: typing.Optional[FormData]

def __init__(
self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
):
Expand All @@ -196,6 +199,7 @@ def __init__(
self._send = send
self._stream_consumed = False
self._is_disconnected = False
self._form = None

@property
def method(self) -> str:
Expand All @@ -210,10 +214,8 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]:
yield self._body
yield b""
return

if self._stream_consumed:
raise RuntimeError("Stream consumed")

self._stream_consumed = True
while True:
message = await self._receive()
Expand Down Expand Up @@ -242,8 +244,8 @@ async def json(self) -> typing.Any:
self._json = json.loads(body)
return self._json

async def form(self) -> FormData:
if not hasattr(self, "_form"):
async def _get_form(self) -> FormData:
if self._form is None:
assert (
parse_options_header is not None
), "The `python-multipart` library must be installed to use form parsing."
Expand All @@ -265,8 +267,11 @@ async def form(self) -> FormData:
self._form = FormData()
return self._form

def form(self) -> AwaitableOrContextManager[FormData]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be async?

Suggested change
def form(self) -> AwaitableOrContextManager[FormData]:
async def form(self) -> AwaitableOrContextManager[FormData]:

I understand it would work either way just because the awaitable thing is returned, but maybe it could help to make it async, to make it explicit in the function that it returns something to be awaited. What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately I don't think that will work: we want to pass an await able object to AwaitableOrContextManagerWrapper so that it can be used like async with request.form() and not async with await request.form().

return AwaitableOrContextManagerWrapper(self._get_form())

async def close(self) -> None:
if hasattr(self, "_form"):
if self._form is not None:
await self._form.close()

async def is_disconnected(self) -> bool:
Expand Down
13 changes: 13 additions & 0 deletions tests/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,19 @@ async def app(scope, receive, send):
assert response.json() == {"form": {"abc": "123 @"}}


def test_request_form_context_manager(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
async with request.form() as form:
response = JSONResponse({"form": dict(form)})
await response(scope, receive, send)

client = test_client_factory(app)

response = client.post("/", data={"abc": "123 @"})
assert response.json() == {"form": {"abc": "123 @"}}


def test_request_body_then_stream(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
Expand Down