Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: top level params and metrics #8529

Merged
merged 4 commits into from
Dec 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion dvc/repo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,50 @@ def __init__(
if stages is not None:
self.stages: List["Stage"] = stages
self._collected_targets: Dict[int, List["StageInfo"]] = {}
self._metrics: Dict[str, List[str]] = {}
self._plots: Dict[str, Any] = {}
self._params: Dict[str, List[str]] = {}

@cached_property
def stages(self) -> List["Stage"]: # pylint: disable=method-hidden
# note that ideally we should be keeping this in a set as it is unique,
# hashable and has no concept of orderliness on its own. But we depend
# on this to be somewhat ordered for status/metrics/plots, etc.
return self._collect()

@cached_property
def _top_metrics(self):
self._collect()
return self._metrics

@cached_property
def _top_plots(self):
self._collect()
return self._plots

@cached_property
def _top_params(self):
self._collect()
return self._params

def _collect(self):
if "stages" in self.__dict__:
return self.stages
Comment on lines +83 to +99
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These all might be temporary, we probably should unify top-level metrics and stage-level metrics. Hence, keeping them as private.


onerror = self.repo.stage_collection_error_handler
return self.stage_collector.collect_repo(onerror=onerror)

# pylint: disable=protected-access
(
stages,
metrics,
plots,
params,
) = self.stage_collector._collect_all_from_repo(onerror=onerror)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Collection needs to be brought inside Index from Repo.stage. This seemed to require more change than I thought, as we have lots of dependencies on stage.collect(), which probably also needs to be changed/refactored.

self.stages = stages
self._metrics = metrics
self._plots = plots
self._params = params
return stages

def __repr__(self) -> str:
from dvc.fs import LocalFileSystem
Expand Down
19 changes: 18 additions & 1 deletion dvc/repo/metrics/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from dvc.repo import locked
from dvc.repo.collect import StrPaths, collect
from dvc.scm import NoSCMError
from dvc.utils import error_handler, errored_revisions, onerror_collect
from dvc.utils import (
as_posix,
error_handler,
errored_revisions,
onerror_collect,
)
from dvc.utils.collections import ensure_list
from dvc.utils.serialize import load_path

Expand All @@ -28,6 +33,17 @@ def _to_fs_paths(metrics: List[Output]) -> StrPaths:
return result


def _collect_top_level_metrics(repo):
top_metrics = repo.index._top_metrics # pylint: disable=protected-access
for dvcfile, metrics in top_metrics.items():
wdir = repo.fs.path.relpath(
repo.fs.path.parent(dvcfile), repo.root_dir
)
for file in metrics:
path = repo.fs.path.join(wdir, as_posix(file))
yield repo.fs.path.normpath(path)


def _collect_metrics(repo, targets, revision, recursive):
metrics, fs_paths = collect(
repo,
Expand Down Expand Up @@ -94,6 +110,7 @@ def _read_metrics(repo, metrics, rev, onerror=None):

def _gather_metrics(repo, targets, rev, recursive, onerror=None):
metrics = _collect_metrics(repo, targets, rev, recursive)
metrics.extend(_collect_top_level_metrics(repo))
return _read_metrics(repo, metrics, rev, onerror=onerror)


Expand Down
19 changes: 18 additions & 1 deletion dvc/repo/params/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from dvc.scm import NoSCMError
from dvc.stage import PipelineStage
from dvc.ui import ui
from dvc.utils import error_handler, errored_revisions, onerror_collect
from dvc.utils import (
as_posix,
error_handler,
errored_revisions,
onerror_collect,
)
from dvc.utils.collections import ensure_list
from dvc.utils.serialize import load_path

Expand All @@ -35,6 +40,17 @@ def _is_params(dep: "Output"):
return isinstance(dep, ParamsDependency)


def _collect_top_level_params(repo):
top_params = repo.index._top_params # pylint: disable=protected-access
for dvcfile, params in top_params.items():
wdir = repo.fs.path.relpath(
repo.fs.path.parent(dvcfile), repo.root_dir
)
for file in params:
path = repo.fs.path.join(wdir, as_posix(file))
yield repo.fs.path.normpath(path)


def _collect_configs(
repo: "Repo", rev, targets=None, deps=False, stages=None
) -> Tuple[List["Output"], List[str]]:
Expand Down Expand Up @@ -176,6 +192,7 @@ def _gather_params(
param_outs, params_fs_paths = _collect_configs(
repo, rev, targets=targets, deps=deps, stages=stages
)
params_fs_paths.extend(_collect_top_level_params(repo=repo))
params = _read_params(
repo,
params=param_outs,
Expand Down
55 changes: 26 additions & 29 deletions dvc/repo/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,36 +439,33 @@ def _resolve_definitions(


def _collect_pipeline_files(repo, targets: List[str], props, onerror=None):
from dvc.dvcfile import PipelineFile

result: Dict[str, Dict] = {}
dvcfiles = {stage.dvcfile for stage in repo.index.stages}
for dvcfile in dvcfiles:
if isinstance(dvcfile, PipelineFile):
dvcfile_path = _relpath(repo.dvcfs, dvcfile.path)
dvcfile_defs = dvcfile.load().get("plots", {})
dvcfile_defs_dict: Dict[str, Union[Dict, None]] = {}
if isinstance(dvcfile_defs, list):
for elem in dvcfile_defs:
if isinstance(elem, str):
dvcfile_defs_dict[elem] = None
else:
k, v = list(elem.items())[0]
dvcfile_defs_dict[k] = v
else:
dvcfile_defs_dict = dvcfile_defs
resolved = _resolve_definitions(
repo.dvcfs,
targets,
props,
dvcfile_path,
dvcfile_defs_dict,
onerror=onerror,
)
dpath.util.merge(
result,
{dvcfile_path: resolved},
)
top_plots = repo.index._top_plots # pylint: disable=protected-access
for dvcfile, plots_def in top_plots.items():
dvcfile_path = _relpath(repo.dvcfs, dvcfile)
dvcfile_defs_dict: Dict[str, Union[Dict, None]] = {}
if isinstance(plots_def, list):
for elem in plots_def:
if isinstance(elem, str):
dvcfile_defs_dict[elem] = None
else:
k, v = list(elem.items())[0]
dvcfile_defs_dict[k] = v
else:
dvcfile_defs_dict = plots_def

resolved = _resolve_definitions(
repo.dvcfs,
targets,
props,
dvcfile_path,
dvcfile_defs_dict,
onerror=onerror,
)
dpath.util.merge(
result,
{dvcfile_path: resolved},
)
return result


Expand Down
101 changes: 101 additions & 0 deletions dvc/repo/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,3 +502,104 @@ def is_out_or_ignored(root, directory):

def collect_repo(self, onerror: Callable[[str, Exception], None] = None):
return list(self._collect_repo(onerror))

def _load_file(self, path):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicated from load_file above that also provides params/metrics/plots. Will refactor in successive PRs. This method serves _collect_all_from_repo below.

from dvc.dvcfile import Dvcfile
from dvc.stage.loader import SingleStageLoader, StageLoader

path = self._get_filepath(path)
dvcfile = Dvcfile(self.repo, path)
# `dvcfile.stages` is not cached
stages = dvcfile.stages # type: ignore

if isinstance(stages, SingleStageLoader):
stages_ = [stages[None]]
else:
assert isinstance(stages, StageLoader)
keys = self._get_keys(stages)
stages_ = [stages[key] for key in keys]

return (
stages_,
stages.metrics_data,
stages.plots_data,
stages.params_data,
)

def _collect_all_from_repo(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bad naming I know, also duplicates logic from _collect_repo(). Again, will be refactored in successive PRs.

self, onerror: Callable[[str, Exception], None] = None
):
"""Collects all of the stages present in the DVC repo.

Args:
onerror (optional): callable that will be called with two args:
the filepath whose collection failed and the exc instance.
It can report the error to continue with the collection
(and, skip failed ones), or raise the exception to abort
the collection.
"""
from dvc.dvcfile import is_valid_filename
from dvc.fs import LocalFileSystem

scm = self.repo.scm
sep = self.fs.sep
outs: Set[str] = set()

is_local_fs = isinstance(self.fs, LocalFileSystem)

def is_ignored(path):
# apply only for the local fs
return is_local_fs and scm.is_ignored(path)

def is_dvcfile_and_not_ignored(root, file):
return is_valid_filename(file) and not is_ignored(
f"{root}{sep}{file}"
)

def is_out_or_ignored(root, directory):
dir_path = f"{root}{sep}{directory}"
# trailing slash needed to check if a directory is gitignored
return dir_path in outs or is_ignored(f"{dir_path}{sep}")

walk_iter = self.repo.dvcignore.walk(self.fs, self.repo.root_dir)
if logger.isEnabledFor(logging.TRACE): # type: ignore[attr-defined]
walk_iter = log_walk(walk_iter)

stages = []
metrics = {}
plots = {}
params = {}

for root, dirs, files in walk_iter:
dvcfile_filter = partial(is_dvcfile_and_not_ignored, root)
for file in filter(dvcfile_filter, files):
file_path = self.fs.path.join(root, file)
try:
(
new_stages,
new_metrics,
new_plots,
new_params,
) = self._load_file(file_path)
except DvcException as exc:
if onerror:
onerror(relpath(file_path), exc)
continue
raise

stages.extend(new_stages)
if new_metrics:
metrics[file_path] = new_metrics
if new_plots:
plots[file_path] = new_plots
if new_params:
params[file_path] = new_params

outs.update(
out.fspath
for stage in new_stages
for out in stage.outs
if out.protocol == "local"
)
dirs[:] = [d for d in dirs if not is_out_or_ignored(root, d)]
return stages, metrics, plots, params
2 changes: 2 additions & 0 deletions dvc/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def validator(data):
PLOTS: Any(SINGLE_PLOT_SCHEMA, [Any(str, SINGLE_PLOT_SCHEMA)]),
STAGES: SINGLE_PIPELINE_STAGE_SCHEMA,
VARS_KWD: VARS_SCHEMA,
StageParams.PARAM_PARAMS: [str],
StageParams.PARAM_METRICS: [str],
}

COMPILED_SINGLE_STAGE_SCHEMA = Schema(SINGLE_STAGE_SCHEMA)
Expand Down
6 changes: 6 additions & 0 deletions dvc/stage/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def __init__(self, dvcfile, data, lockfile_data=None):
self.dvcfile = dvcfile
self.data = data or {}
self.stages_data = self.data.get("stages", {})
self.metrics_data = self.data.get("metrics", [])
self.params_data = self.data.get("params", [])
self.plots_data = self.data.get("plots", {})
Comment on lines +27 to +29
Copy link
Member Author

@skshetry skshetry Dec 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole PR is standing on these 3 lines. It feels weird to use StageLoader to get other sections, but this was a quick-and-easy solution.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to rearchitect here.

self.repo = self.dvcfile.repo

lockfile_data = lockfile_data or {}
Expand Down Expand Up @@ -171,6 +174,9 @@ def __init__(self, dvcfile, stage_data, stage_text=None):
self.dvcfile = dvcfile
self.stage_data = stage_data or {}
self.stage_text = stage_text
self.metrics_data = []
self.params_data = []
self.plots_data = {}

def __getitem__(self, item):
if item:
Expand Down
29 changes: 29 additions & 0 deletions tests/func/metrics/test_diff.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import json
from os.path import join

import pytest

from dvc.cli import main
from dvc.utils import relpath


def test_metrics_diff_simple(tmp_dir, scm, dvc, run_copy_metrics):
Expand Down Expand Up @@ -206,3 +210,28 @@ def _gen(val):
assert result == {
"some_file.yaml": {"foo": {"old": 1, "new": 3, "diff": 2}}
}


@pytest.mark.parametrize(
"dvcfile, metrics_file",
[
("dvc.yaml", "my_metrics.yaml"),
("dir/dvc.yaml", "my_metrics.yaml"),
("dir/dvc.yaml", join("..", "my_metrics.yaml")),
],
)
def test_diff_top_level_metrics(tmp_dir, dvc, scm, dvcfile, metrics_file):
directory = (tmp_dir / dvcfile).parent
directory.mkdir(exist_ok=True)
(tmp_dir / dvcfile).dump({"metrics": [metrics_file]})

metrics_file = directory / metrics_file
metrics_file.dump({"foo": 3})
scm.add_commit([metrics_file, tmp_dir / dvcfile], message="add metrics")

metrics_file.dump({"foo": 5})
assert dvc.metrics.diff() == {
relpath(directory / metrics_file): {
"foo": {"diff": 2, "new": 5, "old": 3}
}
}
Loading