Skip to content

Commit 5ec1d27

Browse files
paredefiop
authored andcommitted
remote: base: don't checkout existing file (#2358)
* remote: base: don't checkout existing file * remote: _checkout_file: split to smaller methods * state: local: verify if file needs checkout considering link type and access rights * tests: checkout: test avoiding unnecessary copying * efiop review refactor * remote: local: refactor _needs_checkout NOTE * remote: local: remove bool cast for `protected`, wrap linking in try * test: checkout: test relinking upon existing cache * stage: remove obsolete unprotecing from Stage.create * test: checkout: refactor setting link type * test: checkout: parametrize relinking test * remote: local: don't remove cache test file on cache_copy test * remote: base: remove force from _needs_checkout arguments * Update dvc/remote/local/__init__.py Co-Authored-By: Alexander Schepanovski <suor.web@gmail.com> * suor review refactor * remote: base: re-add note why do we unprotect on _checkout_file * test: checkout: tests for relinking and resetting target protection * remove checkout behavior change * remove unused imports * remote: base: reuse __checkout_file in __checkout_dir * initial dir * checkout: test: should relink on repeated add * remote: cleanup commented lines * remote: base/local: add save_link flag to _checkout_file * remote: local: add allow_copy flag to unprotect method * Update dvc/remote/local/__init__.py Co-Authored-By: Ruslan Kuprieiev <kupruser@gmail.com> * remote: local: remove unnecessary removal warning on _checkout_file * remote: local: proper cache link type detection * test: checkout: move repeated add tests * Update dvc/remote/local/__init__.py Co-Authored-By: Ruslan Kuprieiev <kupruser@gmail.com> * remote: base: remove _needs_checkout * Update dvc/remote/base.py Co-Authored-By: Ruslan Kuprieiev <kupruser@gmail.com> * remote: base: remove wrappers from _checkout_file
1 parent 74bd7d5 commit 5ec1d27

File tree

3 files changed

+174
-36
lines changed

3 files changed

+174
-36
lines changed

dvc/remote/base.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -680,20 +680,45 @@ def safe_remove(self, path_info, force=False):
680680
self.remove(path_info)
681681

682682
def _checkout_file(
683-
self, path_info, checksum, force, progress_callback=None
683+
self,
684+
path_info,
685+
checksum,
686+
force,
687+
progress_callback=None,
688+
save_link=True,
684689
):
685-
cache_info = self.checksum_to_path_info(checksum)
686-
if self.exists(path_info):
687-
msg = "data '{}' exists. Removing before checkout."
688-
logger.warning(msg.format(str(path_info)))
690+
# NOTE: In case if path_info is already cached and path_info's
691+
# link type matches cache link type, we would like to avoid
692+
# relinking.
693+
if self.changed(
694+
path_info, {self.PARAM_CHECKSUM: checksum}
695+
) or not self._link_matches(path_info):
689696
self.safe_remove(path_info, force=force)
690697

691-
self.link(cache_info, path_info)
692-
self.state.save_link(path_info)
693-
self.state.save(path_info, checksum)
698+
cache_info = self.checksum_to_path_info(checksum)
699+
self.link(cache_info, path_info)
700+
701+
if save_link:
702+
self.state.save_link(path_info)
703+
704+
self.state.save(path_info, checksum)
705+
else:
706+
# NOTE: performing (un)protection costs us +/- the same as checking
707+
# if path_info is protected. Instead of implementing logic,
708+
# just (un)protect according to self.protected.
709+
if self.protected:
710+
self.protect(path_info)
711+
else:
712+
# NOTE dont allow copy, because we checked before that link
713+
# type matches cache, and we don't want data duplication
714+
self.unprotect(path_info, allow_copy=False)
715+
694716
if progress_callback:
695717
progress_callback(str(path_info))
696718

719+
def _link_matches(self, path_info):
720+
return True
721+
697722
def makedirs(self, path_info):
698723
raise NotImplementedError
699724

