Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow using Request.form() as a context manager #1903

Merged
merged 11 commits into from
Feb 6, 2023
10 changes: 5 additions & 5 deletions docs/requests.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ There are a few different interfaces for returning the body of the request:

The request body as bytes: `await request.body()`

The request body, parsed as form data or multipart: `await request.form()`
The request body, parsed as form data or multipart: `async with request.form() as form:`

The request body, parsed as JSON: `await request.json()`

Expand Down Expand Up @@ -114,7 +114,7 @@ state with `disconnected = await request.is_disconnected()`.

Request files are normally sent as multipart form data (`multipart/form-data`).

When you call `await request.form()` you receive a `starlette.datastructures.FormData` which is an immutable
When you call `async with request.form() as form` you receive a `starlette.datastructures.FormData` which is an immutable
multidict, containing both file uploads and text input. File upload items are represented as instances of `starlette.datastructures.UploadFile`.

`UploadFile` has the following attributes:
Expand All @@ -137,9 +137,9 @@ As all these methods are `async` methods, you need to "await" them.
For example, you can get the file name and the contents with:

```python
form = await request.form()
filename = form["upload_file"].filename
contents = await form["upload_file"].read()
async with request.form() as form:
filename = form["upload_file"].filename
contents = await form["upload_file"].read()
```

!!! info
Expand Down
57 changes: 57 additions & 0 deletions starlette/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import asyncio
import functools
import sys
import typing
from types import TracebackType

if sys.version_info < (3, 8): # pragma: no cover
from typing_extensions import Protocol
else: # pragma: no cover
from typing import Protocol


def is_async_callable(obj: typing.Any) -> bool:
Expand All @@ -10,3 +17,53 @@ def is_async_callable(obj: typing.Any) -> bool:
return asyncio.iscoroutinefunction(obj) or (
callable(obj) and asyncio.iscoroutinefunction(obj.__call__)
)


T_co = typing.TypeVar("T_co", covariant=True)


# TODO: once 3.8 is the minimum supported version (27 Jun 2023)
# this can just become
# class AwaitableOrContextManager(typing.Awaitable[T_co], typing.AsyncContextManager[T_co]): pass
adriangb marked this conversation as resolved.
Show resolved Hide resolved
class AwaitableOrContextManager(Protocol[T_co]):
def __await__(self) -> typing.Generator[typing.Any, None, T_co]:
... # pragma: no cover

async def __aenter__(self) -> T_co:
... # pragma: no cover

async def __aexit__(
self,
__exc_type: typing.Optional[typing.Type[BaseException]],
__exc_value: typing.Optional[BaseException],
__traceback: typing.Optional[TracebackType],
) -> typing.Union[bool, None]:
... # pragma: no cover
Kludex marked this conversation as resolved.
Show resolved Hide resolved
adriangb marked this conversation as resolved.
Show resolved Hide resolved


class SupportsAsyncClose(Protocol):
async def close(self) -> None:
... # pragma: no cover


SupportsAsyncCloseType = typing.TypeVar(
"SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False
)


class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]):
__slots__ = ("aw", "entered")

def __init__(self, aw: typing.Awaitable[SupportsAsyncCloseType]) -> None:
self.aw = aw

def __await__(self) -> typing.Generator[typing.Any, None, SupportsAsyncCloseType]:
return self.aw.__await__()

async def __aenter__(self) -> SupportsAsyncCloseType:
self.entered = await self.aw
return self.entered

async def __aexit__(self, *args: typing.Any) -> typing.Union[None, bool]:
await self.entered.close()
return None
15 changes: 10 additions & 5 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import anyio

from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
from starlette.exceptions import HTTPException
from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
Expand Down Expand Up @@ -187,6 +188,8 @@ async def empty_send(message: Message) -> typing.NoReturn:


class Request(HTTPConnection):
_form: typing.Optional[FormData]

def __init__(
self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
):
Expand All @@ -196,6 +199,7 @@ def __init__(
self._send = send
self._stream_consumed = False
self._is_disconnected = False
self._form = None

@property
def method(self) -> str:
Expand All @@ -210,10 +214,8 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]:
yield self._body
yield b""
return

if self._stream_consumed:
raise RuntimeError("Stream consumed")

self._stream_consumed = True
while True:
message = await self._receive()
Expand Down Expand Up @@ -242,8 +244,8 @@ async def json(self) -> typing.Any:
self._json = json.loads(body)
return self._json

async def form(self) -> FormData:
if not hasattr(self, "_form"):
async def _get_form(self) -> FormData:
if self._form is None:
assert (
parse_options_header is not None
), "The `python-multipart` library must be installed to use form parsing."
Expand All @@ -265,8 +267,11 @@ async def form(self) -> FormData:
self._form = FormData()
return self._form

def form(self) -> AwaitableOrContextManager[FormData]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be async?

Suggested change
def form(self) -> AwaitableOrContextManager[FormData]:
async def form(self) -> AwaitableOrContextManager[FormData]:

I understand it would work either way just because the awaitable thing is returned, but maybe it could help to make it async, to make it explicit in the function that it returns something to be awaited. What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately I don't think that will work: we want to pass an await able object to AwaitableOrContextManagerWrapper so that it can be used like async with request.form() and not async with await request.form().

return AwaitableOrContextManagerWrapper(self._get_form())

async def close(self) -> None:
if hasattr(self, "_form"):
if self._form is not None:
await self._form.close()

async def is_disconnected(self) -> bool:
Expand Down
13 changes: 13 additions & 0 deletions tests/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,19 @@ async def app(scope, receive, send):
assert response.json() == {"form": {"abc": "123 @"}}


def test_request_form_context_manager(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
async with request.form() as form:
response = JSONResponse({"form": dict(form)})
await response(scope, receive, send)

client = test_client_factory(app)

response = client.post("/", data={"abc": "123 @"})
assert response.json() == {"form": {"abc": "123 @"}}


def test_request_body_then_stream(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
Expand Down