Skip to content

Commit

Permalink
Restore 304 performance after fixing FileResponse replace race (#10113
Browse files Browse the repository at this point in the history
)

(cherry picked from commit 0130213)
  • Loading branch information
bdraco authored and patchback[bot] committed Dec 5, 2024
1 parent cd0c2c8 commit 435b7c8
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 72 deletions.
1 change: 1 addition & 0 deletions CHANGES/10113.bugfix.rst
163 changes: 91 additions & 72 deletions aiohttp/web_fileresponse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pathlib
import sys
from contextlib import suppress
from enum import Enum, auto
from mimetypes import MimeTypes
from stat import S_ISREG
from types import MappingProxyType
Expand Down Expand Up @@ -69,6 +70,16 @@
}
)


class _FileResponseResult(Enum):
"""The result of the file response."""

SEND_FILE = auto() # Ie a regular file to send
NOT_ACCEPTABLE = auto() # Ie a socket, or non-regular file
PRE_CONDITION_FAILED = auto() # Ie If-Match or If-None-Match failed
NOT_MODIFIED = auto() # 304 Not Modified


# Add custom pairs and clear the encodings map so guess_type ignores them.
CONTENT_TYPES.encodings_map.clear()
for content_type, extension in ADDITIONAL_CONTENT_TYPES.items():
Expand Down Expand Up @@ -166,17 +177,65 @@ async def _precondition_failed(
self.content_length = 0
return await super().prepare(request)

def _open_file_path_stat_encoding(
self, accept_encoding: str
) -> Tuple[Optional[io.BufferedReader], os.stat_result, Optional[str]]:
"""Return the io object, stat result, and encoding.
def _make_response(
self, request: "BaseRequest", accept_encoding: str
) -> Tuple[
_FileResponseResult, Optional[io.BufferedReader], os.stat_result, Optional[str]
]:
"""Return the response result, io object, stat result, and encoding.
If an uncompressed file is returned, the encoding is set to
:py:data:`None`.
This method should be called from a thread executor
since it calls os.stat which may block.
"""
file_path, st, file_encoding = self._get_file_path_stat_encoding(
accept_encoding
)
if not file_path:
return _FileResponseResult.NOT_ACCEPTABLE, None, st, None

etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}"

# https://www.rfc-editor.org/rfc/rfc9110#section-13.1.1-2
if (ifmatch := request.if_match) is not None and not self._etag_match(
etag_value, ifmatch, weak=False
):
return _FileResponseResult.PRE_CONDITION_FAILED, None, st, file_encoding

if (
(unmodsince := request.if_unmodified_since) is not None
and ifmatch is None
and st.st_mtime > unmodsince.timestamp()
):
return _FileResponseResult.PRE_CONDITION_FAILED, None, st, file_encoding

# https://www.rfc-editor.org/rfc/rfc9110#section-13.1.2-2
if (ifnonematch := request.if_none_match) is not None and self._etag_match(
etag_value, ifnonematch, weak=True
):
return _FileResponseResult.NOT_MODIFIED, None, st, file_encoding

if (
(modsince := request.if_modified_since) is not None
and ifnonematch is None
and st.st_mtime <= modsince.timestamp()
):
return _FileResponseResult.NOT_MODIFIED, None, st, file_encoding

fobj = file_path.open("rb")
with suppress(OSError):
# fstat() may not be available on all platforms
# Once we open the file, we want the fstat() to ensure
# the file has not changed between the first stat()
# and the open().
st = os.stat(fobj.fileno())
return _FileResponseResult.SEND_FILE, fobj, st, file_encoding

def _get_file_path_stat_encoding(
self, accept_encoding: str
) -> Tuple[Optional[pathlib.Path], os.stat_result, Optional[str]]:
file_path = self._path
for file_extension, file_encoding in ENCODING_EXTENSIONS.items():
if file_encoding not in accept_encoding:
Expand All @@ -187,36 +246,22 @@ def _open_file_path_stat_encoding(
# Do not follow symlinks and ignore any non-regular files.
st = compressed_path.lstat()
if S_ISREG(st.st_mode):
fobj = compressed_path.open("rb")
with suppress(OSError):
# fstat() may not be available on all platforms
# Once we open the file, we want the fstat() to ensure
# the file has not changed between the first stat()
# and the open().
st = os.stat(fobj.fileno())
return fobj, st, file_encoding
return compressed_path, st, file_encoding

# Fallback to the uncompressed file
st = file_path.stat()
if not S_ISREG(st.st_mode):
return None, st, None
fobj = file_path.open("rb")
with suppress(OSError):
# fstat() may not be available on all platforms
# Once we open the file, we want the fstat() to ensure
# the file has not changed between the first stat()
# and the open().
st = os.stat(fobj.fileno())
return fobj, st, None
return file_path, st, None

async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]:
loop = asyncio.get_running_loop()
# Encoding comparisons should be case-insensitive
# https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1
accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower()
try:
fobj, st, file_encoding = await loop.run_in_executor(
None, self._open_file_path_stat_encoding, accept_encoding
response_result, fobj, st, file_encoding = await loop.run_in_executor(
None, self._make_response, request, accept_encoding
)
except PermissionError:
self.set_status(HTTPForbidden.status_code)
Expand All @@ -227,24 +272,32 @@ async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter
self.set_status(HTTPNotFound.status_code)
return await super().prepare(request)

try:
# Forbid special files like sockets, pipes, devices, etc.
if not fobj or not S_ISREG(st.st_mode):
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)
# Forbid special files like sockets, pipes, devices, etc.
if response_result is _FileResponseResult.NOT_ACCEPTABLE:
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)

if response_result is _FileResponseResult.PRE_CONDITION_FAILED:
return await self._precondition_failed(request)

if response_result is _FileResponseResult.NOT_MODIFIED:
etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}"
last_modified = st.st_mtime
return await self._not_modified(request, etag_value, last_modified)

assert fobj is not None
try:
return await self._prepare_open_file(request, fobj, st, file_encoding)
finally:
if fobj:
# We do not await here because we do not want to wait
# for the executor to finish before returning the response
# so the connection can begin servicing another request
# as soon as possible.
close_future = loop.run_in_executor(None, fobj.close)
# Hold a strong reference to the future to prevent it from being
# garbage collected before it completes.
_CLOSE_FUTURES.add(close_future)
close_future.add_done_callback(_CLOSE_FUTURES.remove)
# We do not await here because we do not want to wait
# for the executor to finish before returning the response
# so the connection can begin servicing another request
# as soon as possible.
close_future = loop.run_in_executor(None, fobj.close)
# Hold a strong reference to the future to prevent it from being
# garbage collected before it completes.
_CLOSE_FUTURES.add(close_future)
close_future.add_done_callback(_CLOSE_FUTURES.remove)

async def _prepare_open_file(
self,
Expand All @@ -253,43 +306,9 @@ async def _prepare_open_file(
st: os.stat_result,
file_encoding: Optional[str],
) -> Optional[AbstractStreamWriter]:
etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}"
last_modified = st.st_mtime

# https://www.rfc-editor.org/rfc/rfc9110#section-13.1.1-2
ifmatch = request.if_match
if ifmatch is not None and not self._etag_match(
etag_value, ifmatch, weak=False
):
return await self._precondition_failed(request)

unmodsince = request.if_unmodified_since
if (
unmodsince is not None
and ifmatch is None
and st.st_mtime > unmodsince.timestamp()
):
return await self._precondition_failed(request)

# https://www.rfc-editor.org/rfc/rfc9110#section-13.1.2-2
ifnonematch = request.if_none_match
if ifnonematch is not None and self._etag_match(
etag_value, ifnonematch, weak=True
):
return await self._not_modified(request, etag_value, last_modified)

modsince = request.if_modified_since
if (
modsince is not None
and ifnonematch is None
and st.st_mtime <= modsince.timestamp()
):
return await self._not_modified(request, etag_value, last_modified)

status = self._status
file_size = st.st_size
count = file_size

start = None

ifrange = request.if_range
Expand Down Expand Up @@ -378,7 +397,7 @@ async def _prepare_open_file(
# compress.
self._compression = False

self.etag = etag_value # type: ignore[assignment]
self.etag = f"{st.st_mtime_ns:x}-{st.st_size:x}" # type: ignore[assignment]
self.last_modified = st.st_mtime # type: ignore[assignment]
self.content_length = count

Expand Down

0 comments on commit 435b7c8

Please sign in to comment.