@@ -712,17 +737,14 @@ def _checkout_dir(
712737
for entry in dir_info:
713738
relative_path = entry[self.PARAM_RELPATH]
714739
entry_checksum = entry[self.PARAM_CHECKSUM]
715-
entry_cache_info = self.checksum_to_path_info(entry_checksum)
716740
entry_info = path_info / relative_path
717-
718-
entry_checksum_info = {self.PARAM_CHECKSUM: entry_checksum}
719-
if self.changed(entry_info, entry_checksum_info):
720-
if self.exists(entry_info):
721-
self.safe_remove(entry_info, force=force)
722-
self.link(entry_cache_info, entry_info)
723-
self.state.save(entry_info, entry_checksum)
724-
if progress_callback:
725-
progress_callback(str(entry_info))
741+
self._checkout_file(
742+
entry_info,
743+
entry_checksum,
744+
force,
745+
progress_callback,
746+
save_link=False,
747+
)
726748

727749
self._remove_redundant_files(path_info, dir_info, force)
728750

@@ -803,7 +825,7 @@ def get_files_number(self, checksum):
803825
return 1
804826

805827
@staticmethod
806-
def unprotect(path_info):
828+
def unprotect(path_info, allow_copy=True):
807829
pass
808830

809831
def _get_unpacked_dir_names(self, checksums):

dvc/remote/local/__init__.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(self, repo, config):
7878
self.cache_types = types
7979
else:
8080
self.cache_types = copy(self.DEFAULT_CACHE_TYPES)
81+
self.cache_type_confirmed = False
8182

8283
# A clunky way to detect cache dir
8384
storagepath = config.get(Config.SECTION_LOCAL_STORAGEPATH, None)
@@ -188,6 +189,7 @@ def _try_links(self, from_info, to_info, link_types):
188189
link_method = self._get_link_method(link_types[0])
189190
try:
190191
self._do_link(from_info, to_info, link_method)
192+
self.cache_type_confirmed = True
191193
return
192194

193195
except DvcException as exc:
@@ -471,19 +473,22 @@ def _log_missing_caches(checksum_info_dict):
471473
logger.warning(msg)
472474

473475
@staticmethod
474-
def _unprotect_file(path):
476+
def _unprotect_file(path, allow_copy=True):
475477
if System.is_symlink(path) or System.is_hardlink(path):
476-
logger.debug("Unprotecting '{}'".format(path))
477-
tmp = os.path.join(os.path.dirname(path), "." + str(uuid()))
478-
479-
# The operations order is important here - if some application
480-
# would access the file during the process of copyfile then it
481-
# would get only the part of file. So, at first, the file should be
482-
# copied with the temporary name, and then original file should be
483-
# replaced by new.
484-
copyfile(path, tmp, name="Unprotecting '{}'".format(relpath(path)))
485-
remove(path)
486-
os.rename(tmp, path)
478+
if allow_copy:
479+
logger.debug("Unprotecting '{}'".format(path))
480+
tmp = os.path.join(os.path.dirname(path), "." + str(uuid()))
481+
482+
# The operations order is important here - if some application
483+
# would access the file during the process of copyfile then it
484+
# would get only the part of file. So, at first, the file
485+
# should be copied with the temporary name, and then
486+
# original file should be replaced by new.
487+
copyfile(
488+
path, tmp, name="Unprotecting '{}'".format(relpath(path))
489+
)
490+
remove(path)
491+
os.rename(tmp, path)
487492

488493
else:
489494
logger.debug(
@@ -493,21 +498,21 @@ def _unprotect_file(path):
493498

494499
os.chmod(path, os.stat(path).st_mode | stat.S_IWRITE)
495500

496-
def _unprotect_dir(self, path):
501+
def _unprotect_dir(self, path, allow_copy=True):
497502
for fname in walk_files(path, self.repo.dvcignore):
498-
RemoteLOCAL._unprotect_file(fname)
503+
self._unprotect_file(fname, allow_copy)
499504

500-
def unprotect(self, path_info):
505+
def unprotect(self, path_info, allow_copy=True):
501506
path = path_info.fspath
502507
if not os.path.exists(path):
503508
raise DvcException(
504509
"can't unprotect non-existing data '{}'".format(path)
505510
)
506511

507512
if os.path.isdir(path):
508-
self._unprotect_dir(path)
513+
self._unprotect_dir(path, allow_copy)
509514
else:
510-
RemoteLOCAL._unprotect_file(path)
515+
self._unprotect_file(path, allow_copy)
511516

512517
@staticmethod
513518
def protect(path_info):
@@ -581,3 +586,36 @@ def _get_unpacked_dir_names(self, checksums):
581586
if self.is_dir_checksum(c):
582587
unpacked.add(c + self.UNPACKED_DIR_SUFFIX)
583588
return unpacked
589+
590+
def _get_cache_type(self, path_info):
591+
if self.cache_type_confirmed:
592+
return self.cache_types[0]
593+
594+
workspace_file = path_info.with_name("." + uuid())
595+
test_cache_file = self.path_info / ".cache_type_test_file"
596+
if not self.exists(test_cache_file):
597+
with open(fspath_py35(test_cache_file), "wb") as fobj:
598+
fobj.write(bytes(1))
599+
try:
600+
self.link(test_cache_file, workspace_file)
601+
finally:
602+
self.remove(workspace_file)
603+
self.remove(test_cache_file)
604+
605+
self.cache_type_confirmed = True
606+
return self.cache_types[0]
607+
608+
def _link_matches(self, path_info):
609+
is_hardlink = System.is_hardlink(path_info)
610+
is_symlink = System.is_symlink(path_info)
611+
is_copy_or_reflink = not is_hardlink and not is_symlink
612+
613+
cache_type = self._get_cache_type(path_info)
614+
615+
if cache_type == "symlink":
616+
return is_symlink
617+
618+
if cache_type == "hardlink":
619+
return is_hardlink
620+
621+
return is_copy_or_reflink

tests/func/test_add.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,3 +553,81 @@ def test_readding_dir_should_not_unprotect_all(dvc_repo, repo_dir):
553553

554554
assert not unprotect_spy.mock.called
555555
assert System.is_symlink(new_file)
556+
557+
558+
def test_should_not_checkout_when_adding_cached_copy(repo_dir, dvc_repo):
559+
dvc_repo.cache.local.cache_types = ["copy"]
560+
561+
dvc_repo.add(repo_dir.FOO)
562+
dvc_repo.add(repo_dir.BAR)
563+
564+
shutil.copy(repo_dir.BAR, repo_dir.FOO)
565+
566+
copy_spy = spy(shutil.copyfile)
567+
568+
RemoteLOCAL.CACHE_TYPE_MAP["copy"] = copy_spy
569+
dvc_repo.add(repo_dir.FOO)
570+
571+
assert copy_spy.mock.call_count == 0
572+
573+
574+
@pytest.mark.parametrize(
575+
"link,new_link,link_test_func",
576+
[
577+
("hardlink", "copy", lambda path: not System.is_hardlink(path)),
578+
("symlink", "copy", lambda path: not System.is_symlink(path)),
579+
("copy", "hardlink", System.is_hardlink),
580+
("copy", "symlink", System.is_symlink),
581+
],
582+
)
583+
def test_should_relink_on_repeated_add(
584+
link, new_link, link_test_func, repo_dir, dvc_repo
585+
):
586+
dvc_repo.cache.local.cache_types = [link]
587+
588+
dvc_repo.add(repo_dir.FOO)
589+
dvc_repo.add(repo_dir.BAR)
590+
591+
os.remove(repo_dir.FOO)
592+
RemoteLOCAL.CACHE_TYPE_MAP[link](repo_dir.BAR, repo_dir.FOO)
593+
594+
dvc_repo.cache.local.cache_types = [new_link]
595+
596+
dvc_repo.add(repo_dir.FOO)
597+
598+
assert link_test_func(repo_dir.FOO)
599+
600+
601+
@pytest.mark.parametrize(
602+
"link, link_func",
603+
[("hardlink", System.hardlink), ("symlink", System.symlink)],
604+
)
605+
def test_should_relink_single_file_in_dir(link, link_func, dvc_repo, repo_dir):
606+
dvc_repo.cache.local.cache_types = [link]
607+
608+
dvc_repo.add(repo_dir.DATA_DIR)
609+
610+
# NOTE status triggers unpacked dir creation for hardlink case
611+
dvc_repo.status()
612+
613+
dvc_repo.unprotect(repo_dir.DATA_SUB)
614+
615+
link_spy = spy(link_func)
616+
RemoteLOCAL.CACHE_TYPE_MAP[link] = link_spy
617+
dvc_repo.add(repo_dir.DATA_DIR)
618+
619+
assert link_spy.mock.call_count == 1
620+
621+
622+
@pytest.mark.parametrize("link", ["hardlink", "symlink", "copy"])
623+
def test_should_protect_on_repeated_add(link, dvc_repo, repo_dir):
624+
dvc_repo.cache.local.cache_types = [link]
625+
dvc_repo.cache.local.protected = True
626+
627+
dvc_repo.add(repo_dir.FOO)
628+
629+
dvc_repo.unprotect(repo_dir.FOO)
630+
631+
dvc_repo.add(repo_dir.FOO)
632+
633+
assert not os.access(repo_dir.FOO, os.W_OK)

0 commit comments

Comments
 (0)