Skip to content

Commit

Permalink
SDK: add a utility function for atomically writing a file (#5372)
Browse files Browse the repository at this point in the history
  • Loading branch information
SpecLad committed Nov 30, 2022
1 parent 460df33 commit 38193ff
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 22 deletions.
32 changes: 11 additions & 21 deletions cvat-sdk/cvat_sdk/core/downloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

from __future__ import annotations

import os
import os.path as osp
from contextlib import closing
from typing import TYPE_CHECKING, Any, Dict, Optional

from cvat_sdk.api_client.api_client import Endpoint
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.utils import atomic_writer

if TYPE_CHECKING:
from cvat_sdk.core.client import Client
Expand Down Expand Up @@ -41,10 +41,6 @@ def download_file(

assert not osp.exists(output_path)

tmp_path = output_path + ".tmp"
if osp.exists(tmp_path):
raise FileExistsError(f"Can't write temporary file '{tmp_path}' - file exists")

response = self._client.api_client.rest_client.GET(
url,
_request_timeout=timeout,
Expand All @@ -57,25 +53,19 @@ def download_file(
except ValueError:
file_size = None

try:
with open(tmp_path, "wb") as fd:
if pbar is not None:
pbar.start(file_size, desc="Downloading")
with atomic_writer(output_path, "wb") as fd:
if pbar is not None:
pbar.start(file_size, desc="Downloading")

while True:
chunk = response.read(amt=CHUNK_SIZE, decode_content=False)
if not chunk:
break
while True:
chunk = response.read(amt=CHUNK_SIZE, decode_content=False)
if not chunk:
break

if pbar is not None:
pbar.advance(len(chunk))

fd.write(chunk)
if pbar is not None:
pbar.advance(len(chunk))

os.rename(tmp_path, output_path)
except:
os.unlink(tmp_path)
raise
fd.write(chunk)

def prepare_and_download_file_from_endpoint(
self,
Expand Down
71 changes: 70 additions & 1 deletion cvat-sdk/cvat_sdk/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,79 @@

from __future__ import annotations

from typing import Any, Dict, Sequence
import contextlib
import itertools
import os
from typing import (
IO,
Any,
BinaryIO,
ContextManager,
Dict,
Iterator,
Sequence,
TextIO,
Union,
overload,
)

from typing_extensions import Literal


def filter_dict(
d: Dict[str, Any], *, keep: Sequence[str] = None, drop: Sequence[str] = None
) -> Dict[str, Any]:
return {k: v for k, v in d.items() if (not keep or k in keep) and (not drop or k not in drop)}


@overload
def atomic_writer(path: Union[os.PathLike, str], mode: Literal["wb"]) -> ContextManager[BinaryIO]:
...


@overload
def atomic_writer(
path: Union[os.PathLike, str], mode: Literal["w"], encoding: str = "UTF-8"
) -> ContextManager[TextIO]:
...


@contextlib.contextmanager
def atomic_writer(
path: Union[os.PathLike, str], mode: Literal["w", "wb"], encoding: str = "UTF-8"
) -> Iterator[IO]:
"""
Returns a context manager that, when entered, returns a handle to a temporary
file opened with the specified `mode` and `encoding`. If the context manager
is exited via an exception, the temporary file is deleted. If the context manager
is exited normally, the file is renamed to `path`.
In other words, this function works like `open()`, but the file does not appear
at the specified path until and unless the context manager is exited
normally.
"""

path_str = os.fspath(path)

for counter in itertools.count():
tmp_path = f"{path_str}.tmp{counter}"

try:
if mode == "w":
tmp_file = open(tmp_path, "xt", encoding=encoding)
elif mode == "wb":
tmp_file = open(tmp_path, "xb")
else:
raise ValueError(f"Unsupported mode: {mode!r}")

break
except FileExistsError:
pass # try next counter value

try:
with tmp_file:
yield tmp_file
os.rename(tmp_path, path)
except:
os.unlink(tmp_path)
raise

0 comments on commit 38193ff

Please sign in to comment.