Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise on File.stream's aenter if the file is a dir or doesn't exist #1046

Merged
merged 10 commits into from
Mar 3, 2022
3 changes: 3 additions & 0 deletions changes/1046.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
The async context manager returned by `File.stream` now errors on enter if the target file doesn't exist to improve error handling when a file that doesn't exist is sent as an attachment.

The multiprocessing file reader strategy now expands user relative (`~`) links (like the threaded strategy).
96 changes: 66 additions & 30 deletions hikari/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import mimetypes
import os
import pathlib
import stat
import typing
import urllib.parse
import urllib.request
Expand Down Expand Up @@ -394,7 +395,7 @@ def __exit__(self, exc_type: typing.Type[Exception], exc_val: Exception, exc_tb:

@attr.define(weakref_slot=False)
@typing.final
class _NoOpAsyncReaderContextManagerImpl(typing.Generic[ReaderImplT], AsyncReaderContextManager[ReaderImplT]):
class _NoOpAsyncReaderContextManagerImpl(AsyncReaderContextManager[ReaderImplT]):
impl: ReaderImplT = attr.field()

async def __aenter__(self) -> ReaderImplT:
Expand Down Expand Up @@ -498,6 +499,8 @@ def stream(
-------
AsyncReaderContextManager[AsyncReader]
An async iterable of bytes to stream.

This will error on enter if the target resource doesn't exist.
"""

def __str__(self) -> str:
Expand Down Expand Up @@ -768,6 +771,54 @@ class FileReader(AsyncReader, abc.ABC):
"""The path to the resource to read."""


def _stat(path: pathlib.Path) -> os.stat_result:
# While paths will be implicitly resolved, we still need to explicitly
# call expanduser to deal with a ~ base.
try:
path = path.expanduser()
except RuntimeError:
pass # A home directory couldn't be resolved, so we'll just use the path as-is.
davfsa marked this conversation as resolved.
Show resolved Hide resolved

return path.stat()


@attr.define(weakref_slot=False)
@typing.final
class _FileAsyncReaderContextManagerImpl(AsyncReaderContextManager[FileReader]):
impl: FileReader = attr.field()

async def __aenter__(self) -> FileReader:
loop = asyncio.get_running_loop()

# Will raise FileNotFoundError if the file doesn't exist (unlike is_dir),
# which is what we want here.
file_stats = await loop.run_in_executor(self.impl.executor, _stat, self.impl.path)

if stat.S_ISDIR(file_stats.st_mode):
raise IsADirectoryError(self.impl.path)

return self.impl

async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]],
exc: typing.Optional[BaseException],
exc_tb: typing.Optional[types.TracebackType],
) -> None:
pass


def _open_file(path: pathlib.Path) -> typing.BinaryIO:
# While paths will be implicitly resolved, we still need to explicitly
# call expanduser to deal with a ~ base.
try:
path = path.expanduser()
except RuntimeError:
pass # A home directory couldn't be resolved, so we'll just use the path as-is.
davfsa marked this conversation as resolved.
Show resolved Hide resolved

return path.open("rb")


@attr.define(weakref_slot=False)
class ThreadedFileReader(FileReader):
"""Asynchronous file reader that reads a resource from local storage.
Expand All @@ -780,40 +831,29 @@ class ThreadedFileReader(FileReader):
async def __aiter__(self) -> typing.AsyncGenerator[typing.Any, bytes]:
loop = asyncio.get_running_loop()

path = self.path
if isinstance(path, pathlib.Path):
path = await loop.run_in_executor(self.executor, self._expand, self.path)

fp = await loop.run_in_executor(self.executor, self._open, path)
fp = await loop.run_in_executor(self.executor, _open_file, self.path)

try:
while True:
chunk = await loop.run_in_executor(self.executor, self._read_chunk, fp, _MAGIC)
chunk = await loop.run_in_executor(self.executor, fp.read, _MAGIC)
yield chunk
if len(chunk) < _MAGIC:
break

finally:
await loop.run_in_executor(self.executor, self._close, fp)

@staticmethod
def _expand(path: pathlib.Path) -> pathlib.Path:
# .expanduser is Platform dependent. Will expand stuff like ~ to /home/<user> on posix.
# .resolve will follow symlinks and what-have-we to translate stuff like `..` to proper paths.
return path.expanduser().resolve()
await loop.run_in_executor(self.executor, fp.close)

@staticmethod
@typing.final
def _read_chunk(fp: typing.IO[bytes], n: int = 10_000) -> bytes:
return fp.read(n)

@staticmethod
def _open(path: Pathish) -> typing.IO[bytes]:
return open(path, "rb")
def _read_all(path: pathlib.Path) -> bytes:
# While paths will be implicitly resolved, we still need to explicitly
# call expanduser to deal with a ~ base.
try:
path = path.expanduser()
except RuntimeError:
pass # A home directory couldn't be resolved, so we'll just use the path as-is.

@staticmethod
def _close(fp: typing.IO[bytes]) -> None:
fp.close()
with path.open("rb") as fp:
return fp.read()


@attr.define(slots=False, weakref_slot=False) # Do not slot (pickle)
Expand All @@ -827,7 +867,7 @@ class MultiprocessingFileReader(FileReader):
"""

async def __aiter__(self) -> typing.AsyncGenerator[typing.Any, bytes]:
yield await asyncio.get_running_loop().run_in_executor(self.executor, self._read_all)
yield await asyncio.get_running_loop().run_in_executor(self.executor, _read_all, self.path)

def __getstate__(self) -> typing.Dict[str, typing.Any]:
return {"path": self.path, "filename": self.filename}
Expand All @@ -838,10 +878,6 @@ def __setstate__(self, state: typing.Dict[str, typing.Any]) -> None:
self.executor: typing.Optional[concurrent.futures.Executor] = None
self.mimetype: typing.Optional[str] = None

def _read_all(self) -> bytes:
with open(self.path, "rb") as fp:
return fp.read()


class File(Resource[FileReader]):
"""A resource that exists on the local machine's storage to be uploaded.
Expand Down Expand Up @@ -921,7 +957,7 @@ def stream(
# so this is safe enough to do.:
is_threaded = executor is None or isinstance(executor, concurrent.futures.ThreadPoolExecutor)
impl = ThreadedFileReader if is_threaded else MultiprocessingFileReader
return _NoOpAsyncReaderContextManagerImpl(impl(self.filename, None, executor, self.path))
return _FileAsyncReaderContextManagerImpl(impl(self.filename, None, executor, self.path))


########################################################################
Expand Down
77 changes: 77 additions & 0 deletions tests/hikari/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import base64
import concurrent.futures
import contextlib
import pathlib
import random
import tempfile
import typing

import mock
import pytest

from hikari import files
Expand All @@ -41,3 +50,71 @@ def test___exit__(self, reader):
reader().__exit__(None, None, None)
except AttributeError as exc:
pytest.fail(exc)


class Test_FileAsyncReaderContextManagerImpl:
@pytest.mark.parametrize(
"executor", [concurrent.futures.ThreadPoolExecutor, concurrent.futures.ProcessPoolExecutor]
)
@pytest.mark.asyncio()
async def test_context_manager(self, executor: typing.Callable[[], concurrent.futures.Executor]):
mock_reader = mock.Mock(executor=executor())
context_manager = files._FileAsyncReaderContextManagerImpl(mock_reader)

with tempfile.NamedTemporaryFile() as file:
mock_reader.path = pathlib.Path(file.name)

async with context_manager as reader:
assert reader is mock_reader

@pytest.mark.asyncio()
async def test_context_manager_when_expandname_raises_runtime_error(self):
# We can't mock patch stuff in other processes easily (if at all) so
# for this test we only run it threaded.
mock_reader = mock.Mock(executor=concurrent.futures.ThreadPoolExecutor())
context_manager = files._FileAsyncReaderContextManagerImpl(mock_reader)

stack = contextlib.ExitStack()
file = stack.enter_context(tempfile.NamedTemporaryFile())
expandname = stack.enter_context(mock.patch.object(pathlib.Path, "expanduser", side_effect=RuntimeError))

with file:
mock_reader.path = pathlib.Path(file.name)

async with context_manager as reader:
assert reader is mock_reader

expandname.assert_called_once_with()

@pytest.mark.parametrize(
"executor", [concurrent.futures.ThreadPoolExecutor, concurrent.futures.ProcessPoolExecutor]
)
@pytest.mark.asyncio()
async def test_context_manager_for_unknown_file(self, executor: typing.Callable[[], concurrent.futures.Executor]):
mock_reader = mock.Mock(executor=executor())
context_manager = files._FileAsyncReaderContextManagerImpl(mock_reader)

mock_reader.path = pathlib.Path(
base64.urlsafe_b64encode(random.getrandbits(512).to_bytes(64, "little")).decode()
)

with pytest.raises(FileNotFoundError): # noqa: PT012 - raises block should contain a single statement
async with context_manager:
...

@pytest.mark.parametrize(
"executor", [concurrent.futures.ThreadPoolExecutor, concurrent.futures.ProcessPoolExecutor]
)
@pytest.mark.asyncio()
async def test_test_context_manager_when_target_is_dir(
self, executor: typing.Callable[[], concurrent.futures.Executor]
):
mock_reader = mock.Mock(executor=executor())
context_manager = files._FileAsyncReaderContextManagerImpl(mock_reader)

with tempfile.TemporaryDirectory() as name:
mock_reader.path = pathlib.Path(name)

with pytest.raises(IsADirectoryError): # noqa: PT012 - raises block should contain a single statement
async with context_manager:
...