diff --git a/dvc/repo/index.py b/dvc/repo/index.py index 18eb6d2a98..cd52979285 100644 --- a/dvc/repo/index.py +++ b/dvc/repo/index.py @@ -68,14 +68,50 @@ def __init__( if stages is not None: self.stages: List["Stage"] = stages self._collected_targets: Dict[int, List["StageInfo"]] = {} + self._metrics: Dict[str, List[str]] = {} + self._plots: Dict[str, Any] = {} + self._params: Dict[str, List[str]] = {} @cached_property def stages(self) -> List["Stage"]: # pylint: disable=method-hidden # note that ideally we should be keeping this in a set as it is unique, # hashable and has no concept of orderliness on its own. But we depend # on this to be somewhat ordered for status/metrics/plots, etc. + return self._collect() + + @cached_property + def _top_metrics(self): + self._collect() + return self._metrics + + @cached_property + def _top_plots(self): + self._collect() + return self._plots + + @cached_property + def _top_params(self): + self._collect() + return self._params + + def _collect(self): + if "stages" in self.__dict__: + return self.stages + onerror = self.repo.stage_collection_error_handler - return self.stage_collector.collect_repo(onerror=onerror) + + # pylint: disable=protected-access + ( + stages, + metrics, + plots, + params, + ) = self.stage_collector._collect_all_from_repo(onerror=onerror) + self.stages = stages + self._metrics = metrics + self._plots = plots + self._params = params + return stages def __repr__(self) -> str: from dvc.fs import LocalFileSystem diff --git a/dvc/repo/metrics/show.py b/dvc/repo/metrics/show.py index 69dc7c1ae0..82ca5cfa4c 100644 --- a/dvc/repo/metrics/show.py +++ b/dvc/repo/metrics/show.py @@ -9,7 +9,12 @@ from dvc.repo import locked from dvc.repo.collect import StrPaths, collect from dvc.scm import NoSCMError -from dvc.utils import error_handler, errored_revisions, onerror_collect +from dvc.utils import ( + as_posix, + error_handler, + errored_revisions, + onerror_collect, +) from dvc.utils.collections import ensure_list from dvc.utils.serialize import load_path @@ -28,6 +33,17 @@ def _to_fs_paths(metrics: List[Output]) -> StrPaths: return result +def _collect_top_level_metrics(repo): + top_metrics = repo.index._top_metrics # pylint: disable=protected-access + for dvcfile, metrics in top_metrics.items(): + wdir = repo.fs.path.relpath( + repo.fs.path.parent(dvcfile), repo.root_dir + ) + for file in metrics: + path = repo.fs.path.join(wdir, as_posix(file)) + yield repo.fs.path.normpath(path) + + def _collect_metrics(repo, targets, revision, recursive): metrics, fs_paths = collect( repo, @@ -94,6 +110,7 @@ def _read_metrics(repo, metrics, rev, onerror=None): def _gather_metrics(repo, targets, rev, recursive, onerror=None): metrics = _collect_metrics(repo, targets, rev, recursive) + metrics.extend(_collect_top_level_metrics(repo)) return _read_metrics(repo, metrics, rev, onerror=onerror) diff --git a/dvc/repo/params/show.py b/dvc/repo/params/show.py index 684eb4037d..f4a9b48c32 100644 --- a/dvc/repo/params/show.py +++ b/dvc/repo/params/show.py @@ -20,7 +20,12 @@ from dvc.scm import NoSCMError from dvc.stage import PipelineStage from dvc.ui import ui -from dvc.utils import error_handler, errored_revisions, onerror_collect +from dvc.utils import ( + as_posix, + error_handler, + errored_revisions, + onerror_collect, +) from dvc.utils.collections import ensure_list from dvc.utils.serialize import load_path @@ -35,6 +40,17 @@ def _is_params(dep: "Output"): return isinstance(dep, ParamsDependency) +def _collect_top_level_params(repo): + top_params = repo.index._top_params # pylint: disable=protected-access + for dvcfile, params in top_params.items(): + wdir = repo.fs.path.relpath( + repo.fs.path.parent(dvcfile), repo.root_dir + ) + for file in params: + path = repo.fs.path.join(wdir, as_posix(file)) + yield repo.fs.path.normpath(path) + + def _collect_configs( repo: "Repo", rev, targets=None, deps=False, stages=None ) -> Tuple[List["Output"], List[str]]: @@ -176,6 +192,7 @@ def _gather_params( param_outs, params_fs_paths = _collect_configs( repo, rev, targets=targets, deps=deps, stages=stages ) + params_fs_paths.extend(_collect_top_level_params(repo=repo)) params = _read_params( repo, params=param_outs, diff --git a/dvc/repo/plots/__init__.py b/dvc/repo/plots/__init__.py index 755980ddd2..74f12a4f4a 100644 --- a/dvc/repo/plots/__init__.py +++ b/dvc/repo/plots/__init__.py @@ -439,36 +439,33 @@ def _resolve_definitions( def _collect_pipeline_files(repo, targets: List[str], props, onerror=None): - from dvc.dvcfile import PipelineFile - result: Dict[str, Dict] = {} - dvcfiles = {stage.dvcfile for stage in repo.index.stages} - for dvcfile in dvcfiles: - if isinstance(dvcfile, PipelineFile): - dvcfile_path = _relpath(repo.dvcfs, dvcfile.path) - dvcfile_defs = dvcfile.load().get("plots", {}) - dvcfile_defs_dict: Dict[str, Union[Dict, None]] = {} - if isinstance(dvcfile_defs, list): - for elem in dvcfile_defs: - if isinstance(elem, str): - dvcfile_defs_dict[elem] = None - else: - k, v = list(elem.items())[0] - dvcfile_defs_dict[k] = v - else: - dvcfile_defs_dict = dvcfile_defs - resolved = _resolve_definitions( - repo.dvcfs, - targets, - props, - dvcfile_path, - dvcfile_defs_dict, - onerror=onerror, - ) - dpath.util.merge( - result, - {dvcfile_path: resolved}, - ) + top_plots = repo.index._top_plots # pylint: disable=protected-access + for dvcfile, plots_def in top_plots.items(): + dvcfile_path = _relpath(repo.dvcfs, dvcfile) + dvcfile_defs_dict: Dict[str, Union[Dict, None]] = {} + if isinstance(plots_def, list): + for elem in plots_def: + if isinstance(elem, str): + dvcfile_defs_dict[elem] = None + else: + k, v = list(elem.items())[0] + dvcfile_defs_dict[k] = v + else: + dvcfile_defs_dict = plots_def + + resolved = _resolve_definitions( + repo.dvcfs, + targets, + props, + dvcfile_path, + dvcfile_defs_dict, + onerror=onerror, + ) + dpath.util.merge( + result, + {dvcfile_path: resolved}, + ) return result diff --git a/dvc/repo/stage.py b/dvc/repo/stage.py index 84b4629fef..14a33786d5 100644 --- a/dvc/repo/stage.py +++ b/dvc/repo/stage.py @@ -502,3 +502,104 @@ def is_out_or_ignored(root, directory): def collect_repo(self, onerror: Callable[[str, Exception], None] = None): return list(self._collect_repo(onerror)) + + def _load_file(self, path): + from dvc.dvcfile import Dvcfile + from dvc.stage.loader import SingleStageLoader, StageLoader + + path = self._get_filepath(path) + dvcfile = Dvcfile(self.repo, path) + # `dvcfile.stages` is not cached + stages = dvcfile.stages # type: ignore + + if isinstance(stages, SingleStageLoader): + stages_ = [stages[None]] + else: + assert isinstance(stages, StageLoader) + keys = self._get_keys(stages) + stages_ = [stages[key] for key in keys] + + return ( + stages_, + stages.metrics_data, + stages.plots_data, + stages.params_data, + ) + + def _collect_all_from_repo( + self, onerror: Callable[[str, Exception], None] = None + ): + """Collects all of the stages present in the DVC repo. + + Args: + onerror (optional): callable that will be called with two args: + the filepath whose collection failed and the exc instance. + It can report the error to continue with the collection + (and, skip failed ones), or raise the exception to abort + the collection. + """ + from dvc.dvcfile import is_valid_filename + from dvc.fs import LocalFileSystem + + scm = self.repo.scm + sep = self.fs.sep + outs: Set[str] = set() + + is_local_fs = isinstance(self.fs, LocalFileSystem) + + def is_ignored(path): + # apply only for the local fs + return is_local_fs and scm.is_ignored(path) + + def is_dvcfile_and_not_ignored(root, file): + return is_valid_filename(file) and not is_ignored( + f"{root}{sep}{file}" + ) + + def is_out_or_ignored(root, directory): + dir_path = f"{root}{sep}{directory}" + # trailing slash needed to check if a directory is gitignored + return dir_path in outs or is_ignored(f"{dir_path}{sep}") + + walk_iter = self.repo.dvcignore.walk(self.fs, self.repo.root_dir) + if logger.isEnabledFor(logging.TRACE): # type: ignore[attr-defined] + walk_iter = log_walk(walk_iter) + + stages = [] + metrics = {} + plots = {} + params = {} + + for root, dirs, files in walk_iter: + dvcfile_filter = partial(is_dvcfile_and_not_ignored, root) + for file in filter(dvcfile_filter, files): + file_path = self.fs.path.join(root, file) + try: + ( + new_stages, + new_metrics, + new_plots, + new_params, + ) = self._load_file(file_path) + except DvcException as exc: + if onerror: + onerror(relpath(file_path), exc) + continue + raise + + stages.extend(new_stages) + if new_metrics: + metrics[file_path] = new_metrics + if new_plots: + plots[file_path] = new_plots + if new_params: + params[file_path] = new_params + + outs.update( + out.fspath + for stage in new_stages + for out in stage.outs + if out.protocol == "local" + ) + dirs[:] = [d for d in dirs if not is_out_or_ignored(root, d)] + return stages, metrics, plots, params diff --git a/dvc/schema.py b/dvc/schema.py index b48f61aa09..02c16c2760 100644 --- a/dvc/schema.py +++ b/dvc/schema.py @@ -123,6 +123,8 @@ def validator(data): PLOTS: Any(SINGLE_PLOT_SCHEMA, [Any(str, SINGLE_PLOT_SCHEMA)]), STAGES: SINGLE_PIPELINE_STAGE_SCHEMA, VARS_KWD: VARS_SCHEMA, + StageParams.PARAM_PARAMS: [str], + StageParams.PARAM_METRICS: [str], } COMPILED_SINGLE_STAGE_SCHEMA = Schema(SINGLE_STAGE_SCHEMA) diff --git a/dvc/stage/loader.py b/dvc/stage/loader.py index 391e43f38f..93092e3894 100644 --- a/dvc/stage/loader.py +++ b/dvc/stage/loader.py @@ -24,6 +24,9 @@ def __init__(self, dvcfile, data, lockfile_data=None): self.dvcfile = dvcfile self.data = data or {} self.stages_data = self.data.get("stages", {}) + self.metrics_data = self.data.get("metrics", []) + self.params_data = self.data.get("params", []) + self.plots_data = self.data.get("plots", {}) self.repo = self.dvcfile.repo lockfile_data = lockfile_data or {} @@ -171,6 +174,9 @@ def __init__(self, dvcfile, stage_data, stage_text=None): self.dvcfile = dvcfile self.stage_data = stage_data or {} self.stage_text = stage_text + self.metrics_data = [] + self.params_data = [] + self.plots_data = {} def __getitem__(self, item): if item: diff --git a/tests/func/metrics/test_diff.py b/tests/func/metrics/test_diff.py index c88085d31a..23adde515c 100644 --- a/tests/func/metrics/test_diff.py +++ b/tests/func/metrics/test_diff.py @@ -1,6 +1,10 @@ import json +from os.path import join + +import pytest from dvc.cli import main +from dvc.utils import relpath def test_metrics_diff_simple(tmp_dir, scm, dvc, run_copy_metrics): @@ -206,3 +210,28 @@ def _gen(val): assert result == { "some_file.yaml": {"foo": {"old": 1, "new": 3, "diff": 2}} } + + +@pytest.mark.parametrize( + "dvcfile, metrics_file", + [ + ("dvc.yaml", "my_metrics.yaml"), + ("dir/dvc.yaml", "my_metrics.yaml"), + ("dir/dvc.yaml", join("..", "my_metrics.yaml")), + ], +) +def test_diff_top_level_metrics(tmp_dir, dvc, scm, dvcfile, metrics_file): + directory = (tmp_dir / dvcfile).parent + directory.mkdir(exist_ok=True) + (tmp_dir / dvcfile).dump({"metrics": [metrics_file]}) + + metrics_file = directory / metrics_file + metrics_file.dump({"foo": 3}) + scm.add_commit([metrics_file, tmp_dir / dvcfile], message="add metrics") + + metrics_file.dump({"foo": 5}) + assert dvc.metrics.diff() == { + relpath(directory / metrics_file): { + "foo": {"diff": 2, "new": 5, "old": 3} + } + } diff --git a/tests/func/params/test_diff.py b/tests/func/params/test_diff.py index f409e40124..ba181ba38b 100644 --- a/tests/func/params/test_diff.py +++ b/tests/func/params/test_diff.py @@ -1,3 +1,5 @@ +from os.path import join + import pytest from dvc.utils import relpath @@ -241,3 +243,28 @@ def test_diff_without_targets_specified(tmp_dir, dvc, scm, file): "y": {"new": "100", "old": None}, } } + + +@pytest.mark.parametrize( + "dvcfile, params_file", + [ + ("dvc.yaml", "my_params.yaml"), + ("dir/dvc.yaml", "my_params.yaml"), + ("dir/dvc.yaml", join("..", "my_params.yaml")), + ], +) +def test_diff_top_level_params(tmp_dir, dvc, scm, dvcfile, params_file): + directory = (tmp_dir / dvcfile).parent + directory.mkdir(exist_ok=True) + (tmp_dir / dvcfile).dump({"params": [params_file]}) + + params_file = directory / params_file + params_file.dump({"foo": 3}) + scm.add_commit([params_file, tmp_dir / dvcfile], message="add params") + + params_file.dump({"foo": 5}) + assert dvc.params.diff() == { + relpath(directory / params_file): { + "foo": {"diff": 2, "new": 5, "old": 3} + } + } diff --git a/tests/unit/command/test_status.py b/tests/unit/command/test_status.py index 301982f0f1..fd69a84f59 100644 --- a/tests/unit/command/test_status.py +++ b/tests/unit/command/test_status.py @@ -81,7 +81,7 @@ def test_status_empty(dvc, mocker, capsys): cmd = cli_args.func(cli_args) - spy = mocker.spy(cmd.repo.stage, "_collect_repo") + spy = mocker.spy(cmd.repo.stage, "_collect_all_from_repo") assert cmd.run() == 0 @@ -108,7 +108,10 @@ def test_status_up_to_date(dvc, mocker, capsys, cloud_opts, expected_message): mocker.patch.dict(cmd.repo.config, {"core": {"remote": "default"}}) mocker.patch.object(cmd.repo, "status", autospec=True, return_value={}) mocker.patch.object( - cmd.repo.stage, "_collect_repo", return_value=[object()], autospec=True + cmd.repo.stage, + "_collect_all_from_repo", + return_value=[[object()], {}, {}, {}], + autospec=True, ) assert cmd.run() == 0