From 2a3607c4d5b7d1222c4f1b5282b4d859b9145256 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Tue, 22 Mar 2022 16:25:40 +0545 Subject: [PATCH] objects.db: get rid of callbacks from internal _list_paths API This removes callback from internal APIs: `_list_paths`, `list_hashes`, and `_hashes_with_limit`. These are all iterators and hence can just be replaced by wrapping a progressbar over those iterators in the caller side. We still have progressbars in `hashes_exist`, `_estimate_remote_size`, `_list_hashes_traverse` and `list_hashes_exists`, which should gradually lifted up and replaced. --- dvc/data/db/local.py | 9 ++---- dvc/objects/db.py | 57 ++++++++++++++-------------------- tests/unit/remote/test_base.py | 12 ++----- 3 files changed, 28 insertions(+), 50 deletions(-) diff --git a/dvc/data/db/local.py b/dvc/data/db/local.py index f7cb63954d..cb4004e289 100644 --- a/dvc/data/db/local.py +++ b/dvc/data/db/local.py @@ -78,7 +78,7 @@ def hashes_exist( return ret - def _list_paths(self, prefix=None, progress_callback=None): + def _list_paths(self, prefix=None): assert self.fs_path is not None if prefix: fs_path = self.fs.path.join(self.fs_path, prefix[:2]) @@ -89,12 +89,7 @@ def _list_paths(self, prefix=None, progress_callback=None): # NOTE: use utils.fs walk_files since fs.walk_files will not follow # symlinks - if progress_callback: - for path in walk_files(fs_path): - progress_callback() - yield path - else: - yield from walk_files(fs_path) + yield from walk_files(fs_path) def _remove_unpacked_dir(self, hash_): hash_fs_path = self.hash_to_path(hash_) diff --git a/dvc/objects/db.py b/dvc/objects/db.py index a30c2c6fbb..4336b9bb45 100644 --- a/dvc/objects/db.py +++ b/dvc/objects/db.py @@ -10,6 +10,8 @@ from dvc.progress import Tqdm if TYPE_CHECKING: + from typing import Tuple + from dvc.fs.base import FileSystem from dvc.hash_info import HashInfo from dvc.types import AnyPath @@ -179,24 +181,14 @@ def check( # next time self.protect(obj.fs_path) - def _list_paths(self, prefix=None, progress_callback=None): + def _list_paths(self, prefix: str = None): + prefix = prefix or "" + parts: "Tuple[str, ...]" = (self.fs_path,) if prefix: - if len(prefix) > 2: - fs_path = self.fs.path.join( - self.fs_path, prefix[:2], prefix[2:] - ) - else: - fs_path = self.fs.path.join(self.fs_path, prefix[:2]) - prefix = True - else: - fs_path = self.fs_path - prefix = False - if progress_callback: - for file_info in self.fs.find(fs_path, prefix=prefix): - progress_callback() - yield file_info - else: - yield from self.fs.find(fs_path, prefix=prefix) + parts = *parts, prefix[:2] + if len(prefix) > 2: + parts = *parts, prefix[2:] + yield from self.fs.find(self.fs.path.join(*parts), prefix=bool(prefix)) def _path_to_hash(self, path): parts = self.fs.path.parts(path)[-2:] @@ -206,13 +198,13 @@ def _path_to_hash(self, path): return "".join(parts) - def _list_hashes(self, prefix=None, progress_callback=None): + def _list_hashes(self, prefix=None): """Iterate over hashes in this fs. If `prefix` is specified, only hashes which begin with `prefix` will be returned. """ - for path in self._list_paths(prefix, progress_callback): + for path in self._list_paths(prefix): try: yield self._path_to_hash(path) except ValueError: @@ -220,9 +212,9 @@ def _list_hashes(self, prefix=None, progress_callback=None): "'%s' doesn't look like a cache file, skipping", path ) - def _hashes_with_limit(self, limit, prefix=None, progress_callback=None): + def _hashes_with_limit(self, limit, prefix=None): count = 0 - for hash_ in self._list_hashes(prefix, progress_callback): + for hash_ in self._list_hashes(prefix): yield hash_ count += 1 if count > limit: @@ -258,17 +250,19 @@ def _estimate_remote_size(self, hashes=None, name=None): unit="file", ) as pbar: - def update(n=1): - pbar.update(n * total_prefixes) + def iter_with_pbar(hashes): + for hash_ in hashes: + pbar.update(total_prefixes) + yield hash_ if max_hashes: hashes = self._hashes_with_limit( - max_hashes / total_prefixes, prefix, update + max_hashes / total_prefixes, prefix ) else: - hashes = self._list_hashes(prefix, update) + hashes = self._list_hashes(prefix) - remote_hashes = set(hashes) + remote_hashes = set(iter_with_pbar(hashes)) if remote_hashes: remote_size = total_prefixes * len(remote_hashes) else: @@ -319,18 +313,13 @@ def _list_hashes_traverse( initial=initial, unit="file", ) as pbar: - - def list_with_update(prefix): - return list( - self._list_hashes( - prefix=prefix, progress_callback=pbar.update - ) - ) + from funcy import collecting with ThreadPoolExecutor( max_workers=jobs or self.fs.jobs ) as executor: - in_remote = executor.map(list_with_update, traverse_prefixes) + list_hashes = collecting(pbar.wrap_fn(self._list_hashes)) + in_remote = executor.map(list_hashes, traverse_prefixes) yield from itertools.chain.from_iterable(in_remote) def all(self, jobs=None, name=None): diff --git a/tests/unit/remote/test_base.py b/tests/unit/remote/test_base.py index 0941922c8c..968e7db701 100644 --- a/tests/unit/remote/test_base.py +++ b/tests/unit/remote/test_base.py @@ -75,21 +75,15 @@ def test_list_hashes_traverse(_path_to_hash, list_hashes, dvc): size = 256 / odb.fs._JOBS * odb.fs.LIST_OBJECT_PAGE_SIZE list(odb._list_hashes_traverse(size, {0})) for i in range(1, 16): - list_hashes.assert_any_call( - prefix=f"{i:03x}", progress_callback=CallableOrNone - ) + list_hashes.assert_any_call(f"{i:03x}") for i in range(1, 256): - list_hashes.assert_any_call( - prefix=f"{i:02x}", progress_callback=CallableOrNone - ) + list_hashes.assert_any_call(f"{i:02x}") # default traverse (small remote) size -= 1 list_hashes.reset_mock() list(odb._list_hashes_traverse(size - 1, {0})) - list_hashes.assert_called_with( - prefix=None, progress_callback=CallableOrNone - ) + list_hashes.assert_called_with(None) def test_list_hashes(dvc):