From 2f37b1ba7e43a4f91368a959b6d6914e2f47fa2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Fri, 4 Nov 2022 16:46:00 +0545 Subject: [PATCH] wip: top level params and metrics --- dvc/repo/metrics/show.py | 23 +++++++++++++++++++++++ dvc/repo/params/show.py | 32 ++++++++++++++++++++++++++++++++ dvc/schema.py | 2 ++ 3 files changed, 57 insertions(+) diff --git a/dvc/repo/metrics/show.py b/dvc/repo/metrics/show.py index 69dc7c1ae0e..8237f237088 100644 --- a/dvc/repo/metrics/show.py +++ b/dvc/repo/metrics/show.py @@ -28,6 +28,28 @@ def _to_fs_paths(metrics: List[Output]) -> StrPaths: return result +def _collect_top_level_metrics(repo): + from dvc.dvcfile import Dvcfile + from dvc.stage import PipelineStage + + files = [] + dvcfiles = { + stage.dvcfile + for stage in repo.index.stages + if isinstance(stage, PipelineStage) + } + dvcfiles.add(Dvcfile(repo, repo.dvcfs.from_os_path("dvc.yaml"))) + for dvcfile in dvcfiles: + wdir = repo.dvcfs.path.parent(repo.dvcfs.from_os_path(dvcfile.path)) + try: + metrics = dvcfile.load().get("metrics", []) + except Exception: # pylint: disable=broad-except + logger.debug("", exc_info=True) + continue + files.extend(repo.dvcfs.path.join(wdir, file) for file in metrics) + return files + + def _collect_metrics(repo, targets, revision, recursive): metrics, fs_paths = collect( repo, @@ -94,6 +116,7 @@ def _read_metrics(repo, metrics, rev, onerror=None): def _gather_metrics(repo, targets, rev, recursive, onerror=None): metrics = _collect_metrics(repo, targets, rev, recursive) + metrics.extend(_collect_top_level_metrics(repo)) return _read_metrics(repo, metrics, rev, onerror=onerror) diff --git a/dvc/repo/params/show.py b/dvc/repo/params/show.py index 684eb4037d2..cf031a6729d 100644 --- a/dvc/repo/params/show.py +++ b/dvc/repo/params/show.py @@ -35,6 +35,27 @@ def _is_params(dep: "Output"): return isinstance(dep, ParamsDependency) +def _collect_top_level_params(repo): + from dvc.dvcfile import Dvcfile + + files = [] + dvcfiles = { + stage.dvcfile + for stage in repo.index.stages + if isinstance(stage, PipelineStage) + } + dvcfiles.add(Dvcfile(repo, repo.dvcfs.from_os_path("dvc.yaml"))) + for dvcfile in dvcfiles: + wdir = repo.dvcfs.path.parent(repo.dvcfs.from_os_path(dvcfile.path)) + try: + params = dvcfile.load().get("params", []) + except Exception: # pylint: disable=broad-except + logger.debug("", exc_info=True) + continue + files.extend(repo.dvcfs.path.join(wdir, file) for file in params) + return files + + def _collect_configs( repo: "Repo", rev, targets=None, deps=False, stages=None ) -> Tuple[List["Output"], List[str]]: @@ -93,7 +114,17 @@ def _read_params( else: fs_paths += [param.fs_path for param in params] + relpath = "" + if repo.root_dir != repo.fs.path.getcwd(): + relpath = repo.fs.path.relpath(repo.root_dir, repo.fs.path.getcwd()) + for fs_path in fs_paths: + rel_param_path = os.path.join(relpath, *repo.fs.path.parts(fs_path)) + if not repo.fs.isfile(fs_path): + if repo.fs.isfile(rel_param_path): + fs_path = rel_param_path + else: + continue from_path = _read_fs_path(repo.fs, fs_path, onerror=onerror) if from_path: name = os.sep.join(repo.fs.path.relparts(fs_path)) @@ -176,6 +207,7 @@ def _gather_params( param_outs, params_fs_paths = _collect_configs( repo, rev, targets=targets, deps=deps, stages=stages ) + params_fs_paths.extend(_collect_top_level_params(repo=repo)) params = _read_params( repo, params=param_outs, diff --git a/dvc/schema.py b/dvc/schema.py index b48f61aa095..02c16c27605 100644 --- a/dvc/schema.py +++ b/dvc/schema.py @@ -123,6 +123,8 @@ def validator(data): PLOTS: Any(SINGLE_PLOT_SCHEMA, [Any(str, SINGLE_PLOT_SCHEMA)]), STAGES: SINGLE_PIPELINE_STAGE_SCHEMA, VARS_KWD: VARS_SCHEMA, + StageParams.PARAM_PARAMS: [str], + StageParams.PARAM_METRICS: [str], } COMPILED_SINGLE_STAGE_SCHEMA = Schema(SINGLE_STAGE_SCHEMA)