Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

objects.db: get rid of callbacks from internal _list_paths API #7486

Merged
merged 1 commit into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions dvc/data/db/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def hashes_exist(

return ret

def _list_paths(self, prefix=None, progress_callback=None):
def _list_paths(self, prefix=None):
assert self.fs_path is not None
if prefix:
fs_path = self.fs.path.join(self.fs_path, prefix[:2])
Expand All @@ -89,12 +89,7 @@ def _list_paths(self, prefix=None, progress_callback=None):

# NOTE: use utils.fs walk_files since fs.walk_files will not follow
# symlinks
if progress_callback:
for path in walk_files(fs_path):
progress_callback()
yield path
else:
yield from walk_files(fs_path)
yield from walk_files(fs_path)

def _remove_unpacked_dir(self, hash_):
hash_fs_path = self.hash_to_path(hash_)
Expand Down
57 changes: 23 additions & 34 deletions dvc/objects/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from dvc.progress import Tqdm

if TYPE_CHECKING:
from typing import Tuple

from dvc.fs.base import FileSystem
from dvc.hash_info import HashInfo
from dvc.types import AnyPath
Expand Down Expand Up @@ -179,24 +181,14 @@ def check(
# next time
self.protect(obj.fs_path)

def _list_paths(self, prefix=None, progress_callback=None):
def _list_paths(self, prefix: str = None):
prefix = prefix or ""
parts: "Tuple[str, ...]" = (self.fs_path,)
if prefix:
if len(prefix) > 2:
fs_path = self.fs.path.join(
self.fs_path, prefix[:2], prefix[2:]
)
else:
fs_path = self.fs.path.join(self.fs_path, prefix[:2])
prefix = True
else:
fs_path = self.fs_path
prefix = False
if progress_callback:
for file_info in self.fs.find(fs_path, prefix=prefix):
progress_callback()
yield file_info
else:
yield from self.fs.find(fs_path, prefix=prefix)
parts = *parts, prefix[:2]
if len(prefix) > 2:
parts = *parts, prefix[2:]
yield from self.fs.find(self.fs.path.join(*parts), prefix=bool(prefix))

def _path_to_hash(self, path):
parts = self.fs.path.parts(path)[-2:]
Expand All @@ -206,23 +198,23 @@ def _path_to_hash(self, path):

return "".join(parts)

def _list_hashes(self, prefix=None, progress_callback=None):
def _list_hashes(self, prefix=None):
"""Iterate over hashes in this fs.

If `prefix` is specified, only hashes which begin with `prefix`
will be returned.
"""
for path in self._list_paths(prefix, progress_callback):
for path in self._list_paths(prefix):
try:
yield self._path_to_hash(path)
except ValueError:
logger.debug(
"'%s' doesn't look like a cache file, skipping", path
)

def _hashes_with_limit(self, limit, prefix=None, progress_callback=None):
def _hashes_with_limit(self, limit, prefix=None):
count = 0
for hash_ in self._list_hashes(prefix, progress_callback):
for hash_ in self._list_hashes(prefix):
yield hash_
count += 1
if count > limit:
Expand Down Expand Up @@ -258,17 +250,19 @@ def _estimate_remote_size(self, hashes=None, name=None):
unit="file",
) as pbar:

def update(n=1):
pbar.update(n * total_prefixes)
def iter_with_pbar(hashes):
for hash_ in hashes:
pbar.update(total_prefixes)
yield hash_

if max_hashes:
hashes = self._hashes_with_limit(
max_hashes / total_prefixes, prefix, update
max_hashes / total_prefixes, prefix
)
else:
hashes = self._list_hashes(prefix, update)
hashes = self._list_hashes(prefix)

remote_hashes = set(hashes)
remote_hashes = set(iter_with_pbar(hashes))
if remote_hashes:
remote_size = total_prefixes * len(remote_hashes)
else:
Expand Down Expand Up @@ -319,18 +313,13 @@ def _list_hashes_traverse(
initial=initial,
unit="file",
) as pbar:

def list_with_update(prefix):
return list(
self._list_hashes(
prefix=prefix, progress_callback=pbar.update
)
)
from funcy import collecting

with ThreadPoolExecutor(
max_workers=jobs or self.fs.jobs
) as executor:
in_remote = executor.map(list_with_update, traverse_prefixes)
list_hashes = collecting(pbar.wrap_fn(self._list_hashes))
in_remote = executor.map(list_hashes, traverse_prefixes)
yield from itertools.chain.from_iterable(in_remote)

def all(self, jobs=None, name=None):
Expand Down
12 changes: 3 additions & 9 deletions tests/unit/remote/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,15 @@ def test_list_hashes_traverse(_path_to_hash, list_hashes, dvc):
size = 256 / odb.fs._JOBS * odb.fs.LIST_OBJECT_PAGE_SIZE
list(odb._list_hashes_traverse(size, {0}))
for i in range(1, 16):
list_hashes.assert_any_call(
prefix=f"{i:03x}", progress_callback=CallableOrNone
)
list_hashes.assert_any_call(f"{i:03x}")
for i in range(1, 256):
list_hashes.assert_any_call(
prefix=f"{i:02x}", progress_callback=CallableOrNone
)
list_hashes.assert_any_call(f"{i:02x}")

# default traverse (small remote)
size -= 1
list_hashes.reset_mock()
list(odb._list_hashes_traverse(size - 1, {0}))
list_hashes.assert_called_with(
prefix=None, progress_callback=CallableOrNone
)
list_hashes.assert_called_with(None)


def test_list_hashes(dvc):
Expand Down