Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 'I/O operation on closed file' and 'Form data has been processed already' upon redirect on multipart data #9201

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGES/9201.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix `I/O operation on closed file` and `Form data has been processed already` upon redirect on multipart data -- by :user:`GLGDLY`.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ Franek Magiera
Frederik Gladhorn
Frederik Peter Aalund
Gabriel Tremblay
Gary Leung
Gary Wilson Jr.
Gennady Andreyev
Georges Dubus
Expand Down
4 changes: 4 additions & 0 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
ClientResponse,
Fingerprint,
RequestInfo,
process_data_to_payload,
)
from .client_ws import (
DEFAULT_WS_CLIENT_TIMEOUT,
Expand Down Expand Up @@ -521,6 +522,9 @@ async def _request(
for trace in traces:
await trace.send_request_start(method, url.update_query(params), headers)

# preprocess the data so we can reuse the Payload object when redirect is needed
data = process_data_to_payload(data)

timer = tm.timer()
try:
with timer:
Expand Down
15 changes: 15 additions & 0 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,21 @@ class ConnectionKey:
proxy_headers_hash: Optional[int] # hash(CIMultiDict)


def process_data_to_payload(body: Any) -> Any:
# this function is used to convert data to payload before looping into redirects,
# so payload with io objects can be keep alive and use the stored data for the next request
if body is None:
return None

if isinstance(body, FormData):
body = body()

with contextlib.suppress(payload.LookupError):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: benchmark this in case the non exceptional case is to raise an exception

body = payload.PAYLOAD_REGISTRY.get(body, disposition=None)

return body


class ClientRequest:
GET_METHODS = {
hdrs.METH_GET,
Expand Down
5 changes: 1 addition & 4 deletions aiohttp/formdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(
self._writer = multipart.MultipartWriter("form-data", boundary=self._boundary)
self._fields: List[Any] = []
self._is_multipart = False
self._is_processed = False
self._quote_fields = quote_fields
self._charset = charset

Expand Down Expand Up @@ -117,8 +116,6 @@ def _gen_form_urlencoded(self) -> payload.BytesPayload:

def _gen_form_data(self) -> multipart.MultipartWriter:
"""Encode a list of fields using the multipart/form-data MIME format"""
if self._is_processed:
raise RuntimeError("Form data has been processed already")
for dispparams, headers, value in self._fields:
try:
if hdrs.CONTENT_TYPE in headers:
Expand Down Expand Up @@ -149,7 +146,7 @@ def _gen_form_data(self) -> multipart.MultipartWriter:

self._writer.append_payload(part)

self._is_processed = True
self._fields.clear()
return self._writer

def __call__(self) -> Payload:
Expand Down
79 changes: 55 additions & 24 deletions aiohttp/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,17 +307,36 @@ def __init__(
if hdrs.CONTENT_DISPOSITION not in self.headers:
self.set_content_disposition(disposition, filename=self._filename)

self._writable = True

try:
self._seekable = self._value.seekable()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think seekable can be blocking , at least aiofiles delegates it the executor https://pypi.org/project/aiofiles/

except AttributeError: # https://github.com/python/cpython/issues/124293
self._seekable = False

if self._seekable:
self._stream_pos = self._value.tell()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: check if tell can be blocking

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure tell() is blocking as aiofiles delegates it to the executor as well https://pypi.org/project/aiofiles/

else:
self._stream_pos = 0

async def write(self, writer: AbstractStreamWriter) -> None:
loop = asyncio.get_event_loop()
try:
if self._seekable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: see if executor jobs can be combined

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be combined into a single executor job. Example

diff --git a/aiohttp/payload.py b/aiohttp/payload.py
index 395c44f6e..3b222b93e 100644
--- a/aiohttp/payload.py
+++ b/aiohttp/payload.py
@@ -319,15 +319,19 @@ class IOBasePayload(Payload):
         else:
             self._stream_pos = 0
 
-    async def write(self, writer: AbstractStreamWriter) -> None:
-        loop = asyncio.get_event_loop()
+    def _read_first(self) -> None:
+        """Read the first chunk of data from the stream."""
         if self._seekable:
-            await loop.run_in_executor(None, self._value.seek, self._stream_pos)
+            self._value.seek(self._stream_pos)
         elif not self._writable:
             raise RuntimeError(
                 f'Non-seekable IO payload "{self._value}" is already consumed (possibly due to redirect, consider storing in a seekable IO buffer instead)'
             )
-        chunk = await loop.run_in_executor(None, self._value.read, 2**16)
+        return self._value.read(2**16)
+
+    async def write(self, writer: AbstractStreamWriter) -> None:
+        loop = asyncio.get_running_loop()
+        chunk = await loop.run_in_executor(None, self._read_first)
         while chunk:
             await writer.write(chunk)
             chunk = await loop.run_in_executor(None, self._value.read, 2**16)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to seek here if it's the first time? That seems like it will be the common case

await loop.run_in_executor(None, self._value.seek, self._stream_pos)
elif not self._writable:
raise RuntimeError(
f'Non-seekable IO payload "{self._value}" is already consumed (possibly due to redirect, consider storing in a seekable IO buffer instead)'
)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
while chunk:
await writer.write(chunk)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
while chunk:
await writer.write(chunk)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
finally:
await loop.run_in_executor(None, self._value.close)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: verify close will still always happen

if not self._seekable:
self._writable = False # Non-seekable IO `_value` can only be consumed once

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
if self._seekable:
self._value.seek(self._stream_pos)
return "".join(r.decode(encoding, errors) for r in self._value.readlines())


Expand Down Expand Up @@ -354,40 +373,50 @@ def __init__(
@property
def size(self) -> Optional[int]:
try:
return os.fstat(self._value.fileno()).st_size - self._value.tell()
return os.fstat(self._value.fileno()).st_size - self._stream_pos
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: verify this doesn't run in the event loop

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like multipart will call this in the event loop from append_payload via ClientRequest.update_body_from_data

except OSError:
return None

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
if self._seekable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to seek on the first time?

self._value.seek(self._stream_pos)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: verify this isn't called in the event loop as it does block

return self._value.read()

async def write(self, writer: AbstractStreamWriter) -> None:
loop = asyncio.get_event_loop()
try:
if self._seekable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to seek here if it's the first time?

await loop.run_in_executor(None, self._value.seek, self._stream_pos)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: see if executor jobs can be combined

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jobs can be combined like #9201 (comment)

elif not self._writable:
raise RuntimeError(
f'Non-seekable IO payload "{self._value}" is already consumed (possibly due to redirect, consider storing in a seekable IO buffer instead)'
)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
while chunk:
data = (
chunk.encode(encoding=self._encoding)
if self._encoding
else chunk.encode()
)
await writer.write(data)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
while chunk:
data = (
chunk.encode(encoding=self._encoding)
if self._encoding
else chunk.encode()
)
await writer.write(data)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
finally:
await loop.run_in_executor(None, self._value.close)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: verify close still happens in failure

if not self._seekable:
self._writable = False # Non-seekable IO `_value` can only be consumed once


class BytesIOPayload(IOBasePayload):
_value: io.BytesIO

@property
def size(self) -> int:
position = self._value.tell()
end = self._value.seek(0, os.SEEK_END)
self._value.seek(position)
return end - position
def size(self) -> Optional[int]:
if self._seekable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: verify this doesn't run in the event loop

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end = self._value.seek(0, os.SEEK_END)
self._value.seek(self._stream_pos)
return end - self._stream_pos
return None

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
if self._seekable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: make sure this is run in the executor

self._value.seek(self._stream_pos)
return self._value.read().decode(encoding, errors)


Expand All @@ -397,7 +426,7 @@ class BufferedReaderPayload(IOBasePayload):
@property
def size(self) -> Optional[int]:
try:
return os.fstat(self._value.fileno()).st_size - self._value.tell()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: make sure this is run in the executor

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return os.fstat(self._value.fileno()).st_size - self._stream_pos
except (OSError, AttributeError):
# data.fileno() is not supported, e.g.
# io.BufferedReader(io.BytesIO(b'data'))
Expand All @@ -406,6 +435,8 @@ def size(self) -> Optional[int]:
return None

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
if self._seekable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: verify this does not run in the event loop

self._value.seek(self._stream_pos)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to happen on the first attempt?

return self._value.read().decode(encoding, errors)


Expand Down
10 changes: 6 additions & 4 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,8 @@ async def test_GET_DEFLATE(
aiohttp_client: AiohttpClient, data: Optional[bytes]
) -> None:
async def handler(request: web.Request) -> web.Response:
recv_data = await request.read()
assert recv_data == b"" # both cases should receive empty bytes
return web.json_response({"ok": True})

write_mock = None
Expand All @@ -1553,10 +1555,10 @@ async def write_bytes(
self: ClientRequest, writer: StreamWriter, conn: Connection
) -> None:
nonlocal write_mock
original_write = writer._write
original_write = writer.write

with mock.patch.object(
writer, "_write", autospec=True, spec_set=True, side_effect=original_write
writer, "write", autospec=True, spec_set=True, side_effect=original_write
) as write_mock:
await original_write_bytes(self, writer, conn)

Expand All @@ -1571,8 +1573,8 @@ async def write_bytes(
assert content == {"ok": True}

assert write_mock is not None
# No chunks should have been sent for an empty body.
write_mock.assert_not_called()
# Empty b"" should have been sent for an empty body.
write_mock.assert_called_once_with(b"")


async def test_POST_DATA_DEFLATE(aiohttp_client: AiohttpClient) -> None:
Expand Down
Loading
Loading