diff --git a/cvat-sdk/cvat_sdk/core/downloading.py b/cvat-sdk/cvat_sdk/core/downloading.py index 7caf7138d579..7831b63fa5e6 100644 --- a/cvat-sdk/cvat_sdk/core/downloading.py +++ b/cvat-sdk/cvat_sdk/core/downloading.py @@ -5,13 +5,13 @@ from __future__ import annotations -import os import os.path as osp from contextlib import closing from typing import TYPE_CHECKING, Any, Dict, Optional from cvat_sdk.api_client.api_client import Endpoint from cvat_sdk.core.progress import ProgressReporter +from cvat_sdk.core.utils import atomic_writer if TYPE_CHECKING: from cvat_sdk.core.client import Client @@ -41,10 +41,6 @@ def download_file( assert not osp.exists(output_path) - tmp_path = output_path + ".tmp" - if osp.exists(tmp_path): - raise FileExistsError(f"Can't write temporary file '{tmp_path}' - file exists") - response = self._client.api_client.rest_client.GET( url, _request_timeout=timeout, @@ -57,25 +53,19 @@ def download_file( except ValueError: file_size = None - try: - with open(tmp_path, "wb") as fd: - if pbar is not None: - pbar.start(file_size, desc="Downloading") + with atomic_writer(output_path, "wb") as fd: + if pbar is not None: + pbar.start(file_size, desc="Downloading") - while True: - chunk = response.read(amt=CHUNK_SIZE, decode_content=False) - if not chunk: - break + while True: + chunk = response.read(amt=CHUNK_SIZE, decode_content=False) + if not chunk: + break - if pbar is not None: - pbar.advance(len(chunk)) - - fd.write(chunk) + if pbar is not None: + pbar.advance(len(chunk)) - os.rename(tmp_path, output_path) - except: - os.unlink(tmp_path) - raise + fd.write(chunk) def prepare_and_download_file_from_endpoint( self, diff --git a/cvat-sdk/cvat_sdk/core/utils.py b/cvat-sdk/cvat_sdk/core/utils.py index 407b6d3e79c4..e7c28e90e9f9 100644 --- a/cvat-sdk/cvat_sdk/core/utils.py +++ b/cvat-sdk/cvat_sdk/core/utils.py @@ -4,10 +4,79 @@ from __future__ import annotations -from typing import Any, Dict, Sequence +import contextlib +import itertools +import os +from typing import ( + IO, + Any, + BinaryIO, + ContextManager, + Dict, + Iterator, + Sequence, + TextIO, + Union, + overload, +) + +from typing_extensions import Literal def filter_dict( d: Dict[str, Any], *, keep: Sequence[str] = None, drop: Sequence[str] = None ) -> Dict[str, Any]: return {k: v for k, v in d.items() if (not keep or k in keep) and (not drop or k not in drop)} + + +@overload +def atomic_writer(path: Union[os.PathLike, str], mode: Literal["wb"]) -> ContextManager[BinaryIO]: + ... + + +@overload +def atomic_writer( + path: Union[os.PathLike, str], mode: Literal["w"], encoding: str = "UTF-8" +) -> ContextManager[TextIO]: + ... + + +@contextlib.contextmanager +def atomic_writer( + path: Union[os.PathLike, str], mode: Literal["w", "wb"], encoding: str = "UTF-8" +) -> Iterator[IO]: + """ + Returns a context manager that, when entered, returns a handle to a temporary + file opened with the specified `mode` and `encoding`. If the context manager + is exited via an exception, the temporary file is deleted. If the context manager + is exited normally, the file is renamed to `path`. + + In other words, this function works like `open()`, but the file does not appear + at the specified path until and unless the context manager is exited + normally. + """ + + path_str = os.fspath(path) + + for counter in itertools.count(): + tmp_path = f"{path_str}.tmp{counter}" + + try: + if mode == "w": + tmp_file = open(tmp_path, "xt", encoding=encoding) + elif mode == "wb": + tmp_file = open(tmp_path, "xb") + else: + raise ValueError(f"Unsupported mode: {mode!r}") + + break + except FileExistsError: + pass # try next counter value + + try: + with tmp_file: + yield tmp_file + os.rename(tmp_path, path) + except: + os.unlink(tmp_path) + raise