Skip to content

Commit

Permalink
perf: remove fs exists check in plots, parallel data collect
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein committed Jan 11, 2023
1 parent a060049 commit 1f1ed6e
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 54 deletions.
11 changes: 1 addition & 10 deletions dvc/repo/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def _collect_paths(
repo: "Repo",
targets: Iterable[str],
recursive: bool = False,
rev: str = None,
) -> StrPaths:
from dvc.fs.dvc import DVCFileSystem

Expand All @@ -39,13 +38,6 @@ def _collect_paths(
for fs_path in fs_paths:
if recursive and fs.isdir(fs_path):
target_paths.extend(fs.find(fs_path))

rel = fs.path.relpath(fs_path)
if not fs.exists(fs_path):
if rev == "workspace" or rev == "":
logger.warning("'%s' was not found in current workspace.", rel)
else:
logger.warning("'%s' was not found at: '%s'.", rel, rev)
target_paths.append(fs_path)

return target_paths
Expand Down Expand Up @@ -73,7 +65,6 @@ def collect(
deps: bool = False,
targets: Iterable[str] = None,
output_filter: FilterFn = None,
rev: str = None,
recursive: bool = False,
duplicates: bool = False,
) -> Tuple[Outputs, StrPaths]:
Expand All @@ -85,6 +76,6 @@ def collect(
fs_paths: StrPaths = []
return outs, fs_paths

target_paths = _collect_paths(repo, targets, recursive=recursive, rev=rev)
target_paths = _collect_paths(repo, targets, recursive=recursive)

return _filter_outs(outs, target_paths, duplicates=duplicates)
2 changes: 1 addition & 1 deletion dvc/repo/experiments/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def collect_experiment_commit(
result["timestamp"] = datetime.fromtimestamp(commit.commit_time)

params = _gather_params(
repo, rev=rev, targets=None, deps=param_deps, onerror=onerror
repo, targets=None, deps=param_deps, onerror=onerror
)
if params:
result["params"] = params
Expand Down
10 changes: 3 additions & 7 deletions dvc/repo/metrics/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,9 @@ def _collect_top_level_metrics(repo):
yield repo.fs.path.normpath(path)


def _collect_metrics(repo, targets, revision, recursive):
def _collect_metrics(repo, targets, recursive):
metrics, fs_paths = collect(
repo,
targets=targets,
output_filter=_is_metric,
recursive=recursive,
rev=revision,
repo, targets=targets, output_filter=_is_metric, recursive=recursive
)
return _to_fs_paths(metrics) + list(fs_paths)

Expand Down Expand Up @@ -109,7 +105,7 @@ def _read_metrics(repo, metrics, rev, onerror=None):


def _gather_metrics(repo, targets, rev, recursive, onerror=None):
metrics = _collect_metrics(repo, targets, rev, recursive)
metrics = _collect_metrics(repo, targets, recursive)
metrics.extend(_collect_top_level_metrics(repo))
return _read_metrics(repo, metrics, rev, onerror=onerror)

Expand Down
10 changes: 3 additions & 7 deletions dvc/repo/params/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,14 @@ def _collect_top_level_params(repo):


def _collect_configs(
repo: "Repo", rev, targets=None, deps=False, stages=None
repo: "Repo", targets=None, deps=False, stages=None
) -> Tuple[List["Output"], List[str]]:

params, fs_paths = collect(
repo,
targets=targets or [],
deps=True,
output_filter=_is_params,
rev=rev,
duplicates=deps or stages is not None,
)
all_fs_paths = fs_paths + [p.fs_path for p in params]
Expand Down Expand Up @@ -156,7 +155,6 @@ def show(
for branch in repo.brancher(revs=revs):
params = error_handler(_gather_params)(
repo=repo,
rev=branch,
targets=targets,
deps=deps,
onerror=onerror,
Expand Down Expand Up @@ -188,11 +186,9 @@ def show(
return res


def _gather_params(
repo, rev, targets=None, deps=False, onerror=None, stages=None
):
def _gather_params(repo, targets=None, deps=False, onerror=None, stages=None):
param_outs, params_fs_paths = _collect_configs(
repo, rev, targets=targets, deps=deps, stages=stages
repo, targets=targets, deps=deps, stages=stages
)
params_fs_paths.extend(_collect_top_level_params(repo=repo))
params = _read_params(
Expand Down
33 changes: 23 additions & 10 deletions dvc/repo/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import OrderedDict, defaultdict
from copy import deepcopy
from functools import partial
from multiprocessing import cpu_count
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -24,6 +25,7 @@
from dvc.exceptions import DvcException
from dvc.utils import error_handler, errored_revisions, onerror_collect
from dvc.utils.serialize import LOADERS
from dvc.utils.threadpool import ThreadPoolExecutor

if TYPE_CHECKING:
from dvc.output import Output
Expand Down Expand Up @@ -136,7 +138,6 @@ def collect(
data_targets = _get_data_targets(definitions)

res[rev]["sources"] = self._collect_data_sources(
revision=rev,
targets=data_targets,
recursive=recursive,
props=props,
Expand All @@ -148,7 +149,6 @@ def collect(
def _collect_data_sources(
self,
targets: Optional[List[str]] = None,
revision: Optional[str] = None,
recursive: bool = False,
props: Optional[Dict] = None,
onerror: Optional[Callable] = None,
Expand All @@ -159,7 +159,7 @@ def _collect_data_sources(

props = props or {}

plots = _collect_plots(self.repo, targets, revision, recursive)
plots = _collect_plots(self.repo, targets, recursive)
res: Dict[str, Any] = {}
for fs_path, rev_props in plots.items():
joined_props = {**rev_props, **props}
Expand Down Expand Up @@ -270,19 +270,33 @@ def _is_plot(out: "Output") -> bool:


def _resolve_data_sources(plots_data: Dict):
for value in plots_data.values():
values = list(plots_data.values())
to_resolve = []
while values:
value = values.pop()
if isinstance(value, dict):
if "data_source" in value:
data_source = value.pop("data_source")
assert callable(data_source)
value.update(data_source())
_resolve_data_sources(value)
to_resolve.append(value)
values.extend(value.values())

def resolve(value):
data_source = value.pop("data_source")
assert callable(data_source)
value.update(data_source())

executor = ThreadPoolExecutor(
max_workers=4 * cpu_count(),
thread_name_prefix="resolve_data",
cancel_on_error=True,
)
with executor:
# imap_unordered is lazy, wrapping to trigger it
list(executor.imap_unordered(resolve, to_resolve))


def _collect_plots(
repo: "Repo",
targets: List[str] = None,
rev: str = None,
recursive: bool = False,
) -> Dict[str, Dict]:
from dvc.repo.collect import collect
Expand All @@ -291,7 +305,6 @@ def _collect_plots(
repo,
output_filter=_is_plot,
targets=targets,
rev=rev,
recursive=recursive,
)

Expand Down
1 change: 1 addition & 0 deletions dvc/utils/threadpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def create_taskset(n: int) -> Set[futures.Future]:

it = zip(*iterables)
tasks = create_taskset(self.max_workers * 5)

while tasks:
done, tasks = futures.wait(
tasks, return_when=futures.FIRST_COMPLETED
Expand Down
10 changes: 2 additions & 8 deletions tests/func/metrics/test_show.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import os

import pytest
Expand Down Expand Up @@ -258,13 +257,8 @@ def test_show_malformed_metric(tmp_dir, scm, dvc, caplog):
)


def test_metrics_show_no_target(tmp_dir, dvc, caplog):
with caplog.at_level(logging.WARNING):
assert dvc.metrics.show(targets=["metrics.json"]) == {"": {}}

assert (
"'metrics.json' was not found in current workspace." in caplog.messages
)
def test_metrics_show_no_target(tmp_dir, dvc, capsys):
assert dvc.metrics.show(targets=["metrics.json"]) == {"": {}}


def test_show_no_metrics_files(tmp_dir, dvc, caplog):
Expand Down
8 changes: 6 additions & 2 deletions tests/func/plots/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,18 @@ def test_show_from_subdir(tmp_dir, dvc, capsys):
assert (subdir / "dvc_plots" / "index.html").is_file()


def test_plots_show_non_existing(tmp_dir, dvc, caplog):
def test_plots_show_non_existing(tmp_dir, dvc, capsys):
result = dvc.plots.show(targets=["plot.json"])
assert isinstance(
get_plot(result, "workspace", file="plot.json", endkey="error"),
FileNotFoundError,
)

assert "'plot.json' was not found in current workspace." in caplog.text
cap = capsys.readouterr()
assert (
"DVC failed to load some plots for following revisions: 'workspace'"
in cap.err
)


@pytest.mark.parametrize("clear_before_run", [True, False])
Expand Down
9 changes: 0 additions & 9 deletions tests/unit/test_collect.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
import logging

from dvc.repo.collect import collect


def test_no_file_on_target_rev(tmp_dir, scm, dvc, caplog):
with caplog.at_level(logging.WARNING, "dvc"):
collect(dvc, targets=["file.yaml"], rev="current_branch")

assert "'file.yaml' was not found at: 'current_branch'." in caplog.text


def test_collect_duplicates(tmp_dir, scm, dvc):
tmp_dir.gen("params.yaml", "foo: 1\nbar: 2")
tmp_dir.gen("foobar", "")
Expand Down

0 comments on commit 1f1ed6e

Please sign in to comment.