diff --git a/dvc/command/run.py b/dvc/command/run.py index aa4f7ded5a..4b33a468ff 100644 --- a/dvc/command/run.py +++ b/dvc/command/run.py @@ -40,6 +40,7 @@ def run(self): metrics=self.args.metrics, metrics_no_cache=self.args.metrics_no_cache, deps=self.args.deps, + params=self.args.params, fname=self.args.file, cwd=self.args.cwd, wdir=self.args.wdir, @@ -111,6 +112,13 @@ def add_parser(subparsers, parent_parser): help="Declare output file or directory " "(do not put into DVC cache).", ) + run_parser.add_argument( + "-p", + "--params", + action="append", + default=[], + help="Declare parameter to use as additional dependency.", + ) run_parser.add_argument( "-m", "--metrics", diff --git a/dvc/dependency/__init__.py b/dvc/dependency/__init__.py index a968f01443..9d2e93e00a 100644 --- a/dvc/dependency/__init__.py +++ b/dvc/dependency/__init__.py @@ -8,6 +8,7 @@ from dvc.dependency.local import DependencyLOCAL from dvc.dependency.s3 import DependencyS3 from dvc.dependency.ssh import DependencySSH +from dvc.dependency.param import DependencyPARAMS from dvc.output.base import OutputBase from dvc.remote import Remote from dvc.scheme import Schemes @@ -42,31 +43,32 @@ SCHEMA = output.SCHEMA.copy() del SCHEMA[OutputBase.PARAM_CACHE] del SCHEMA[OutputBase.PARAM_METRIC] -SCHEMA[DependencyREPO.PARAM_REPO] = DependencyREPO.REPO_SCHEMA +SCHEMA.update(DependencyREPO.REPO_SCHEMA) +SCHEMA.update(DependencyPARAMS.PARAM_SCHEMA) -def _get(stage, p, info): - parsed = urlparse(p) +def _get_by_path(stage, path, info): + parsed = urlparse(path) if parsed.scheme == "remote": remote = Remote(stage.repo, name=parsed.netloc) - return DEP_MAP[remote.scheme](stage, p, info, remote=remote) + return DEP_MAP[remote.scheme](stage, path, info, remote=remote) if info and info.get(DependencyREPO.PARAM_REPO): repo = info.pop(DependencyREPO.PARAM_REPO) - return DependencyREPO(repo, stage, p, info) + return DependencyREPO(repo, stage, path, info) for d in DEPS: - if d.supported(p): - return d(stage, p, info) - return DependencyLOCAL(stage, p, info) + if d.supported(path): + return d(stage, path, info) + return DependencyLOCAL(stage, path, info) def loadd_from(stage, d_list): ret = [] for d in d_list: p = d.pop(OutputBase.PARAM_PATH) - ret.append(_get(stage, p, d)) + ret.append(_get_by_path(stage, p, d)) return ret @@ -74,5 +76,10 @@ def loads_from(stage, s_list, erepo=None): ret = [] for s in s_list: info = {DependencyREPO.PARAM_REPO: erepo} if erepo else {} - ret.append(_get(stage, s, info)) + dep_obj = _get_by_path(stage, s, info) + ret.append(dep_obj) return ret + + +def loads_params(stage, s_list): # TODO: Make support for `eropo=` as well ? + return DependencyPARAMS.from_list(stage, s_list) diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py new file mode 100644 index 0000000000..f44139579c --- /dev/null +++ b/dvc/dependency/param.py @@ -0,0 +1,107 @@ +import json +import re +from itertools import groupby + +from dvc.dependency.local import DependencyLOCAL +from dvc.exceptions import DvcException + + +class BadParamNameError(DvcException): + def __init__(self, param_name): + msg = "Parameter name '{}' is not valid".format(param_name) + super().__init__(msg) + + +class BadParamFileError(DvcException): + def __init__(self, path): + msg = "Parameter file '{}' could not be read".format(path) + super().__init__(msg) + + +class DependencyPARAMS(DependencyLOCAL): + # SCHEMA: + # params: + # - : + # - : + PARAM_PARAMS = "params" + PARAM_SCHEMA = {PARAM_PARAMS: {str: str}} + FILE_DELIMITER = ":" + PARAM_DELIMITER = "," + DEFAULT_PARAMS_FILE = "params.json" + + REGEX_SUBNAME = r"\w+" + REGEX_NAME = r"{sub}(\.{sub})*".format(sub=REGEX_SUBNAME) + REGEX_MULTI_PARAMS = r"^{param}(,{param})*$".format(param=REGEX_NAME) + REGEX_COMPILED = re.compile(REGEX_MULTI_PARAMS) + + def __init__(self, stage, input_str, *args, **kwargs): + path, param_names = self._parse_and_validate_input(input_str) + super().__init__(stage, path, *args, **kwargs) + self.param_names = sorted(param_names.split(self.PARAM_DELIMITER)) + self.param_values = {} + + def __str__(self): + path = super().__str__() + return self._reverse_parse_input(path, self.param_names) + + @classmethod + def from_list(cls, stage, s_list): + # Creates an object for each unique file that is referenced in the list + ret = [] + pathname_tuples = [cls._parse_and_validate_input(s) for s in s_list] + grouped_by_path = groupby(sorted(pathname_tuples), key=lambda x: x[0]) + for path, group in grouped_by_path: + param_names = [g[1] for g in group] + regrouped_input = cls._reverse_parse_input(path, param_names) + ret.append(DependencyPARAMS(stage, regrouped_input)) + return ret + + @classmethod + def _parse_and_validate_input(cls, input_str): + path, _, param_names = input_str.rpartition(cls.FILE_DELIMITER) + cls._validate_input(param_names) + path = path or cls.DEFAULT_PARAMS_FILE + return path, param_names + + @classmethod + def _reverse_parse_input(cls, path, param_names): + return "{path}{delimiter}{params}".format( + path=path, + delimiter=cls.FILE_DELIMITER, + params=cls.PARAM_DELIMITER.join(param_names), + ) + + @classmethod + def _validate_input(cls, param_names): + if not cls.REGEX_COMPILED.match(param_names): + raise BadParamNameError(param_names) + + def save(self): + super().save() + params_in_file = self._parse_file() + self.param_values = {k: params_in_file[k] for k in self.param_names} + + def dumpd(self): + return { + self.PARAM_PATH: self.def_path, + self.PARAM_PARAMS: self.param_values, + } + + @property + def exists(self): + file_exists = super().exists + params_in_file = self._parse_file() + params_exists = all([p in params_in_file for p in self.param_names]) + return file_exists and params_exists + + def _parse_file(self): + try: + return self._params_cache + except AttributeError: + path = self.path_info.fspath + with open(path, "r") as fp: + try: + self._params_cache = json.load(fp) + except json.JSONDecodeError: + raise BadParamFileError(path) + return self._params_cache diff --git a/dvc/dependency/repo.py b/dvc/dependency/repo.py index 53edac176b..56292c4272 100644 --- a/dvc/dependency/repo.py +++ b/dvc/dependency/repo.py @@ -14,9 +14,11 @@ class DependencyREPO(DependencyLOCAL): PARAM_REV_LOCK = "rev_lock" REPO_SCHEMA = { - Required(PARAM_URL): str, - PARAM_REV: str, - PARAM_REV_LOCK: str, + PARAM_REPO: { + Required(PARAM_URL): str, + PARAM_REV: str, + PARAM_REV_LOCK: str, + } } def __init__(self, def_repo, stage, *args, **kwargs): diff --git a/dvc/stage.py b/dvc/stage.py index a3e015a6be..9d8ac9225b 100644 --- a/dvc/stage.py +++ b/dvc/stage.py @@ -533,9 +533,11 @@ def create(repo, accompany_outs=False, **kwargs): ) Stage._fill_stage_outputs(stage, **kwargs) - stage.deps = dependency.loads_from( + deps = dependency.loads_from( stage, kwargs.get("deps", []), erepo=kwargs.get("erepo", None) ) + params = dependency.loads_params(stage, kwargs.get("params", [])) + stage.deps = deps + params stage._check_circular_dependency() stage._check_duplicated_arguments() diff --git a/tests/basic_env.py b/tests/basic_env.py index 119a599182..49353070a1 100644 --- a/tests/basic_env.py +++ b/tests/basic_env.py @@ -38,6 +38,10 @@ class TestDirFixture(object): # in tests, we replace foo with bar, so we need to make sure that when we # modify a file in our tests, its content length changes. BAR_CONTENTS = BAR + "r" + PARAMSDEFAULT = "params.json" + PARAMSDEFAULT_CONTENTS = '{"p_one": "1", "p_two": "1"}' + PARAMS = "par.json" + PARAMS_CONTENTS = '{"p_three": "3"}' CODE = "code.py" CODE_CONTENTS = ( "import sys\nimport shutil\n" @@ -87,6 +91,8 @@ def setUp(self): self._pushd(self._root_dir) self.create(self.FOO, self.FOO_CONTENTS) self.create(self.BAR, self.BAR_CONTENTS) + self.create(self.PARAMSDEFAULT, self.PARAMSDEFAULT_CONTENTS) + self.create(self.PARAMS, self.PARAMS_CONTENTS) self.create(self.CODE, self.CODE_CONTENTS) os.mkdir(self.DATA_DIR) os.mkdir(self.DATA_SUB_DIR) diff --git a/tests/func/test_run.py b/tests/func/test_run.py index 0f97672a67..8e1bc195ce 100644 --- a/tests/func/test_run.py +++ b/tests/func/test_run.py @@ -35,6 +35,7 @@ class TestRun(TestDvc): def test(self): cmd = "python {} {} {}".format(self.CODE, self.FOO, "out") deps = [self.FOO, self.CODE] + params = ["p_one", "p_two", "par.json:p_three"] outs = [os.path.join(self.dvc.root_dir, "out")] outs_no_cache = [] fname = "out.dvc" @@ -45,6 +46,7 @@ def test(self): cmd=cmd, deps=deps, outs=outs, + params=params, outs_no_cache=outs_no_cache, fname=fname, cwd=cwd, @@ -53,7 +55,7 @@ def test(self): self.assertTrue(filecmp.cmp(self.FOO, "out", shallow=False)) self.assertTrue(os.path.isfile(stage.path)) self.assertEqual(stage.cmd, cmd) - self.assertEqual(len(stage.deps), len(deps)) + self.assertEqual(len(stage.deps), len(deps) + 2) self.assertEqual(len(stage.outs), len(outs + outs_no_cache)) self.assertEqual(stage.outs[0].fspath, outs[0]) self.assertEqual(stage.outs[0].checksum, file_md5(self.FOO)[0]) diff --git a/tests/unit/dependency/test_params.py b/tests/unit/dependency/test_params.py new file mode 100644 index 0000000000..b78f7e0eba --- /dev/null +++ b/tests/unit/dependency/test_params.py @@ -0,0 +1,18 @@ +import mock + +from dvc.dependency import DependencyPARAMS +from dvc.stage import Stage +from tests.basic_env import TestDvc + + +class TestDependencyPARAM(TestDvc): + def test_from_list(self): + stage = Stage(self.dvc) + deps = DependencyPARAMS.from_list( + stage, ["foo", "bar,baz", "a_file:qux"] + ) + assert len(deps) == 2 + assert deps[0].def_path == "a_file" + assert deps[0].param_names == ["qux"] + assert deps[1].def_path == DependencyPARAMS.DEFAULT_PARAMS_FILE + assert deps[1].param_names == ["bar", "baz", "foo"]